1 //===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_ 10 #define MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_ 11 12 #include "mlir/IR/Attributes.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/OpDefinition.h" 16 #include "mlir/IR/StandardTypes.h" 17 #include "mlir/IR/Types.h" 18 #include "llvm/Support/MathExtras.h" 19 20 namespace mlir { 21 namespace quant { 22 23 class QuantizedIntegerType; 24 25 namespace detail { 26 27 struct QuantizedTypeStorage; 28 struct AnyQuantizedTypeStorage; 29 struct UniformQuantizedTypeStorage; 30 struct UniformQuantizedPerAxisTypeStorage; 31 32 } // namespace detail 33 34 namespace QuantizationTypes { 35 enum Kind { 36 Any = Type::FIRST_QUANTIZATION_TYPE, 37 UniformQuantized, 38 UniformQuantizedPerAxis, 39 LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis, 40 }; 41 } // namespace QuantizationTypes 42 43 /// Enumeration of bit-mapped flags related to quantized types. 44 namespace QuantizationFlags { 45 enum FlagValue { 46 // Indicates that the storage type should be interpreted as a signed 47 // integer. The default is to interpret it as an unsigned value. 48 Signed = 1, 49 }; 50 } // namespace QuantizationFlags 51 52 /// Base class for all quantized types known to this dialect. 53 /// All quantized types have: 54 /// - storageType: The (narrower) numeric type that is being used to 55 /// approximate some expressed type. 56 /// - expressedType: The type that is being approximated. 57 /// 58 /// The base class provides generic support for manipulating the types based 59 /// on these fields. 60 class QuantizedType : public Type { 61 public: 62 using ImplType = detail::QuantizedTypeStorage; 63 using Type::Type; 64 65 /// The maximum number of bits supported for storage types. 66 static constexpr unsigned MaxStorageBits = 32; 67 68 static LogicalResult 69 verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context, 70 unsigned flags, Type storageType, 71 Type expressedType, int64_t storageTypeMin, 72 int64_t storageTypeMax); 73 74 /// Support method to enable LLVM-style type casting. classof(Type type)75 static bool classof(Type type) { 76 return type.getKind() >= Type::FIRST_QUANTIZATION_TYPE && 77 type.getKind() <= QuantizationTypes::LAST_USED_QUANTIZATION_TYPE; 78 } 79 80 /// Gets the minimum possible stored by a storageType. storageTypeMin must 81 /// be greater than or equal to this value. getDefaultMinimumForInteger(bool isSigned,unsigned integralWidth)82 static int64_t getDefaultMinimumForInteger(bool isSigned, 83 unsigned integralWidth) { 84 if (isSigned) { 85 return llvm::minIntN(integralWidth); 86 } 87 return 0; 88 } 89 90 /// Gets the maximum possible stored by a storageType. storageTypeMax must 91 /// be less than or equal to this value. getDefaultMaximumForInteger(bool isSigned,unsigned integralWidth)92 static int64_t getDefaultMaximumForInteger(bool isSigned, 93 unsigned integralWidth) { 94 if (isSigned) { 95 return llvm::maxIntN(integralWidth); 96 } 97 return llvm::maxUIntN(integralWidth); 98 } 99 100 /// Gets the original expressed type that this quantized type approximates. 101 /// Note that this presumes that the quantized type was always derived from 102 /// a floating point type, which in the broadest definition, is not true (i.e. 103 /// it could be some form of integral, fixed type or affine type in its own 104 /// right); however, at the high level, no examples of such usage are 105 /// presently known and the restriction serves some useful purposes (such as 106 /// always being able to reverse a transformation or measure error). In most 107 /// cases, this will be f32. 108 Type getExpressedType() const; 109 110 /// Gets the flags associated with this type. Typically a more specific 111 /// accessor is appropriate. 112 unsigned getFlags() const; 113 114 // Convenience helpers. 115 /// Whether the storage type should be interpreted as a signed quantity 116 /// (true) or an unsigned value (false). isSigned()117 bool isSigned() const { 118 return (getFlags() & QuantizationFlags::Signed) == 119 QuantizationFlags::Signed; 120 } 121 122 /// Gets the underlying type used for to store values. Note that this may 123 /// be signed or unsigned. Use the isSigned() accessor to differentiate. 124 Type getStorageType() const; 125 126 /// The minimum value that storageType can take. 127 int64_t getStorageTypeMin() const; 128 129 /// The maximum value that storageType can take. 130 int64_t getStorageTypeMax() const; 131 132 /// Gets the integral bit width that the underlying storage type can exactly 133 /// represent. For integral storage types, this will just be their width. 134 unsigned getStorageTypeIntegralWidth() const; 135 136 /// Returns whether the candidateExpressedType is a match for this 137 /// QuantizedType. This will be true if the candidate type is either a 138 /// primitive type or a container type whose element type equals this 139 /// QuantizedType's expressed type. 140 /// Examples of compatible candidateExpressedType: 141 /// !quant.uniform<i8:f32, 1.0> =~ f32 142 /// !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32> 143 bool isCompatibleExpressedType(Type candidateExpressedType); 144 145 /// Returns the element type as a QuantizedType or nullptr if it is not 146 /// a quantized type. If the type is primitive, returns that. If it is a 147 /// container (vector/tensor), return the element type. 148 /// Examples: 149 /// !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0> 150 /// tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0> 151 static QuantizedType getQuantizedElementType(Type primitiveOrContainerType); 152 153 /// Casts from a type based on the storageType to a corresponding type based 154 /// on this type (returns nullptr if the cast is not valid). 155 /// Examples: 156 /// i8 -> !quant.uniform<i8:f32, 1.0> 157 /// tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>> 158 /// vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>> 159 Type castFromStorageType(Type candidateType); 160 161 /// Casts from a type based on a QuantizedType to a corresponding type based 162 /// on the storageType (returns nullptr if the cast is not valid). 163 /// This is the inverse of castFromStorageType(). 164 static Type castToStorageType(Type quantizedType); 165 166 /// Casts from a type based on the expressedType to a corresponding type based 167 /// on this type (returns nullptr if the cast is not valid). 168 /// Examples: 169 /// f32 -> !quant.uniform<i8:f32, 1.0> 170 /// tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>> 171 /// vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>> 172 Type castFromExpressedType(Type candidateType); 173 174 /// Casts from a type based on QuantizedType to a corresponding type based 175 /// on the expressedType (returns nullptr if the cast is not valid). 176 /// This is the inverse of castFromExpressedType. 177 static Type castToExpressedType(Type quantizedType); 178 179 /// Casts from a type based on the expressedType to the equivalent type 180 /// based on storageType by way of this QuantizedType. Equivalent to: 181 /// QuantizedType::castToStorageType(castFromExpressedType(candidateType)) 182 /// (but with validity checks). 183 /// Example (for this = !quant.uniform<i8:f32, 1.0>): 184 /// tensor<4xf32> -> tensor<4xi8> 185 Type castExpressedToStorageType(Type candidateType); 186 187 private: 188 /// Hide the following methods inherited from `Type`. It is almost certainly 189 /// a bug to call them from a `QuantizedType` object. Users should call 190 /// `getStorageType` or `getExpressedType` to get the underlying types 191 /// they want to inspect. 192 using Type::isBF16; 193 using Type::isF16; 194 using Type::isF32; 195 using Type::isF64; 196 using Type::isIndex; 197 using Type::isInteger; 198 }; 199 200 /// A quantized type that maps storage to/from expressed types in an 201 /// unspecified way. 202 /// 203 /// Typical syntax: 204 /// quant.any<i8:f32> 205 /// quant.any<i8> 206 /// quant.any<i8<-16,15>> 207 /// 208 /// Note that for the any type, the expressed type is optional. 209 class AnyQuantizedType 210 : public Type::TypeBase<AnyQuantizedType, QuantizedType, 211 detail::AnyQuantizedTypeStorage> { 212 public: 213 using Base::Base; 214 215 /// Support method to enable LLVM-style type casting. kindof(unsigned kind)216 static bool kindof(unsigned kind) { return kind == QuantizationTypes::Any; } 217 218 /// Gets an instance of the type with all parameters specified but not 219 /// checked. 220 static AnyQuantizedType get(unsigned flags, Type storageType, 221 Type expressedType, int64_t storageTypeMin, 222 int64_t storageTypeMax); 223 224 /// Gets an instance of the type with all specified parameters checked. 225 /// Returns a nullptr convertible type on failure. 226 static AnyQuantizedType getChecked(unsigned flags, Type storageType, 227 Type expressedType, int64_t storageTypeMin, 228 int64_t storageTypeMax, Location location); 229 230 /// Verifies construction invariants and issues errors/warnings. 231 static LogicalResult 232 verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context, 233 unsigned flags, Type storageType, 234 Type expressedType, int64_t storageTypeMin, 235 int64_t storageTypeMax); 236 }; 237 238 /// Represents a family of uniform, quantized types. 239 /// 240 /// Each instance of this type expresses a mapping between real values (most 241 /// often expressed in floating point f32) and quantized values (either fixed 242 /// point or affine). 243 /// 244 /// The relationship is: 245 /// real_value = scale * (quantized_value - zero_point) 246 /// 247 /// It is used as part of high level graph transformations that have the goal 248 /// of re-expressing parts of a computation in terms of this common form for 249 /// more efficient execution at runtime. In addition, it is designed to be 250 /// expressive enough to facilitate lowering to precise types and operations 251 /// in target hardware. 252 /// 253 /// As a high-level type, focused on intermediate passes, this type holds 254 /// opinions consistent with high-level usage. If lowering math kernels below 255 /// the high level arithmetic ops (i.e. to LLVM IR or hardware specific 256 /// instruction sets), it is expected that the information expressed here 257 /// will be used to drive low level codegen and target specific type selection, 258 /// but this type will likely be erased in the process. 259 /// 260 /// Syntax synopsis: 261 /// Per-layer, all parameters expressed: 262 /// !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}> 263 /// Per-layer, optional parameters omitted: 264 /// !quant<uniform[StorageType]{Scale}> 265 /// 266 /// StorageType: 'i'|'u' NumBits 267 /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' 268 /// Scale: A legal double value 269 /// ZeroPoint: An integer value 270 class UniformQuantizedType 271 : public Type::TypeBase<UniformQuantizedType, QuantizedType, 272 detail::UniformQuantizedTypeStorage> { 273 public: 274 using Base::Base; 275 276 /// Gets an instance of the type with all parameters specified but not 277 /// checked. 278 static UniformQuantizedType get(unsigned flags, Type storageType, 279 Type expressedType, double scale, 280 int64_t zeroPoint, int64_t storageTypeMin, 281 int64_t storageTypeMax); 282 283 /// Gets an instance of the type with all specified parameters checked. 284 /// Returns a nullptr convertible type on failure. 285 static UniformQuantizedType 286 getChecked(unsigned flags, Type storageType, Type expressedType, double scale, 287 int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax, 288 Location location); 289 290 /// Verifies construction invariants and issues errors/warnings. 291 static LogicalResult verifyConstructionInvariants( 292 Optional<Location> loc, MLIRContext *context, unsigned flags, 293 Type storageType, Type expressedType, double scale, int64_t zeroPoint, 294 int64_t storageTypeMin, int64_t storageTypeMax); 295 296 /// Support method to enable LLVM-style type casting. kindof(unsigned kind)297 static bool kindof(unsigned kind) { 298 return kind == QuantizationTypes::UniformQuantized; 299 } 300 301 /// Gets the scale term. The scale designates the difference between the real 302 /// values corresponding to consecutive quantized values differing by 1. 303 double getScale() const; 304 305 /// Gets the storage value corresponding to the real value 0 in the affine 306 /// equation. 307 int64_t getZeroPoint() const; 308 309 // Fixed point values are real numbers divided by a scale. 310 // Currently, only signed storage types are treated as fixed point. 311 // A fixed point value can be obtained from an affine value by subtracting 312 // the zeroPoint. 313 // In the future, this may be explicit versus implied by type and zeroPoint. isFixedPoint()314 bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; } 315 }; 316 317 /// Represents per-axis (also known as per-channel quantization). 318 /// 319 /// Syntax synopsis: 320 /// Per-axis, all parameters expressed: 321 /// !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}> 322 /// Per-axis, optional parameters omitted: 323 /// !quant<uniform[StorageType]{Scale}> 324 /// 325 /// StorageType: 'i'|'u' NumBits 326 /// ExpressedType: 'f16', 'f32', 'bf16', 'f64' 327 /// QuantizedDim: An integer value 328 /// QuantParams: (Scale ':' ZeroPoint)+ 329 /// Scale: A legal double value 330 /// ZeroPoint: An integer value 331 class UniformQuantizedPerAxisType 332 : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType, 333 detail::UniformQuantizedPerAxisTypeStorage> { 334 public: 335 using Base::Base; 336 337 /// Gets an instance of the type with all parameters specified but not 338 /// checked. 339 static UniformQuantizedPerAxisType 340 get(unsigned flags, Type storageType, Type expressedType, 341 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, 342 int32_t quantizedDimension, int64_t storageTypeMin, 343 int64_t storageTypeMax); 344 345 /// Gets an instance of the type with all specified parameters checked. 346 /// Returns a nullptr convertible type on failure. 347 static UniformQuantizedPerAxisType 348 getChecked(unsigned flags, Type storageType, Type expressedType, 349 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, 350 int32_t quantizedDimension, int64_t storageTypeMin, 351 int64_t storageTypeMax, Location location); 352 353 /// Verifies construction invariants and issues errors/warnings. 354 static LogicalResult verifyConstructionInvariants( 355 Optional<Location> loc, MLIRContext *context, unsigned flags, 356 Type storageType, Type expressedType, ArrayRef<double> scales, 357 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension, 358 int64_t storageTypeMin, int64_t storageTypeMax); 359 360 /// Support method to enable LLVM-style type casting. kindof(unsigned kind)361 static bool kindof(unsigned kind) { 362 return kind == QuantizationTypes::UniformQuantizedPerAxis; 363 } 364 365 /// Gets the quantization scales. The scales designate the difference between 366 /// the real values corresponding to consecutive quantized values differing 367 /// by 1. The ith scale corresponds to the ith slice in the 368 /// quantized_dimension. 369 ArrayRef<double> getScales() const; 370 371 /// Gets the storage values corresponding to the real value 0 in the affine 372 /// equation. The ith zero point corresponds to the ith slice in the 373 /// quantized_dimension. 374 ArrayRef<int64_t> getZeroPoints() const; 375 376 /// Specifies the dimension of the Tensor's shape that the scales and 377 /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] 378 /// with quantization params: 379 /// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1 380 /// will be quantized across the second dimension of t. 381 /// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 382 /// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 383 /// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 384 int32_t getQuantizedDimension() const; 385 386 /// Fixed point values are real numbers divided by a scale. 387 /// Currently, only signed storage types are treated as fixed point. 388 /// A fixed point value can be obtained from an affine value by subtracting 389 /// the zeroPoint. 390 /// In the future, this may be explicit versus implied by type and zeroPoint. isFixedPoint()391 bool isFixedPoint() const { 392 if (!isSigned()) 393 return false; 394 return llvm::all_of(getZeroPoints(), 395 [](int64_t zeroPoint) { return zeroPoint != 0; }); 396 } 397 }; 398 399 } // namespace quant 400 } // namespace mlir 401 402 #endif // MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_ 403