1 //===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/QuantOps/QuantizeUtils.h"
10 #include "mlir/Dialect/QuantOps/UniformSupport.h"
11 #include "mlir/IR/Attributes.h"
12 #include "mlir/IR/StandardTypes.h"
13 #include "gmock/gmock.h"
14 #include "gtest/gtest.h"
15 
16 using namespace mlir;
17 using namespace mlir::quant;
18 
19 namespace {
20 
21 // Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
22 class TestUniformQuantizedValueConverter
23     : public UniformQuantizedValueConverter {
24 public:
TestUniformQuantizedValueConverter(UniformQuantizedType type)25   TestUniformQuantizedValueConverter(UniformQuantizedType type)
26       : UniformQuantizedValueConverter(type), qtype(type) {}
quantizeFloatToInt(APFloat expressedValue) const27   APInt quantizeFloatToInt(APFloat expressedValue) const {
28     return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L);
29   }
30 
31 private:
32   UniformQuantizedType qtype;
33 };
34 
getTestFloatAttr(double value,MLIRContext * ctx)35 Attribute getTestFloatAttr(double value, MLIRContext *ctx) {
36   return FloatAttr::get(FloatType::getF32(ctx), value);
37 }
38 
39 template <typename ConcreteAttrClass, typename... Arg>
40 ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
41                                       Arg... value) {
42   auto eleType = FloatType::getF32(ctx);
43   ShapedType tensorType;
44   if (shape.size() == 1 && shape[0] == -1) {
45     tensorType = UnrankedTensorType::get(eleType);
46   } else {
47     tensorType = RankedTensorType::get(shape, eleType);
48   }
49   return ConcreteAttrClass::get(tensorType, value...);
50 }
51 
getTestSparseElementsAttr(MLIRContext * ctx,ArrayRef<int64_t> shape)52 ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
53                                        ArrayRef<int64_t> shape) {
54   auto eleType = FloatType::getF32(ctx);
55   ShapedType tensorType;
56   if (shape.size() == 1 && shape[0] == -1) {
57     tensorType = UnrankedTensorType::get(eleType);
58   } else {
59     tensorType = RankedTensorType::get(shape, eleType);
60   }
61   auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx));
62   auto indices =
63       DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
64   auto valuesType = RankedTensorType::get({1}, eleType);
65   auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)});
66   return SparseElementsAttr::get(tensorType, indices, values);
67 }
68 
getTestQuantizedType(Type storageType,MLIRContext * ctx)69 UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
70   return UniformQuantizedType::get(/*flags=*/false, storageType,
71                                    FloatType::getF32(ctx), /*scale=*/1.0,
72                                    /*zeroPoint=*/0, /*storageTypeMin=*/0,
73                                    /*storageTypeMax=*/255);
74 }
75 
TEST(QuantizationUtilsTest,convertFloatAttrUniform)76 TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
77   MLIRContext ctx;
78   IntegerType convertedType = IntegerType::get(8, &ctx);
79   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
80   TestUniformQuantizedValueConverter converter(quantizedType);
81 
82   auto realValue = getTestFloatAttr(1.0, &ctx);
83   Type typeResult;
84   auto valueResult =
85       quantizeAttrUniform(realValue, quantizedType, converter, typeResult);
86 
87   EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5);
88   EXPECT_EQ(
89       valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(),
90       convertedType.getWidth());
91 }
92 
TEST(QuantizationUtilsTest,convertRankedDenseAttrUniform)93 TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
94   MLIRContext ctx;
95   IntegerType convertedType = IntegerType::get(8, &ctx);
96   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
97   TestUniformQuantizedValueConverter converter(quantizedType);
98   auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
99       &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)});
100 
101   Type returnedType;
102   auto returnedValue =
103       quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
104 
105   // Check Elements attribute shape and kind are not changed.
106   auto tensorType = returnedType.cast<TensorType>();
107   auto expectedTensorType = realValue.getType().cast<TensorType>();
108   EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
109   EXPECT_EQ(tensorType.getElementType(), convertedType);
110   EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>());
111 
112   // Check Elements attribute element value is expected.
113   auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
114   EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
115 }
116 
TEST(QuantizationUtilsTest,convertRankedSplatAttrUniform)117 TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
118   MLIRContext ctx;
119   IntegerType convertedType = IntegerType::get(8, &ctx);
120   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
121   TestUniformQuantizedValueConverter converter(quantizedType);
122   auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
123       &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
124 
125   Type returnedType;
126   auto returnedValue =
127       quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
128 
129   // Check Elements attribute shape and kind are not changed.
130   auto tensorType = returnedType.cast<TensorType>();
131   auto expectedTensorType = realValue.getType().cast<TensorType>();
132   EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
133   EXPECT_EQ(tensorType.getElementType(), convertedType);
134   EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>());
135 
136   // Check Elements attribute element value is expected.
137   auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
138   EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
139 }
140 
TEST(QuantizationUtilsTest,convertRankedSparseAttrUniform)141 TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
142   MLIRContext ctx;
143   IntegerType convertedType = IntegerType::get(8, &ctx);
144   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
145   TestUniformQuantizedValueConverter converter(quantizedType);
146   auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
147 
148   Type returnedType;
149   auto returnedValue =
150       quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
151 
152   // Check Elements attribute shape and kind are not changed.
153   auto tensorType = returnedType.cast<TensorType>();
154   auto expectedTensorType = realValue.getType().cast<TensorType>();
155   EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
156   EXPECT_EQ(tensorType.getElementType(), convertedType);
157   EXPECT_EQ(returnedValue.getKind(), StandardAttributes::SparseElements);
158 
159   // Check Elements attribute element value is expected.
160   auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
161   EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
162 }
163 
164 } // end namespace
165