1 //===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
10 #define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
11
12 #include "mlir/Dialect/QuantOps/QuantOps.h"
13 #include "mlir/Dialect/QuantOps/QuantTypes.h"
14 #include "mlir/Dialect/QuantOps/UniformSupport.h"
15 #include "mlir/IR/Operation.h"
16
17 #include <cmath>
18
19 namespace mlir {
20 namespace fxpmath {
21 namespace detail {
22
getUniformElementType(Type t)23 inline quant::UniformQuantizedType getUniformElementType(Type t) {
24 return quant::QuantizedType::getQuantizedElementType(t)
25 .dyn_cast_or_null<quant::UniformQuantizedType>();
26 }
27
hasStorageBitWidth(quant::QuantizedType t,ArrayRef<unsigned> checkWidths)28 inline bool hasStorageBitWidth(quant::QuantizedType t,
29 ArrayRef<unsigned> checkWidths) {
30 unsigned w = t.getStorageType().getIntOrFloatBitWidth();
31 for (unsigned checkWidth : checkWidths) {
32 if (w == checkWidth)
33 return true;
34 }
35 return false;
36 }
37
38 /// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
39 /// be considered an exact integral value.
integralLog2(F x,int & log2Result)40 template <typename F> bool integralLog2(F x, int &log2Result) {
41 const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
42 const F xLog2Rounded = std::round(xLog2);
43 const F xLog2Frac = xLog2 - xLog2Rounded;
44 log2Result = static_cast<int>(xLog2Rounded);
45 // Allow small comparison slop below the level that would make a difference
46 // for 2^16 levels.
47 return std::abs(xLog2Frac) < 1e-6;
48 }
49
50 /// Helper class for operating on binary operations where all operands
51 /// and the result are a UniformQuantizedType.
52 struct UniformBinaryOpInfo {
UniformBinaryOpInfoUniformBinaryOpInfo53 UniformBinaryOpInfo(Operation *op, Value lhs, Value rhs,
54 Optional<APFloat> clampMin, Optional<APFloat> clampMax)
55 : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
56 lhsType(getUniformElementType(lhs.getType())),
57 rhsType(getUniformElementType(rhs.getType())),
58 resultType(getUniformElementType(*op->result_type_begin())),
59 lhsStorageType(quant::QuantizedType::castToStorageType(lhs.getType())),
60 rhsStorageType(quant::QuantizedType::castToStorageType(rhs.getType())),
61 resultStorageType(
62 quant::QuantizedType::castToStorageType(*op->result_type_begin())) {
63 }
64
65 /// Returns whether this info is valid (all types defined, etc).
isValidUniformBinaryOpInfo66 bool isValid() const {
67 return lhsType && rhsType && resultType && lhsStorageType &&
68 rhsStorageType && resultStorageType;
69 }
70
71 /// Gets the final quantized result type of the result.
getQuantizedResultTypeUniformBinaryOpInfo72 Type getQuantizedResultType() const { return *op->result_type_begin(); }
73
74 /// Returns whether the storage type of all operands is identical.
isSameStorageTypeUniformBinaryOpInfo75 bool isSameStorageType() const {
76 return lhsType.getStorageType() == rhsType.getStorageType() &&
77 lhsType.getStorageType() == resultType.getStorageType();
78 }
79
80 /// Returns whether all operands and result are considered fixedpoint power
81 /// of two, setting the lhs, rhs, and result log2 scale references.
isFixedPointPOTUniformBinaryOpInfo82 bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
83 int &resultLog2Scale) const {
84 if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
85 !resultType.isFixedPoint()) {
86 return false;
87 }
88
89 if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
90 !integralLog2(rhsType.getScale(), rhsLog2Scale) ||
91 !integralLog2(resultType.getScale(), resultLog2Scale)) {
92 return false;
93 }
94
95 return true;
96 }
97
98 /// Gets the result integer clamp range given the result quantized type
99 // and any explicit clamp provided as attributes.
getClampMinMaxUniformBinaryOpInfo100 std::pair<IntegerAttr, IntegerAttr> getClampMinMax(IntegerType ty) const {
101 int64_t typeMin = resultType.getStorageTypeMin();
102 int64_t typeMax = resultType.getStorageTypeMax();
103
104 if (clampMin || clampMax) {
105 quant::UniformQuantizedValueConverter conv(resultType);
106 if (clampMin) {
107 typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
108 }
109 if (clampMax) {
110 typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
111 }
112 }
113
114 // The quantized, integral ops expect clamps as 32bit ints.
115 return {
116 IntegerAttr::get(ty, typeMin),
117 IntegerAttr::get(ty, typeMax),
118 };
119 }
120
121 Operation *op;
122 Value lhs;
123 Value rhs;
124 Optional<APFloat> clampMin;
125 Optional<APFloat> clampMax;
126
127 // Element UniformQuantizedType for operands/result.
128 quant::UniformQuantizedType lhsType;
129 quant::UniformQuantizedType rhsType;
130 quant::UniformQuantizedType resultType;
131
132 // Full storage-based types.
133 Type lhsStorageType;
134 Type rhsStorageType;
135 Type resultStorageType;
136 };
137
138 /// Derives a quantized multiplier and shift from a real valued multiplier
139 /// less than 1.
140 struct QuantizedMultiplierSmallerThanOneExp {
QuantizedMultiplierSmallerThanOneExpQuantizedMultiplierSmallerThanOneExp141 QuantizedMultiplierSmallerThanOneExp(double realMultiplier) {
142 assert(realMultiplier < 1.0);
143 assert(realMultiplier > 0.0);
144
145 const double q = std::frexp(realMultiplier, &exponent);
146 auto qFixed = static_cast<int64_t>(std::round(q * (1ll << 31)));
147 assert(qFixed <= (1ll << 31));
148 if (qFixed == (1ll << 31)) {
149 qFixed /= 2;
150 ++exponent;
151 }
152 assert(qFixed <= std::numeric_limits<int32_t>::max());
153 multiplier = static_cast<int32_t>(qFixed);
154 }
155
156 int32_t multiplier;
157 int exponent;
158 };
159
160 /// Casts an integer or floating point based shaped type to a new element type.
castElementType(Type t,Type newElementType)161 inline Type castElementType(Type t, Type newElementType) {
162 if (auto st = t.dyn_cast<ShapedType>()) {
163 switch (st.getKind()) {
164 case StandardTypes::Kind::Vector:
165 return VectorType::get(st.getShape(), newElementType);
166 case StandardTypes::Kind::RankedTensor:
167 return RankedTensorType::get(st.getShape(), newElementType);
168 case StandardTypes::Kind::UnrankedTensor:
169 return UnrankedTensorType::get(newElementType);
170 case StandardTypes::Kind::MemRef:
171 return MemRefType::get(st.getShape(), newElementType,
172 st.cast<MemRefType>().getAffineMaps());
173 }
174 }
175 assert(t.isIntOrFloat());
176 return newElementType;
177 }
178
179 /// Creates an IntegerAttr with a type that matches the shape of 't' (which can
180 /// be a scalar primitive or a shaped type).
broadcastScalarConstIntValue(Type t,int64_t value)181 inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) {
182 if (auto st = t.dyn_cast<ShapedType>()) {
183 assert(st.getElementType().isa<IntegerType>());
184 return DenseElementsAttr::get(st,
185 IntegerAttr::get(st.getElementType(), value));
186 }
187
188 auto integerType = t.cast<IntegerType>();
189 assert(t.isa<IntegerType>() && "integer broadcast must be of integer type");
190 return IntegerAttr::get(integerType, value);
191 }
192
193 /// Given an APFloat, converts it to the float semantics that matches the
194 /// given FloatType, silently ignoring inexact conversions.
convertFloatToType(FloatType ft,APFloat value)195 inline APFloat convertFloatToType(FloatType ft, APFloat value) {
196 bool losesInfo;
197 auto status = value.convert(ft.getFloatSemantics(),
198 APFloat::rmNearestTiesToEven, &losesInfo);
199 (void)status; // unused in opt mode
200 assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 &&
201 "could not convert to float const");
202 return value;
203 }
204
205 /// Creates a FloatAttr with a type that matches the shape of 't' (which can be
206 /// a scalar primitive or a shaped type).
broadcastScalarConstFloatValue(Type t,APFloat value)207 inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
208 if (auto st = t.dyn_cast<ShapedType>()) {
209 FloatType floatElementType = st.getElementType().dyn_cast<FloatType>();
210 assert(floatElementType &&
211 "float broadcast element type must be float like");
212 APFloat apValue = convertFloatToType(floatElementType, value);
213 return DenseElementsAttr::get(st,
214 FloatAttr::get(st.getElementType(), apValue));
215 } else {
216 auto floatType = t.dyn_cast<FloatType>();
217 assert(floatType && "float broadcast must be of float type");
218 APFloat apValue = convertFloatToType(floatType, value);
219 return FloatAttr::get(floatType, apValue);
220 }
221 }
222
223 } // namespace detail
224 } // namespace fxpmath
225 } // namespace mlir
226
227 #endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_
228