1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/QuantOps/QuantTypes.h"
10 #include "TypeDetail.h"
11
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/StandardTypes.h"
14 #include "llvm/ADT/StringRef.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Support/MathExtras.h"
17
18 using namespace mlir;
19 using namespace mlir::quant;
20 using namespace mlir::quant::detail;
21
getFlags() const22 unsigned QuantizedType::getFlags() const {
23 return static_cast<ImplType *>(impl)->flags;
24 }
25
verifyConstructionInvariants(Optional<Location> loc,MLIRContext * context,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)26 LogicalResult QuantizedType::verifyConstructionInvariants(
27 Optional<Location> loc, MLIRContext *context, unsigned flags,
28 Type storageType, Type expressedType, int64_t storageTypeMin,
29 int64_t storageTypeMax) {
30 // Verify that the storage type is integral.
31 // This restriction may be lifted at some point in favor of using bf16
32 // or f16 as exact representations on hardware where that is advantageous.
33 auto intStorageType = storageType.dyn_cast<IntegerType>();
34 if (!intStorageType)
35 return emitOptionalError(loc, "storage type must be integral");
36 unsigned integralWidth = intStorageType.getWidth();
37
38 // Verify storage width.
39 if (integralWidth == 0 || integralWidth > MaxStorageBits)
40 return emitOptionalError(loc, "illegal storage type size: ", integralWidth);
41
42 // Verify storageTypeMin and storageTypeMax.
43 bool isSigned =
44 (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
45 int64_t defaultIntegerMin =
46 getDefaultMinimumForInteger(isSigned, integralWidth);
47 int64_t defaultIntegerMax =
48 getDefaultMaximumForInteger(isSigned, integralWidth);
49 if (storageTypeMax - storageTypeMin <= 0 ||
50 storageTypeMin < defaultIntegerMin ||
51 storageTypeMax > defaultIntegerMax) {
52 return emitOptionalError(loc, "illegal storage min and storage max: (",
53 storageTypeMin, ":", storageTypeMax, ")");
54 }
55 return success();
56 }
57
getStorageType() const58 Type QuantizedType::getStorageType() const {
59 return static_cast<ImplType *>(impl)->storageType;
60 }
61
getStorageTypeMin() const62 int64_t QuantizedType::getStorageTypeMin() const {
63 return static_cast<ImplType *>(impl)->storageTypeMin;
64 }
65
getStorageTypeMax() const66 int64_t QuantizedType::getStorageTypeMax() const {
67 return static_cast<ImplType *>(impl)->storageTypeMax;
68 }
69
getStorageTypeIntegralWidth() const70 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
71 // NOTE: If ever supporting non-integral storage types, some other scheme
72 // for determining the width will be needed.
73 return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
74 }
75
getExpressedType() const76 Type QuantizedType::getExpressedType() const {
77 return static_cast<ImplType *>(impl)->expressedType;
78 }
79
isCompatibleExpressedType(Type candidateExpressedType)80 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
81 if (candidateExpressedType.isa<ShapedType>()) {
82 return candidateExpressedType.cast<ShapedType>().getElementType() ==
83 getExpressedType();
84 }
85 return candidateExpressedType == getExpressedType();
86 }
87
88 QuantizedType
getQuantizedElementType(Type primitiveOrContainerType)89 QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
90 if (primitiveOrContainerType.isa<ShapedType>()) {
91 Type elementType =
92 primitiveOrContainerType.cast<ShapedType>().getElementType();
93 return elementType.dyn_cast<QuantizedType>();
94 }
95 return primitiveOrContainerType.dyn_cast<QuantizedType>();
96 }
97
castFromStorageType(Type candidateType)98 Type QuantizedType::castFromStorageType(Type candidateType) {
99 if (candidateType == getStorageType()) {
100 // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
101 return *this;
102 } else if (candidateType.isa<RankedTensorType>()) {
103 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
104 return RankedTensorType::get(
105 candidateType.cast<RankedTensorType>().getShape(), getStorageType());
106 } else if (candidateType.isa<UnrankedTensorType>()) {
107 // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
108 return UnrankedTensorType::get(getStorageType());
109 } else if (candidateType.isa<VectorType>()) {
110 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
111 return VectorType::get(candidateType.cast<VectorType>().getShape(),
112 getStorageType());
113 }
114
115 return nullptr;
116 }
117
castToStorageType(Type quantizedType)118 Type QuantizedType::castToStorageType(Type quantizedType) {
119 if (quantizedType.isa<QuantizedType>()) {
120 // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
121 return quantizedType.cast<QuantizedType>().getStorageType();
122 } else if (quantizedType.isa<ShapedType>()) {
123 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
124 ShapedType sType = quantizedType.cast<ShapedType>();
125 if (!sType.getElementType().isa<QuantizedType>()) {
126 return nullptr;
127 }
128 Type storageType =
129 sType.getElementType().cast<QuantizedType>().getStorageType();
130 if (quantizedType.isa<RankedTensorType>()) {
131 return RankedTensorType::get(sType.getShape(), storageType);
132 } else if (quantizedType.isa<UnrankedTensorType>()) {
133 return UnrankedTensorType::get(storageType);
134 } else if (quantizedType.isa<VectorType>()) {
135 return VectorType::get(sType.getShape(), storageType);
136 }
137 }
138
139 return nullptr;
140 }
141
castFromExpressedType(Type candidateType)142 Type QuantizedType::castFromExpressedType(Type candidateType) {
143 if (candidateType == getExpressedType()) {
144 // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
145 return *this;
146 } else if (candidateType.isa<ShapedType>()) {
147 ShapedType candidateShapedType = candidateType.cast<ShapedType>();
148 if (candidateShapedType.getElementType() != getExpressedType()) {
149 return nullptr;
150 }
151
152 if (candidateType.isa<RankedTensorType>()) {
153 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
154 return RankedTensorType::get(candidateShapedType.getShape(), *this);
155 } else if (candidateType.isa<UnrankedTensorType>()) {
156 // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
157 return UnrankedTensorType::get(*this);
158 } else if (candidateType.isa<VectorType>()) {
159 // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
160 return VectorType::get(candidateShapedType.getShape(), *this);
161 }
162 }
163
164 return nullptr;
165 }
166
castToExpressedType(Type quantizedType)167 Type QuantizedType::castToExpressedType(Type quantizedType) {
168 if (quantizedType.isa<QuantizedType>()) {
169 // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
170 return quantizedType.cast<QuantizedType>().getExpressedType();
171 } else if (quantizedType.isa<ShapedType>()) {
172 // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
173 ShapedType sType = quantizedType.cast<ShapedType>();
174 if (!sType.getElementType().isa<QuantizedType>()) {
175 return nullptr;
176 }
177 Type expressedType =
178 sType.getElementType().cast<QuantizedType>().getExpressedType();
179 if (quantizedType.isa<RankedTensorType>()) {
180 return RankedTensorType::get(sType.getShape(), expressedType);
181 } else if (quantizedType.isa<UnrankedTensorType>()) {
182 return UnrankedTensorType::get(expressedType);
183 } else if (quantizedType.isa<VectorType>()) {
184 return VectorType::get(sType.getShape(), expressedType);
185 }
186 }
187
188 return nullptr;
189 }
190
castExpressedToStorageType(Type candidateType)191 Type QuantizedType::castExpressedToStorageType(Type candidateType) {
192 Type expressedQuantizedType = castFromExpressedType(candidateType);
193 if (!expressedQuantizedType) {
194 return nullptr;
195 }
196 return QuantizedType::castToStorageType(expressedQuantizedType);
197 }
198
get(unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)199 AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
200 Type expressedType,
201 int64_t storageTypeMin,
202 int64_t storageTypeMax) {
203 return Base::get(storageType.getContext(), QuantizationTypes::Any, flags,
204 storageType, expressedType, storageTypeMin, storageTypeMax);
205 }
206
getChecked(unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax,Location location)207 AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
208 Type expressedType,
209 int64_t storageTypeMin,
210 int64_t storageTypeMax,
211 Location location) {
212 return Base::getChecked(location, storageType.getContext(),
213 QuantizationTypes::Any, flags, storageType,
214 expressedType, storageTypeMin, storageTypeMax);
215 }
216
verifyConstructionInvariants(Optional<Location> loc,MLIRContext * context,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)217 LogicalResult AnyQuantizedType::verifyConstructionInvariants(
218 Optional<Location> loc, MLIRContext *context, unsigned flags,
219 Type storageType, Type expressedType, int64_t storageTypeMin,
220 int64_t storageTypeMax) {
221 if (failed(QuantizedType::verifyConstructionInvariants(
222 loc, context, flags, storageType, expressedType, storageTypeMin,
223 storageTypeMax))) {
224 return failure();
225 }
226
227 // Verify that the expressed type is floating point.
228 // If this restriction is ever eliminated, the parser/printer must be
229 // extended.
230 if (expressedType && !expressedType.isa<FloatType>())
231 return emitOptionalError(loc, "expressed type must be floating point");
232
233 return success();
234 }
235
get(unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)236 UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
237 Type expressedType, double scale,
238 int64_t zeroPoint,
239 int64_t storageTypeMin,
240 int64_t storageTypeMax) {
241 return Base::get(storageType.getContext(),
242 QuantizationTypes::UniformQuantized, flags, storageType,
243 expressedType, scale, zeroPoint, storageTypeMin,
244 storageTypeMax);
245 }
246
247 UniformQuantizedType
getChecked(unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax,Location location)248 UniformQuantizedType::getChecked(unsigned flags, Type storageType,
249 Type expressedType, double scale,
250 int64_t zeroPoint, int64_t storageTypeMin,
251 int64_t storageTypeMax, Location location) {
252 return Base::getChecked(location, storageType.getContext(),
253 QuantizationTypes::UniformQuantized, flags,
254 storageType, expressedType, scale, zeroPoint,
255 storageTypeMin, storageTypeMax);
256 }
257
verifyConstructionInvariants(Optional<Location> loc,MLIRContext * context,unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)258 LogicalResult UniformQuantizedType::verifyConstructionInvariants(
259 Optional<Location> loc, MLIRContext *context, unsigned flags,
260 Type storageType, Type expressedType, double scale, int64_t zeroPoint,
261 int64_t storageTypeMin, int64_t storageTypeMax) {
262 if (failed(QuantizedType::verifyConstructionInvariants(
263 loc, context, flags, storageType, expressedType, storageTypeMin,
264 storageTypeMax))) {
265 return failure();
266 }
267
268 // Uniform quantization requires fully expressed parameters, including
269 // expressed type.
270 if (!expressedType)
271 return emitOptionalError(loc,
272 "uniform quantization requires expressed type");
273
274 // Verify that the expressed type is floating point.
275 // If this restriction is ever eliminated, the parser/printer must be
276 // extended.
277 if (!expressedType.isa<FloatType>())
278 return emitOptionalError(loc, "expressed type must be floating point");
279
280 // Verify scale.
281 if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
282 return emitOptionalError(loc, "illegal scale: ", scale);
283
284 return success();
285 }
286
getScale() const287 double UniformQuantizedType::getScale() const { return getImpl()->scale; }
288
getZeroPoint() const289 int64_t UniformQuantizedType::getZeroPoint() const {
290 return getImpl()->zeroPoint;
291 }
292
get(unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)293 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
294 unsigned flags, Type storageType, Type expressedType,
295 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
296 int32_t quantizedDimension, int64_t storageTypeMin,
297 int64_t storageTypeMax) {
298 return Base::get(storageType.getContext(),
299 QuantizationTypes::UniformQuantizedPerAxis, flags,
300 storageType, expressedType, scales, zeroPoints,
301 quantizedDimension, storageTypeMin, storageTypeMax);
302 }
303
getChecked(unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax,Location location)304 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
305 unsigned flags, Type storageType, Type expressedType,
306 ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
307 int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
308 Location location) {
309 return Base::getChecked(location, storageType.getContext(),
310 QuantizationTypes::UniformQuantizedPerAxis, flags,
311 storageType, expressedType, scales, zeroPoints,
312 quantizedDimension, storageTypeMin, storageTypeMax);
313 }
314
verifyConstructionInvariants(Optional<Location> loc,MLIRContext * context,unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)315 LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
316 Optional<Location> loc, MLIRContext *context, unsigned flags,
317 Type storageType, Type expressedType, ArrayRef<double> scales,
318 ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
319 int64_t storageTypeMin, int64_t storageTypeMax) {
320 if (failed(QuantizedType::verifyConstructionInvariants(
321 loc, context, flags, storageType, expressedType, storageTypeMin,
322 storageTypeMax))) {
323 return failure();
324 }
325
326 // Uniform quantization requires fully expressed parameters, including
327 // expressed type.
328 if (!expressedType)
329 return emitOptionalError(loc,
330 "uniform quantization requires expressed type");
331
332 // Verify that the expressed type is floating point.
333 // If this restriction is ever eliminated, the parser/printer must be
334 // extended.
335 if (!expressedType.isa<FloatType>())
336 return emitOptionalError(loc, "expressed type must be floating point");
337
338 // Ensure that the number of scales and zeroPoints match.
339 if (scales.size() != zeroPoints.size())
340 return emitOptionalError(loc, "illegal number of scales and zeroPoints: ",
341 scales.size(), ", ", zeroPoints.size());
342
343 // Verify scale.
344 for (double scale : scales) {
345 if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
346 return emitOptionalError(loc, "illegal scale: ", scale);
347 }
348
349 return success();
350 }
351
getScales() const352 ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
353 return getImpl()->getScales();
354 }
355
getZeroPoints() const356 ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
357 return getImpl()->getZeroPoints();
358 }
359
getQuantizedDimension() const360 int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
361 return getImpl()->quantizedDimension;
362 }
363