1 //===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/QuantOps/QuantizeUtils.h"
10 #include "mlir/Dialect/QuantOps/UniformSupport.h"
11 #include "mlir/IR/Attributes.h"
12 #include "mlir/IR/StandardTypes.h"
13 #include "gmock/gmock.h"
14 #include "gtest/gtest.h"
15
16 using namespace mlir;
17 using namespace mlir::quant;
18
19 namespace {
20
21 // Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
22 class TestUniformQuantizedValueConverter
23 : public UniformQuantizedValueConverter {
24 public:
TestUniformQuantizedValueConverter(UniformQuantizedType type)25 TestUniformQuantizedValueConverter(UniformQuantizedType type)
26 : UniformQuantizedValueConverter(type), qtype(type) {}
quantizeFloatToInt(APFloat expressedValue) const27 APInt quantizeFloatToInt(APFloat expressedValue) const {
28 return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L);
29 }
30
31 private:
32 UniformQuantizedType qtype;
33 };
34
getTestFloatAttr(double value,MLIRContext * ctx)35 Attribute getTestFloatAttr(double value, MLIRContext *ctx) {
36 return FloatAttr::get(FloatType::getF32(ctx), value);
37 }
38
39 template <typename ConcreteAttrClass, typename... Arg>
40 ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
41 Arg... value) {
42 auto eleType = FloatType::getF32(ctx);
43 ShapedType tensorType;
44 if (shape.size() == 1 && shape[0] == -1) {
45 tensorType = UnrankedTensorType::get(eleType);
46 } else {
47 tensorType = RankedTensorType::get(shape, eleType);
48 }
49 return ConcreteAttrClass::get(tensorType, value...);
50 }
51
getTestSparseElementsAttr(MLIRContext * ctx,ArrayRef<int64_t> shape)52 ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
53 ArrayRef<int64_t> shape) {
54 auto eleType = FloatType::getF32(ctx);
55 ShapedType tensorType;
56 if (shape.size() == 1 && shape[0] == -1) {
57 tensorType = UnrankedTensorType::get(eleType);
58 } else {
59 tensorType = RankedTensorType::get(shape, eleType);
60 }
61 auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx));
62 auto indices =
63 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
64 auto valuesType = RankedTensorType::get({1}, eleType);
65 auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)});
66 return SparseElementsAttr::get(tensorType, indices, values);
67 }
68
getTestQuantizedType(Type storageType,MLIRContext * ctx)69 UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
70 return UniformQuantizedType::get(/*flags=*/false, storageType,
71 FloatType::getF32(ctx), /*scale=*/1.0,
72 /*zeroPoint=*/0, /*storageTypeMin=*/0,
73 /*storageTypeMax=*/255);
74 }
75
TEST(QuantizationUtilsTest,convertFloatAttrUniform)76 TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
77 MLIRContext ctx;
78 IntegerType convertedType = IntegerType::get(8, &ctx);
79 auto quantizedType = getTestQuantizedType(convertedType, &ctx);
80 TestUniformQuantizedValueConverter converter(quantizedType);
81
82 auto realValue = getTestFloatAttr(1.0, &ctx);
83 Type typeResult;
84 auto valueResult =
85 quantizeAttrUniform(realValue, quantizedType, converter, typeResult);
86
87 EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5);
88 EXPECT_EQ(
89 valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(),
90 convertedType.getWidth());
91 }
92
TEST(QuantizationUtilsTest,convertRankedDenseAttrUniform)93 TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
94 MLIRContext ctx;
95 IntegerType convertedType = IntegerType::get(8, &ctx);
96 auto quantizedType = getTestQuantizedType(convertedType, &ctx);
97 TestUniformQuantizedValueConverter converter(quantizedType);
98 auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
99 &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)});
100
101 Type returnedType;
102 auto returnedValue =
103 quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
104
105 // Check Elements attribute shape and kind are not changed.
106 auto tensorType = returnedType.cast<TensorType>();
107 auto expectedTensorType = realValue.getType().cast<TensorType>();
108 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
109 EXPECT_EQ(tensorType.getElementType(), convertedType);
110 EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>());
111
112 // Check Elements attribute element value is expected.
113 auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
114 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
115 }
116
TEST(QuantizationUtilsTest,convertRankedSplatAttrUniform)117 TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
118 MLIRContext ctx;
119 IntegerType convertedType = IntegerType::get(8, &ctx);
120 auto quantizedType = getTestQuantizedType(convertedType, &ctx);
121 TestUniformQuantizedValueConverter converter(quantizedType);
122 auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
123 &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
124
125 Type returnedType;
126 auto returnedValue =
127 quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
128
129 // Check Elements attribute shape and kind are not changed.
130 auto tensorType = returnedType.cast<TensorType>();
131 auto expectedTensorType = realValue.getType().cast<TensorType>();
132 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
133 EXPECT_EQ(tensorType.getElementType(), convertedType);
134 EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>());
135
136 // Check Elements attribute element value is expected.
137 auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
138 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
139 }
140
TEST(QuantizationUtilsTest,convertRankedSparseAttrUniform)141 TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
142 MLIRContext ctx;
143 IntegerType convertedType = IntegerType::get(8, &ctx);
144 auto quantizedType = getTestQuantizedType(convertedType, &ctx);
145 TestUniformQuantizedValueConverter converter(quantizedType);
146 auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
147
148 Type returnedType;
149 auto returnedValue =
150 quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
151
152 // Check Elements attribute shape and kind are not changed.
153 auto tensorType = returnedType.cast<TensorType>();
154 auto expectedTensorType = realValue.getType().cast<TensorType>();
155 EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
156 EXPECT_EQ(tensorType.getElementType(), convertedType);
157 EXPECT_EQ(returnedValue.getKind(), StandardAttributes::SparseElements);
158
159 // Check Elements attribute element value is expected.
160 auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
161 EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
162 }
163
164 } // end namespace
165