1 //===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
10 #define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
11 
12 #include "mlir/Dialect/QuantOps/QuantOps.h"
13 #include "mlir/Dialect/QuantOps/QuantTypes.h"
14 #include "mlir/Dialect/QuantOps/UniformSupport.h"
15 #include "mlir/IR/Operation.h"
16 
17 #include <cmath>
18 
19 namespace mlir {
20 namespace fxpmath {
21 namespace detail {
22 
getUniformElementType(Type t)23 inline quant::UniformQuantizedType getUniformElementType(Type t) {
24   return quant::QuantizedType::getQuantizedElementType(t)
25       .dyn_cast_or_null<quant::UniformQuantizedType>();
26 }
27 
hasStorageBitWidth(quant::QuantizedType t,ArrayRef<unsigned> checkWidths)28 inline bool hasStorageBitWidth(quant::QuantizedType t,
29                                ArrayRef<unsigned> checkWidths) {
30   unsigned w = t.getStorageType().getIntOrFloatBitWidth();
31   for (unsigned checkWidth : checkWidths) {
32     if (w == checkWidth)
33       return true;
34   }
35   return false;
36 }
37 
38 /// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
39 /// be considered an exact integral value.
integralLog2(F x,int & log2Result)40 template <typename F> bool integralLog2(F x, int &log2Result) {
41   const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
42   const F xLog2Rounded = std::round(xLog2);
43   const F xLog2Frac = xLog2 - xLog2Rounded;
44   log2Result = static_cast<int>(xLog2Rounded);
45   // Allow small comparison slop below the level that would make a difference
46   // for 2^16 levels.
47   return std::abs(xLog2Frac) < 1e-6;
48 }
49 
50 /// Helper class for operating on binary operations where all operands
51 /// and the result are a UniformQuantizedType.
52 struct UniformBinaryOpInfo {
UniformBinaryOpInfoUniformBinaryOpInfo53   UniformBinaryOpInfo(Operation *op, Value lhs, Value rhs,
54                       Optional<APFloat> clampMin, Optional<APFloat> clampMax)
55       : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
56         lhsType(getUniformElementType(lhs.getType())),
57         rhsType(getUniformElementType(rhs.getType())),
58         resultType(getUniformElementType(*op->result_type_begin())),
59         lhsStorageType(quant::QuantizedType::castToStorageType(lhs.getType())),
60         rhsStorageType(quant::QuantizedType::castToStorageType(rhs.getType())),
61         resultStorageType(
62             quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
63   }
64 
65   /// Returns whether this info is valid (all types defined, etc).
isValidUniformBinaryOpInfo66   bool isValid() const {
67     return lhsType && rhsType && resultType && lhsStorageType &&
68            rhsStorageType && resultStorageType;
69   }
70 
71   /// Gets the final quantized result type of the result.
getQuantizedResultTypeUniformBinaryOpInfo72   Type getQuantizedResultType() const { return *op->result_type_begin(); }
73 
74   /// Returns whether the storage type of all operands is identical.
isSameStorageTypeUniformBinaryOpInfo75   bool isSameStorageType() const {
76     return lhsType.getStorageType() == rhsType.getStorageType() &&
77            lhsType.getStorageType() == resultType.getStorageType();
78   }
79 
80   /// Returns whether all operands and result are considered fixedpoint power
81   /// of two, setting the lhs, rhs, and result log2 scale references.
isFixedPointPOTUniformBinaryOpInfo82   bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
83                        int &resultLog2Scale) const {
84     if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
85         !resultType.isFixedPoint()) {
86       return false;
87     }
88 
89     if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
90         !integralLog2(rhsType.getScale(), rhsLog2Scale) ||
91         !integralLog2(resultType.getScale(), resultLog2Scale)) {
92       return false;
93     }
94 
95     return true;
96   }
97 
98   /// Gets the result integer clamp range given the result quantized type
99   // and any explicit clamp provided as attributes.
getClampMinMaxUniformBinaryOpInfo100   std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const {
101     int64_t typeMin = resultType.getStorageTypeMin();
102     int64_t typeMax = resultType.getStorageTypeMax();
103 
104     if (clampMin || clampMax) {
105       quant::UniformQuantizedValueConverter conv(resultType);
106       if (clampMin) {
107         typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
108       }
109       if (clampMax) {
110         typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
111       }
112     }
113 
114     // The quantized, integral ops expect clamps as 32bit ints.
115     return {
116         IntegerAttr::get(ty, typeMin),
117         IntegerAttr::get(ty, typeMax),
118     };
119   }
120 
121   Operation *op;
122   Value lhs;
123   Value rhs;
124   Optional<APFloat> clampMin;
125   Optional<APFloat> clampMax;
126 
127   // Element UniformQuantizedType for operands/result.
128   quant::UniformQuantizedType lhsType;
129   quant::UniformQuantizedType rhsType;
130   quant::UniformQuantizedType resultType;
131 
132   // Full storage-based types.
133   Type lhsStorageType;
134   Type rhsStorageType;
135   Type resultStorageType;
136 };
137 
138 /// Derives a quantized multiplier and shift from a real valued multiplier
139 /// less than 1.
140 struct QuantizedMultiplierSmallerThanOneExp {
QuantizedMultiplierSmallerThanOneExpQuantizedMultiplierSmallerThanOneExp141   QuantizedMultiplierSmallerThanOneExp(double realMultiplier) {
142     assert(realMultiplier < 1.0);
143     assert(realMultiplier > 0.0);
144 
145     const double q = std::frexp(realMultiplier, &exponent);
146     auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
147     assert(qFixed <= (1ll << 31));
148     if (qFixed == (1ll << 31)) {
149       qFixed /= 2;
150       ++exponent;
151     }
152     assert(qFixed <= std::numeric_limits<int32_t>::max());
153     multiplier = static_cast<int32_t>(qFixed);
154   }
155 
156   int32_t multiplier;
157   int exponent;
158 };
159 
160 /// Casts an integer or floating point based shaped type to a new element type.
castElementType(Type t,Type newElementType)161 inline Type castElementType(Type t, Type newElementType) {
162   if (auto st = t.dyn_cast<ShapedType>()) {
163     switch (st.getKind()) {
164     case StandardTypes::Kind::Vector:
165       return VectorType::get(st.getShape(), newElementType);
166     case StandardTypes::Kind::RankedTensor:
167       return RankedTensorType::get(st.getShape(), newElementType);
168     case StandardTypes::Kind::UnrankedTensor:
169       return UnrankedTensorType::get(newElementType);
170     case StandardTypes::Kind::MemRef:
171       return MemRefType::get(st.getShape(), newElementType,
172                              st.cast<MemRefType>().getAffineMaps());
173     }
174   }
175   assert(t.isIntOrFloat());
176   return newElementType;
177 }
178 
179 /// Creates an IntegerAttr with a type that matches the shape of 't' (which can
180 /// be a scalar primitive or a shaped type).
broadcastScalarConstIntValue(Type t,int64_t value)181 inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
182   if (auto st = t.dyn_cast<ShapedType>()) {
183     assert(st.getElementType().isa<IntegerType>());
184     return DenseElementsAttr::get(st,
185                                   IntegerAttr::get(st.getElementType(), value));
186   }
187 
188   auto integerType = t.cast<IntegerType>();
189   assert(t.isa<IntegerType>() && "integer broadcast must be of integer type");
190   return IntegerAttr::get(integerType, value);
191 }
192 
193 /// Given an APFloat, converts it to the float semantics that matches the
194 /// given FloatType, silently ignoring inexact conversions.
convertFloatToType(FloatType ft,APFloat value)195 inline APFloat convertFloatToType(FloatType ft, APFloat value) {
196   bool losesInfo;
197   auto status = value.convert(ft.getFloatSemantics(),
198                               APFloat::rmNearestTiesToEven, &losesInfo);
199   (void)status; // unused in opt mode
200   assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 &&
201          "could not convert to float const");
202   return value;
203 }
204 
205 /// Creates a FloatAttr with a type that matches the shape of 't' (which can be
206 /// a scalar primitive or a shaped type).
broadcastScalarConstFloatValue(Type t,APFloat value)207 inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
208   if (auto st = t.dyn_cast<ShapedType>()) {
209     FloatType floatElementType = st.getElementType().dyn_cast<FloatType>();
210     assert(floatElementType &&
211            "float broadcast element type must be float like");
212     APFloat apValue = convertFloatToType(floatElementType, value);
213     return DenseElementsAttr::get(st,
214                                   FloatAttr::get(st.getElementType(), apValue));
215   } else {
216     auto floatType = t.dyn_cast<FloatType>();
217     assert(floatType && "float broadcast must be of float type");
218     APFloat apValue = convertFloatToType(floatType, value);
219     return FloatAttr::get(floatType, apValue);
220   }
221 }
222 
223 } // namespace detail
224 } // namespace fxpmath
225 } // namespace mlir
226 
227 #endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
228