1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file src/relay/qnn/util.cc
22  * \brief Utility functions for QNN.
23  */
24 
25 #include "util.h"
26 #include "../pass/pattern_util.h"
27 
28 namespace tvm {
29 namespace relay {
30 namespace qnn {
31 
32 /*
33  * \brief Convert FP32 representation into fixed point representation.
34  * \param double_multplier The input FP32 number.
35  * \return The pair of multiplier and shift for fixed point representation.
36  * \note Converts a floating point number so that it can be represented by
37  *       integers. The representation is
38  *             float_number = (significand) * 2^(exponent)
39  *
40  *       The significand is a number between 0.5 and 1. This is represented by
41  *       an integer number. For example, if it is int32, then the decimal point
42  *       exists between bit 31 and 30 from LSB (or between first and second bit
43  *       from the left).
44  *
45  *       Some examples are
46  *           0.25 = (0.5) * 2^(-1)
47  *           0.125 = (0.5) * 2^(-2)
48  *
49  *       Credit to TFLite reference implementation.
50  */
GetFixedPointMultiplierShift(double double_multiplier)51 std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(
52     double double_multiplier) {
53   int32_t significand, exponent;
54   if (double_multiplier == 0.) {
55     significand = 0;
56     exponent = 0;
57     return std::make_pair(significand, exponent);
58   }
59 
60   // Get the significand and exponent.
61   double significand_d = std::frexp(double_multiplier, &exponent);
62 
63   // Convert the double significand to int significand, i.e., convert into a
64   // integer where the decimal point is between bit 31 and 30. This is done by
65   // multiplying the double value with 2^31 and then casting to int.
66   significand_d = std::round(significand_d * (1ll << 31));
67   auto significand_int64 = static_cast<int64_t>(significand_d);
68   CHECK_LE(significand_int64, (1ll << 31));
69   if (significand_int64 == (1ll << 31)) {
70     significand_int64 /= 2;
71     ++exponent;
72   }
73   CHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
74   significand = static_cast<int32_t>(significand_int64);
75   return std::make_pair(significand, exponent);
76 }
77 
FixedPointMultiply(Expr tensor,double multiplier,const Array<IndexExpr> & input_shape,const std::string & rounding)78 Expr FixedPointMultiply(Expr tensor, double multiplier,
79                    const Array<IndexExpr>& input_shape, const std::string& rounding) {
80   // Choose high precision datatype to be int64. This is for avoiding overflow
81   // in multiplication of two int32 values.
82   DataType hp_dtype = Int(64);
83 
84   // 1) Calculating the integer multiplier and integer shift
85   int32_t fixed_point_multiplier, shift;
86   std::tie(fixed_point_multiplier, shift) =
87       GetFixedPointMultiplierShift(multiplier);
88   int left_shift = shift > 0 ? shift : 0;
89   int right_shift = shift > 0 ? 0 : -shift;
90 
91   // 2) Multiply the integer multiplier
92   if (left_shift != 0) {
93     tensor = LeftShift(tensor, MakeConstantScalar(hp_dtype, left_shift));
94   }
95 
96   // 3) Perform the multiplication in higher precision.
97   // The scalar is a fixed point value of int32 where the decimal point is
98   // between bits 31 and 30. After multiplying with input_tensor, the result
99   // is in int64 where the decimal point is sitting between bits 31 and 30
100   // (from the right, rightmost bit is bit 0). The computation is performed in
101   // higher precision to avoid overflow in multiplying two int32 values.
102   Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
103   tensor = Multiply(tensor, scalar);
104 
105   // 4) Find the rounding scalar. This depends on where the final decimal
106   // point sits. As we will be right shifting the multiplied_t, we need to
107   // first calculate the total_right_shift.
108   int total_right_shift = right_shift + 31;
109   int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
110 
111   Expr round_scalar;
112   if (rounding == "UPWARD") {
113     round_scalar = MakeConstantScalar(hp_dtype, pos_rounding_value);
114   } else if (rounding == "TONEAREST") {
115     auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
116     auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
117     auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
118     auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
119 
120     auto zero_t = Zeros(input_shape, hp_dtype);
121     round_scalar =
122         Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
123   } else {
124     LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
125   }
126   // Add the rounding scalar.
127   tensor = Add(tensor, round_scalar);
128 
129   // 5) Simply right shift the result to get the final output.
130   tensor =
131       RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
132 
133   return tensor;
134 }
135 
136 }  // namespace qnn
137 }  // namespace relay
138 }  // namespace tvm
139