1 //===- QuantTypes.h - Quantization Ops and Types ----------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
10 #define MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/OpDefinition.h"
16 #include "mlir/IR/StandardTypes.h"
17 #include "mlir/IR/Types.h"
18 #include "llvm/Support/MathExtras.h"
19 
20 namespace mlir {
21 namespace quant {
22 
23 class QuantizedIntegerType;
24 
25 namespace detail {
26 
27 struct QuantizedTypeStorage;
28 struct AnyQuantizedTypeStorage;
29 struct UniformQuantizedTypeStorage;
30 struct UniformQuantizedPerAxisTypeStorage;
31 
32 } // namespace detail
33 
34 namespace QuantizationTypes {
35 enum Kind {
36   Any = Type::FIRST_QUANTIZATION_TYPE,
37   UniformQuantized,
38   UniformQuantizedPerAxis,
39   LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis,
40 };
41 } // namespace QuantizationTypes
42 
43 /// Enumeration of bit-mapped flags related to quantized types.
44 namespace QuantizationFlags {
45 enum FlagValue {
46   // Indicates that the storage type should be interpreted as a signed
47   // integer. The default is to interpret it as an unsigned value.
48   Signed = 1,
49 };
50 } // namespace QuantizationFlags
51 
52 /// Base class for all quantized types known to this dialect.
53 /// All quantized types have:
54 ///   - storageType: The (narrower) numeric type that is being used to
55 ///     approximate some expressed type.
56 ///   - expressedType: The type that is being approximated.
57 ///
58 /// The base class provides generic support for manipulating the types based
59 /// on these fields.
60 class QuantizedType : public Type {
61 public:
62   using ImplType = detail::QuantizedTypeStorage;
63   using Type::Type;
64 
65   /// The maximum number of bits supported for storage types.
66   static constexpr unsigned MaxStorageBits = 32;
67 
68   static LogicalResult
69   verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
70                                unsigned flags, Type storageType,
71                                Type expressedType, int64_t storageTypeMin,
72                                int64_t storageTypeMax);
73 
74   /// Support method to enable LLVM-style type casting.
classof(Type type)75   static bool classof(Type type) {
76     return type.getKind() >= Type::FIRST_QUANTIZATION_TYPE &&
77            type.getKind() <= QuantizationTypes::LAST_USED_QUANTIZATION_TYPE;
78   }
79 
80   /// Gets the minimum possible stored by a storageType. storageTypeMin must
81   /// be greater than or equal to this value.
getDefaultMinimumForInteger(bool isSigned,unsigned integralWidth)82   static int64_t getDefaultMinimumForInteger(bool isSigned,
83                                              unsigned integralWidth) {
84     if (isSigned) {
85       return llvm::minIntN(integralWidth);
86     }
87     return 0;
88   }
89 
90   /// Gets the maximum possible stored by a storageType. storageTypeMax must
91   /// be less than or equal to this value.
getDefaultMaximumForInteger(bool isSigned,unsigned integralWidth)92   static int64_t getDefaultMaximumForInteger(bool isSigned,
93                                              unsigned integralWidth) {
94     if (isSigned) {
95       return llvm::maxIntN(integralWidth);
96     }
97     return llvm::maxUIntN(integralWidth);
98   }
99 
100   /// Gets the original expressed type that this quantized type approximates.
101   /// Note that this presumes that the quantized type was always derived from
102   /// a floating point type, which in the broadest definition, is not true (i.e.
103   /// it could be some form of integral, fixed type or affine type in its own
104   /// right); however, at the high level, no examples of such usage are
105   /// presently known and the restriction serves some useful purposes (such as
106   /// always being able to reverse a transformation or measure error). In most
107   /// cases, this will be f32.
108   Type getExpressedType() const;
109 
110   /// Gets the flags associated with this type. Typically a more specific
111   /// accessor is appropriate.
112   unsigned getFlags() const;
113 
114   // Convenience helpers.
115   /// Whether the storage type should be interpreted as a signed quantity
116   /// (true) or an unsigned value (false).
isSigned()117   bool isSigned() const {
118     return (getFlags() & QuantizationFlags::Signed) ==
119            QuantizationFlags::Signed;
120   }
121 
122   /// Gets the underlying type used for to store values. Note that this may
123   /// be signed or unsigned. Use the isSigned() accessor to differentiate.
124   Type getStorageType() const;
125 
126   /// The minimum value that storageType can take.
127   int64_t getStorageTypeMin() const;
128 
129   /// The maximum value that storageType can take.
130   int64_t getStorageTypeMax() const;
131 
132   /// Gets the integral bit width that the underlying storage type can exactly
133   /// represent. For integral storage types, this will just be their width.
134   unsigned getStorageTypeIntegralWidth() const;
135 
136   /// Returns whether the candidateExpressedType is a match for this
137   /// QuantizedType. This will be true if the candidate type is either a
138   /// primitive type or a container type whose element type equals this
139   /// QuantizedType's expressed type.
140   /// Examples of compatible candidateExpressedType:
141   ///   !quant.uniform<i8:f32, 1.0> =~ f32
142   ///   !quant.uniform<i8:f32, 1.0> =~ tensor<4xf32>
143   bool isCompatibleExpressedType(Type candidateExpressedType);
144 
145   /// Returns the element type as a QuantizedType or nullptr if it is not
146   /// a quantized type. If the type is primitive, returns that. If it is a
147   /// container (vector/tensor), return the element type.
148   /// Examples:
149   ///   !quant.uniform<i8:f32, 1.0> -> !quant.uniform<i8:f32, 1.0>
150   ///   tensor<4x!quant.uniform<i8:f32, 1.0> -> quant.uniform<i8:f32, 1.0>
151   static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
152 
153   /// Casts from a type based on the storageType to a corresponding type based
154   /// on this type (returns nullptr if the cast is not valid).
155   /// Examples:
156   ///   i8 -> !quant.uniform<i8:f32, 1.0>
157   ///   tensor<4xi8> -> tensor<4x!quant.uniform<i8:f32, 1.0}>>
158   ///   vector<4xi8> -> vector<4x!quant.uniform<i8:f32, 1.0>>
159   Type castFromStorageType(Type candidateType);
160 
161   /// Casts from a type based on a QuantizedType to a corresponding type based
162   /// on the storageType (returns nullptr if the cast is not valid).
163   /// This is the inverse of castFromStorageType().
164   static Type castToStorageType(Type quantizedType);
165 
166   /// Casts from a type based on the expressedType to a corresponding type based
167   /// on this type (returns nullptr if the cast is not valid).
168   /// Examples:
169   ///   f32 -> !quant.uniform<i8:f32, 1.0>
170   ///   tensor<4xf32> -> tensor<4x!quant.uniform<i8:f32, 1.0>>
171   ///   vector<4xf32> -> vector<4x!quant.uniform<i8:f32, 1.0>>
172   Type castFromExpressedType(Type candidateType);
173 
174   /// Casts from a type based on QuantizedType to a corresponding type based
175   /// on the expressedType (returns nullptr if the cast is not valid).
176   /// This is the inverse of castFromExpressedType.
177   static Type castToExpressedType(Type quantizedType);
178 
179   /// Casts from a type based on the expressedType to the equivalent type
180   /// based on storageType by way of this QuantizedType. Equivalent to:
181   ///   QuantizedType::castToStorageType(castFromExpressedType(candidateType))
182   /// (but with validity checks).
183   /// Example (for this = !quant.uniform<i8:f32, 1.0>):
184   ///   tensor<4xf32> -> tensor<4xi8>
185   Type castExpressedToStorageType(Type candidateType);
186 
187 private:
188   /// Hide the following methods inherited from `Type`. It is almost certainly
189   /// a bug to call them from a `QuantizedType` object. Users should call
190   /// `getStorageType` or `getExpressedType` to get the underlying types
191   /// they want to inspect.
192   using Type::isBF16;
193   using Type::isF16;
194   using Type::isF32;
195   using Type::isF64;
196   using Type::isIndex;
197   using Type::isInteger;
198 };
199 
200 /// A quantized type that maps storage to/from expressed types in an
201 /// unspecified way.
202 ///
203 /// Typical syntax:
204 ///   quant.any<i8:f32>
205 ///   quant.any<i8>
206 ///   quant.any<i8<-16,15>>
207 ///
208 /// Note that for the any type, the expressed type is optional.
209 class AnyQuantizedType
210     : public Type::TypeBase<AnyQuantizedType, QuantizedType,
211                             detail::AnyQuantizedTypeStorage> {
212 public:
213   using Base::Base;
214 
215   /// Support method to enable LLVM-style type casting.
kindof(unsigned kind)216   static bool kindof(unsigned kind) { return kind == QuantizationTypes::Any; }
217 
218   /// Gets an instance of the type with all parameters specified but not
219   /// checked.
220   static AnyQuantizedType get(unsigned flags, Type storageType,
221                               Type expressedType, int64_t storageTypeMin,
222                               int64_t storageTypeMax);
223 
224   /// Gets an instance of the type with all specified parameters checked.
225   /// Returns a nullptr convertible type on failure.
226   static AnyQuantizedType getChecked(unsigned flags, Type storageType,
227                                      Type expressedType, int64_t storageTypeMin,
228                                      int64_t storageTypeMax, Location location);
229 
230   /// Verifies construction invariants and issues errors/warnings.
231   static LogicalResult
232   verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
233                                unsigned flags, Type storageType,
234                                Type expressedType, int64_t storageTypeMin,
235                                int64_t storageTypeMax);
236 };
237 
238 /// Represents a family of uniform, quantized types.
239 ///
240 /// Each instance of this type expresses a mapping between real values (most
241 /// often expressed in floating point f32) and quantized values (either fixed
242 /// point or affine).
243 ///
244 /// The relationship is:
245 ///     real_value = scale * (quantized_value - zero_point)
246 ///
247 /// It is used as part of high level graph transformations that have the goal
248 /// of re-expressing parts of a computation in terms of this common form for
249 /// more efficient execution at runtime. In addition, it is designed to be
250 /// expressive enough to facilitate lowering to precise types and operations
251 /// in target hardware.
252 ///
253 /// As a high-level type, focused on intermediate passes, this type holds
254 /// opinions consistent with high-level usage. If lowering math kernels below
255 /// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
256 /// instruction sets), it is expected that the information expressed here
257 /// will be used to drive low level codegen and target specific type selection,
258 /// but this type will likely be erased in the process.
259 ///
260 /// Syntax synopsis:
261 ///   Per-layer, all parameters expressed:
262 ///     !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
263 ///   Per-layer, optional parameters omitted:
264 ///     !quant<uniform[StorageType]{Scale}>
265 ///
266 ///   StorageType: 'i'|'u' NumBits
267 ///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
268 ///   Scale: A legal double value
269 ///   ZeroPoint: An integer value
270 class UniformQuantizedType
271     : public Type::TypeBase<UniformQuantizedType, QuantizedType,
272                             detail::UniformQuantizedTypeStorage> {
273 public:
274   using Base::Base;
275 
276   /// Gets an instance of the type with all parameters specified but not
277   /// checked.
278   static UniformQuantizedType get(unsigned flags, Type storageType,
279                                   Type expressedType, double scale,
280                                   int64_t zeroPoint, int64_t storageTypeMin,
281                                   int64_t storageTypeMax);
282 
283   /// Gets an instance of the type with all specified parameters checked.
284   /// Returns a nullptr convertible type on failure.
285   static UniformQuantizedType
286   getChecked(unsigned flags, Type storageType, Type expressedType, double scale,
287              int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax,
288              Location location);
289 
290   /// Verifies construction invariants and issues errors/warnings.
291   static LogicalResult verifyConstructionInvariants(
292       Optional<Location> loc, MLIRContext *context, unsigned flags,
293       Type storageType, Type expressedType, double scale, int64_t zeroPoint,
294       int64_t storageTypeMin, int64_t storageTypeMax);
295 
296   /// Support method to enable LLVM-style type casting.
kindof(unsigned kind)297   static bool kindof(unsigned kind) {
298     return kind == QuantizationTypes::UniformQuantized;
299   }
300 
301   /// Gets the scale term. The scale designates the difference between the real
302   /// values corresponding to consecutive quantized values differing by 1.
303   double getScale() const;
304 
305   /// Gets the storage value corresponding to the real value 0 in the affine
306   /// equation.
307   int64_t getZeroPoint() const;
308 
309   // Fixed point values are real numbers divided by a scale.
310   // Currently, only signed storage types are treated as fixed point.
311   // A fixed point value can be obtained from an affine value by subtracting
312   // the zeroPoint.
313   // In the future, this may be explicit versus implied by type and zeroPoint.
isFixedPoint()314   bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
315 };
316 
317 /// Represents per-axis (also known as per-channel quantization).
318 ///
319 /// Syntax synopsis:
320 ///   Per-axis, all parameters expressed:
321 ///     !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
322 ///   Per-axis, optional parameters omitted:
323 ///     !quant<uniform[StorageType]{Scale}>
324 ///
325 ///   StorageType: 'i'|'u' NumBits
326 ///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
327 ///   QuantizedDim: An integer value
328 ///   QuantParams: (Scale ':' ZeroPoint)+
329 ///   Scale: A legal double value
330 ///   ZeroPoint: An integer value
331 class UniformQuantizedPerAxisType
332     : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
333                             detail::UniformQuantizedPerAxisTypeStorage> {
334 public:
335   using Base::Base;
336 
337   /// Gets an instance of the type with all parameters specified but not
338   /// checked.
339   static UniformQuantizedPerAxisType
340   get(unsigned flags, Type storageType, Type expressedType,
341       ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
342       int32_t quantizedDimension, int64_t storageTypeMin,
343       int64_t storageTypeMax);
344 
345   /// Gets an instance of the type with all specified parameters checked.
346   /// Returns a nullptr convertible type on failure.
347   static UniformQuantizedPerAxisType
348   getChecked(unsigned flags, Type storageType, Type expressedType,
349              ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
350              int32_t quantizedDimension, int64_t storageTypeMin,
351              int64_t storageTypeMax, Location location);
352 
353   /// Verifies construction invariants and issues errors/warnings.
354   static LogicalResult verifyConstructionInvariants(
355       Optional<Location> loc, MLIRContext *context, unsigned flags,
356       Type storageType, Type expressedType, ArrayRef<double> scales,
357       ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
358       int64_t storageTypeMin, int64_t storageTypeMax);
359 
360   /// Support method to enable LLVM-style type casting.
kindof(unsigned kind)361   static bool kindof(unsigned kind) {
362     return kind == QuantizationTypes::UniformQuantizedPerAxis;
363   }
364 
365   /// Gets the quantization scales. The scales designate the difference between
366   /// the real values corresponding to consecutive quantized values differing
367   /// by 1. The ith scale corresponds to the ith slice in the
368   /// quantized_dimension.
369   ArrayRef<double> getScales() const;
370 
371   /// Gets the storage values corresponding to the real value 0 in the affine
372   /// equation. The ith zero point corresponds to the ith slice in the
373   /// quantized_dimension.
374   ArrayRef<int64_t> getZeroPoints() const;
375 
376   /// Specifies the dimension of the Tensor's shape that the scales and
377   /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
378   /// with quantization params:
379   ///   scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
380   /// will be quantized across the second dimension of t.
381   ///   t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
382   ///   t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
383   ///   t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
384   int32_t getQuantizedDimension() const;
385 
386   /// Fixed point values are real numbers divided by a scale.
387   /// Currently, only signed storage types are treated as fixed point.
388   /// A fixed point value can be obtained from an affine value by subtracting
389   /// the zeroPoint.
390   /// In the future, this may be explicit versus implied by type and zeroPoint.
isFixedPoint()391   bool isFixedPoint() const {
392     if (!isSigned())
393       return false;
394     return llvm::all_of(getZeroPoints(),
395                         [](int64_t zeroPoint) { return zeroPoint != 0; });
396   }
397 };
398 
399 } // namespace quant
400 } // namespace mlir
401 
402 #endif // MLIR_DIALECT_QUANTOPS_QUANT_TYPES_H_
403