1 //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/QuantOps/QuantTypes.h"
10 #include "TypeDetail.h"
11 
12 #include "mlir/IR/MLIRContext.h"
13 #include "mlir/IR/StandardTypes.h"
14 #include "llvm/ADT/StringRef.h"
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Support/MathExtras.h"
17 
18 using namespace mlir;
19 using namespace mlir::quant;
20 using namespace mlir::quant::detail;
21 
getFlags() const22 unsigned QuantizedType::getFlags() const {
23   return static_cast<ImplType *>(impl)->flags;
24 }
25 
verifyConstructionInvariants(Optional<Location> loc,MLIRContext * context,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)26 LogicalResult QuantizedType::verifyConstructionInvariants(
27     Optional<Location> loc, MLIRContext *context, unsigned flags,
28     Type storageType, Type expressedType, int64_t storageTypeMin,
29     int64_t storageTypeMax) {
30   // Verify that the storage type is integral.
31   // This restriction may be lifted at some point in favor of using bf16
32   // or f16 as exact representations on hardware where that is advantageous.
33   auto intStorageType = storageType.dyn_cast<IntegerType>();
34   if (!intStorageType)
35     return emitOptionalError(loc, "storage type must be integral");
36   unsigned integralWidth = intStorageType.getWidth();
37 
38   // Verify storage width.
39   if (integralWidth == 0 || integralWidth > MaxStorageBits)
40     return emitOptionalError(loc, "illegal storage type size: ", integralWidth);
41 
42   // Verify storageTypeMin and storageTypeMax.
43   bool isSigned =
44       (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
45   int64_t defaultIntegerMin =
46       getDefaultMinimumForInteger(isSigned, integralWidth);
47   int64_t defaultIntegerMax =
48       getDefaultMaximumForInteger(isSigned, integralWidth);
49   if (storageTypeMax - storageTypeMin <= 0 ||
50       storageTypeMin < defaultIntegerMin ||
51       storageTypeMax > defaultIntegerMax) {
52     return emitOptionalError(loc, "illegal storage min and storage max: (",
53                              storageTypeMin, ":", storageTypeMax, ")");
54   }
55   return success();
56 }
57 
getStorageType() const58 Type QuantizedType::getStorageType() const {
59   return static_cast<ImplType *>(impl)->storageType;
60 }
61 
getStorageTypeMin() const62 int64_t QuantizedType::getStorageTypeMin() const {
63   return static_cast<ImplType *>(impl)->storageTypeMin;
64 }
65 
getStorageTypeMax() const66 int64_t QuantizedType::getStorageTypeMax() const {
67   return static_cast<ImplType *>(impl)->storageTypeMax;
68 }
69 
getStorageTypeIntegralWidth() const70 unsigned QuantizedType::getStorageTypeIntegralWidth() const {
71   // NOTE: If ever supporting non-integral storage types, some other scheme
72   // for determining the width will be needed.
73   return static_cast<ImplType *>(impl)->storageType.getIntOrFloatBitWidth();
74 }
75 
getExpressedType() const76 Type QuantizedType::getExpressedType() const {
77   return static_cast<ImplType *>(impl)->expressedType;
78 }
79 
isCompatibleExpressedType(Type candidateExpressedType)80 bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
81   if (candidateExpressedType.isa<ShapedType>()) {
82     return candidateExpressedType.cast<ShapedType>().getElementType() ==
83            getExpressedType();
84   }
85   return candidateExpressedType == getExpressedType();
86 }
87 
88 QuantizedType
getQuantizedElementType(Type primitiveOrContainerType)89 QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
90   if (primitiveOrContainerType.isa<ShapedType>()) {
91     Type elementType =
92         primitiveOrContainerType.cast<ShapedType>().getElementType();
93     return elementType.dyn_cast<QuantizedType>();
94   }
95   return primitiveOrContainerType.dyn_cast<QuantizedType>();
96 }
97 
castFromStorageType(Type candidateType)98 Type QuantizedType::castFromStorageType(Type candidateType) {
99   if (candidateType == getStorageType()) {
100     // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
101     return *this;
102   } else if (candidateType.isa<RankedTensorType>()) {
103     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
104     return RankedTensorType::get(
105         candidateType.cast<RankedTensorType>().getShape(), getStorageType());
106   } else if (candidateType.isa<UnrankedTensorType>()) {
107     // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
108     return UnrankedTensorType::get(getStorageType());
109   } else if (candidateType.isa<VectorType>()) {
110     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
111     return VectorType::get(candidateType.cast<VectorType>().getShape(),
112                            getStorageType());
113   }
114 
115   return nullptr;
116 }
117 
castToStorageType(Type quantizedType)118 Type QuantizedType::castToStorageType(Type quantizedType) {
119   if (quantizedType.isa<QuantizedType>()) {
120     // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
121     return quantizedType.cast<QuantizedType>().getStorageType();
122   } else if (quantizedType.isa<ShapedType>()) {
123     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
124     ShapedType sType = quantizedType.cast<ShapedType>();
125     if (!sType.getElementType().isa<QuantizedType>()) {
126       return nullptr;
127     }
128     Type storageType =
129         sType.getElementType().cast<QuantizedType>().getStorageType();
130     if (quantizedType.isa<RankedTensorType>()) {
131       return RankedTensorType::get(sType.getShape(), storageType);
132     } else if (quantizedType.isa<UnrankedTensorType>()) {
133       return UnrankedTensorType::get(storageType);
134     } else if (quantizedType.isa<VectorType>()) {
135       return VectorType::get(sType.getShape(), storageType);
136     }
137   }
138 
139   return nullptr;
140 }
141 
castFromExpressedType(Type candidateType)142 Type QuantizedType::castFromExpressedType(Type candidateType) {
143   if (candidateType == getExpressedType()) {
144     // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
145     return *this;
146   } else if (candidateType.isa<ShapedType>()) {
147     ShapedType candidateShapedType = candidateType.cast<ShapedType>();
148     if (candidateShapedType.getElementType() != getExpressedType()) {
149       return nullptr;
150     }
151 
152     if (candidateType.isa<RankedTensorType>()) {
153       // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
154       return RankedTensorType::get(candidateShapedType.getShape(), *this);
155     } else if (candidateType.isa<UnrankedTensorType>()) {
156       // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
157       return UnrankedTensorType::get(*this);
158     } else if (candidateType.isa<VectorType>()) {
159       // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
160       return VectorType::get(candidateShapedType.getShape(), *this);
161     }
162   }
163 
164   return nullptr;
165 }
166 
castToExpressedType(Type quantizedType)167 Type QuantizedType::castToExpressedType(Type quantizedType) {
168   if (quantizedType.isa<QuantizedType>()) {
169     // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
170     return quantizedType.cast<QuantizedType>().getExpressedType();
171   } else if (quantizedType.isa<ShapedType>()) {
172     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
173     ShapedType sType = quantizedType.cast<ShapedType>();
174     if (!sType.getElementType().isa<QuantizedType>()) {
175       return nullptr;
176     }
177     Type expressedType =
178         sType.getElementType().cast<QuantizedType>().getExpressedType();
179     if (quantizedType.isa<RankedTensorType>()) {
180       return RankedTensorType::get(sType.getShape(), expressedType);
181     } else if (quantizedType.isa<UnrankedTensorType>()) {
182       return UnrankedTensorType::get(expressedType);
183     } else if (quantizedType.isa<VectorType>()) {
184       return VectorType::get(sType.getShape(), expressedType);
185     }
186   }
187 
188   return nullptr;
189 }
190 
castExpressedToStorageType(Type candidateType)191 Type QuantizedType::castExpressedToStorageType(Type candidateType) {
192   Type expressedQuantizedType = castFromExpressedType(candidateType);
193   if (!expressedQuantizedType) {
194     return nullptr;
195   }
196   return QuantizedType::castToStorageType(expressedQuantizedType);
197 }
198 
get(unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)199 AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType,
200                                        Type expressedType,
201                                        int64_t storageTypeMin,
202                                        int64_t storageTypeMax) {
203   return Base::get(storageType.getContext(), QuantizationTypes::Any, flags,
204                    storageType, expressedType, storageTypeMin, storageTypeMax);
205 }
206 
getChecked(unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax,Location location)207 AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
208                                               Type expressedType,
209                                               int64_t storageTypeMin,
210                                               int64_t storageTypeMax,
211                                               Location location) {
212   return Base::getChecked(location, storageType.getContext(),
213                           QuantizationTypes::Any, flags, storageType,
214                           expressedType, storageTypeMin, storageTypeMax);
215 }
216 
verifyConstructionInvariants(Optional<Location> loc,MLIRContext * context,unsigned flags,Type storageType,Type expressedType,int64_t storageTypeMin,int64_t storageTypeMax)217 LogicalResult AnyQuantizedType::verifyConstructionInvariants(
218     Optional<Location> loc, MLIRContext *context, unsigned flags,
219     Type storageType, Type expressedType, int64_t storageTypeMin,
220     int64_t storageTypeMax) {
221   if (failed(QuantizedType::verifyConstructionInvariants(
222           loc, context, flags, storageType, expressedType, storageTypeMin,
223           storageTypeMax))) {
224     return failure();
225   }
226 
227   // Verify that the expressed type is floating point.
228   // If this restriction is ever eliminated, the parser/printer must be
229   // extended.
230   if (expressedType && !expressedType.isa<FloatType>())
231     return emitOptionalError(loc, "expressed type must be floating point");
232 
233   return success();
234 }
235 
get(unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)236 UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
237                                                Type expressedType, double scale,
238                                                int64_t zeroPoint,
239                                                int64_t storageTypeMin,
240                                                int64_t storageTypeMax) {
241   return Base::get(storageType.getContext(),
242                    QuantizationTypes::UniformQuantized, flags, storageType,
243                    expressedType, scale, zeroPoint, storageTypeMin,
244                    storageTypeMax);
245 }
246 
247 UniformQuantizedType
getChecked(unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax,Location location)248 UniformQuantizedType::getChecked(unsigned flags, Type storageType,
249                                  Type expressedType, double scale,
250                                  int64_t zeroPoint, int64_t storageTypeMin,
251                                  int64_t storageTypeMax, Location location) {
252   return Base::getChecked(location, storageType.getContext(),
253                           QuantizationTypes::UniformQuantized, flags,
254                           storageType, expressedType, scale, zeroPoint,
255                           storageTypeMin, storageTypeMax);
256 }
257 
verifyConstructionInvariants(Optional<Location> loc,MLIRContext * context,unsigned flags,Type storageType,Type expressedType,double scale,int64_t zeroPoint,int64_t storageTypeMin,int64_t storageTypeMax)258 LogicalResult UniformQuantizedType::verifyConstructionInvariants(
259     Optional<Location> loc, MLIRContext *context, unsigned flags,
260     Type storageType, Type expressedType, double scale, int64_t zeroPoint,
261     int64_t storageTypeMin, int64_t storageTypeMax) {
262   if (failed(QuantizedType::verifyConstructionInvariants(
263           loc, context, flags, storageType, expressedType, storageTypeMin,
264           storageTypeMax))) {
265     return failure();
266   }
267 
268   // Uniform quantization requires fully expressed parameters, including
269   // expressed type.
270   if (!expressedType)
271     return emitOptionalError(loc,
272                              "uniform quantization requires expressed type");
273 
274   // Verify that the expressed type is floating point.
275   // If this restriction is ever eliminated, the parser/printer must be
276   // extended.
277   if (!expressedType.isa<FloatType>())
278     return emitOptionalError(loc, "expressed type must be floating point");
279 
280   // Verify scale.
281   if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
282     return emitOptionalError(loc, "illegal scale: ", scale);
283 
284   return success();
285 }
286 
getScale() const287 double UniformQuantizedType::getScale() const { return getImpl()->scale; }
288 
getZeroPoint() const289 int64_t UniformQuantizedType::getZeroPoint() const {
290   return getImpl()->zeroPoint;
291 }
292 
get(unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)293 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
294     unsigned flags, Type storageType, Type expressedType,
295     ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
296     int32_t quantizedDimension, int64_t storageTypeMin,
297     int64_t storageTypeMax) {
298   return Base::get(storageType.getContext(),
299                    QuantizationTypes::UniformQuantizedPerAxis, flags,
300                    storageType, expressedType, scales, zeroPoints,
301                    quantizedDimension, storageTypeMin, storageTypeMax);
302 }
303 
getChecked(unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax,Location location)304 UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
305     unsigned flags, Type storageType, Type expressedType,
306     ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
307     int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
308     Location location) {
309   return Base::getChecked(location, storageType.getContext(),
310                           QuantizationTypes::UniformQuantizedPerAxis, flags,
311                           storageType, expressedType, scales, zeroPoints,
312                           quantizedDimension, storageTypeMin, storageTypeMax);
313 }
314 
verifyConstructionInvariants(Optional<Location> loc,MLIRContext * context,unsigned flags,Type storageType,Type expressedType,ArrayRef<double> scales,ArrayRef<int64_t> zeroPoints,int32_t quantizedDimension,int64_t storageTypeMin,int64_t storageTypeMax)315 LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
316     Optional<Location> loc, MLIRContext *context, unsigned flags,
317     Type storageType, Type expressedType, ArrayRef<double> scales,
318     ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
319     int64_t storageTypeMin, int64_t storageTypeMax) {
320   if (failed(QuantizedType::verifyConstructionInvariants(
321           loc, context, flags, storageType, expressedType, storageTypeMin,
322           storageTypeMax))) {
323     return failure();
324   }
325 
326   // Uniform quantization requires fully expressed parameters, including
327   // expressed type.
328   if (!expressedType)
329     return emitOptionalError(loc,
330                              "uniform quantization requires expressed type");
331 
332   // Verify that the expressed type is floating point.
333   // If this restriction is ever eliminated, the parser/printer must be
334   // extended.
335   if (!expressedType.isa<FloatType>())
336     return emitOptionalError(loc, "expressed type must be floating point");
337 
338   // Ensure that the number of scales and zeroPoints match.
339   if (scales.size() != zeroPoints.size())
340     return emitOptionalError(loc, "illegal number of scales and zeroPoints: ",
341                              scales.size(), ", ", zeroPoints.size());
342 
343   // Verify scale.
344   for (double scale : scales) {
345     if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
346       return emitOptionalError(loc, "illegal scale: ", scale);
347   }
348 
349   return success();
350 }
351 
getScales() const352 ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
353   return getImpl()->getScales();
354 }
355 
getZeroPoints() const356 ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
357   return getImpl()->getZeroPoints();
358 }
359 
getQuantizedDimension() const360 int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
361   return getImpl()->quantizedDimension;
362 }
363