1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file src/relay/qnn/util.cc
22 * \brief Utility functions for QNN.
23 */
24
25 #include "util.h"
26 #include "../pass/pattern_util.h"
27
28 namespace tvm {
29 namespace relay {
30 namespace qnn {
31
32 /*
33 * \brief Convert FP32 representation into fixed point representation.
34 * \param double_multplier The input FP32 number.
35 * \return The pair of multiplier and shift for fixed point representation.
36 * \note Converts a floating point number so that it can be represented by
37 * integers. The representation is
38 * float_number = (significand) * 2^(exponent)
39 *
40 * The significand is a number between 0.5 and 1. This is represented by
41 * an integer number. For example, if it is int32, then the decimal point
42 * exists between bit 31 and 30 from LSB (or between first and second bit
43 * from the left).
44 *
45 * Some examples are
46 * 0.25 = (0.5) * 2^(-1)
47 * 0.125 = (0.5) * 2^(-2)
48 *
49 * Credit to TFLite reference implementation.
50 */
GetFixedPointMultiplierShift(double double_multiplier)51 std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
52 double double_multiplier) {
53 int32_t significand, exponent;
54 if (double_multiplier == 0.) {
55 significand = 0;
56 exponent = 0;
57 return std::make_pair(significand, exponent);
58 }
59
60 // Get the significand and exponent.
61 double significand_d = std::frexp(double_multiplier, &exponent);
62
63 // Convert the double significand to int significand, i.e., convert into a
64 // integer where the decimal point is between bit 31 and 30. This is done by
65 // multiplying the double value with 2^31 and then casting to int.
66 significand_d = std::round(significand_d * (1ll << 31));
67 auto significand_int64 = static_cast<int64_t>(significand_d);
68 CHECK_LE(significand_int64, (1ll << 31));
69 if (significand_int64 == (1ll << 31)) {
70 significand_int64 /= 2;
71 ++exponent;
72 }
73 CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
74 significand = static_cast<int32_t>(significand_int64);
75 return std::make_pair(significand, exponent);
76 }
77
FixedPointMultiply(Expr tensor,double multiplier,const Array<IndexExpr> & input_shape,const std::string & rounding)78 Expr FixedPointMultiply(Expr tensor, double multiplier,
79 const Array<IndexExpr>& input_shape, const std::string& rounding) {
80 // Choose high precision datatype to be int64. This is for avoiding overflow
81 // in multiplication of two int32 values.
82 DataType hp_dtype = Int(64);
83
84 // 1) Calculating the integer multiplier and integer shift
85 int32_t fixed_point_multiplier, shift;
86 std::tie(fixed_point_multiplier, shift) =
87 GetFixedPointMultiplierShift(multiplier);
88 int left_shift = shift > 0 ? shift : 0;
89 int right_shift = shift > 0 ? 0 : -shift;
90
91 // 2) Multiply the integer multiplier
92 if (left_shift != 0) {
93 tensor = LeftShift(tensor, MakeConstantScalar(hp_dtype, left_shift));
94 }
95
96 // 3) Perform the multiplication in higher precision.
97 // The scalar is a fixed point value of int32 where the decimal point is
98 // between bits 31 and 30. After multiplying with input_tensor, the result
99 // is in int64 where the decimal point is sitting between bits 31 and 30
100 // (from the right, rightmost bit is bit 0). The computation is performed in
101 // higher precision to avoid overflow in multiplying two int32 values.
102 Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
103 tensor = Multiply(tensor, scalar);
104
105 // 4) Find the rounding scalar. This depends on where the final decimal
106 // point sits. As we will be right shifting the multiplied_t, we need to
107 // first calculate the total_right_shift.
108 int total_right_shift = right_shift + 31;
109 int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
110
111 Expr round_scalar;
112 if (rounding == "UPWARD") {
113 round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
114 } else if (rounding == "TONEAREST") {
115 auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
116 auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
117 auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
118 auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
119
120 auto zero_t = Zeros(input_shape, hp_dtype);
121 round_scalar =
122 Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
123 } else {
124 LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
125 }
126 // Add the rounding scalar.
127 tensor = Add(tensor, round_scalar);
128
129 // 5) Simply right shift the result to get the final output.
130 tensor =
131 RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
132
133 return tensor;
134 }
135
136 } // namespace qnn
137 } // namespace relay
138 } // namespace tvm
139