1 //===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Quant/QuantOps.h"
10 #include "mlir/Dialect/Quant/QuantizeUtils.h"
11 #include "mlir/Dialect/Quant/UniformSupport.h"
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/StandardTypes.h"
14 #include "gmock/gmock.h"
15 #include "gtest/gtest.h"
16 
17 using namespace mlir;
18 using namespace mlir::quant;
19 
20 namespace {
21 
22 // Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
23 class TestUniformQuantizedValueConverter
24     : public UniformQuantizedValueConverter {
25 public:
TestUniformQuantizedValueConverter(UniformQuantizedType type)26   TestUniformQuantizedValueConverter(UniformQuantizedType type)
27       : UniformQuantizedValueConverter(type), qtype(type) {}
quantizeFloatToInt(APFloat expressedValue) const28   APInt quantizeFloatToInt(APFloat expressedValue) const {
29     return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L);
30   }
31 
32 private:
33   UniformQuantizedType qtype;
34 };
35 
getTestFloatAttr(double value,MLIRContext * ctx)36 Attribute getTestFloatAttr(double value, MLIRContext *ctx) {
37   return FloatAttr::get(FloatType::getF32(ctx), value);
38 }
39 
40 template <typename ConcreteAttrClass, typename... Arg>
41 ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
42                                       Arg... value) {
43   auto eleType = FloatType::getF32(ctx);
44   ShapedType tensorType;
45   if (shape.size() == 1 && shape[0] == -1) {
46     tensorType = UnrankedTensorType::get(eleType);
47   } else {
48     tensorType = RankedTensorType::get(shape, eleType);
49   }
50   return ConcreteAttrClass::get(tensorType, value...);
51 }
52 
getTestSparseElementsAttr(MLIRContext * ctx,ArrayRef<int64_t> shape)53 ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
54                                        ArrayRef<int64_t> shape) {
55   auto eleType = FloatType::getF32(ctx);
56   ShapedType tensorType;
57   if (shape.size() == 1 && shape[0] == -1) {
58     tensorType = UnrankedTensorType::get(eleType);
59   } else {
60     tensorType = RankedTensorType::get(shape, eleType);
61   }
62   auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx));
63   auto indices =
64       DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
65   auto valuesType = RankedTensorType::get({1}, eleType);
66   auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)});
67   return SparseElementsAttr::get(tensorType, indices, values);
68 }
69 
getTestQuantizedType(Type storageType,MLIRContext * ctx)70 UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
71   return UniformQuantizedType::get(/*flags=*/false, storageType,
72                                    FloatType::getF32(ctx), /*scale=*/1.0,
73                                    /*zeroPoint=*/0, /*storageTypeMin=*/0,
74                                    /*storageTypeMax=*/255);
75 }
76 
TEST(QuantizationUtilsTest,convertFloatAttrUniform)77 TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
78   MLIRContext ctx;
79   ctx.getOrLoadDialect<QuantizationDialect>();
80   IntegerType convertedType = IntegerType::get(8, &ctx);
81   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
82   TestUniformQuantizedValueConverter converter(quantizedType);
83 
84   auto realValue = getTestFloatAttr(1.0, &ctx);
85   Type typeResult;
86   auto valueResult =
87       quantizeAttrUniform(realValue, quantizedType, converter, typeResult);
88 
89   EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5);
90   EXPECT_EQ(
91       valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(),
92       convertedType.getWidth());
93 }
94 
TEST(QuantizationUtilsTest,convertRankedDenseAttrUniform)95 TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
96   MLIRContext ctx;
97   ctx.getOrLoadDialect<QuantizationDialect>();
98   IntegerType convertedType = IntegerType::get(8, &ctx);
99   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
100   TestUniformQuantizedValueConverter converter(quantizedType);
101   auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
102       &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)});
103 
104   Type returnedType;
105   auto returnedValue =
106       quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
107 
108   // Check Elements attribute shape and kind are not changed.
109   auto tensorType = returnedType.cast<TensorType>();
110   auto expectedTensorType = realValue.getType().cast<TensorType>();
111   EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
112   EXPECT_EQ(tensorType.getElementType(), convertedType);
113   EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>());
114 
115   // Check Elements attribute element value is expected.
116   auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
117   EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
118 }
119 
TEST(QuantizationUtilsTest,convertRankedSplatAttrUniform)120 TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
121   MLIRContext ctx;
122   ctx.getOrLoadDialect<QuantizationDialect>();
123   IntegerType convertedType = IntegerType::get(8, &ctx);
124   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
125   TestUniformQuantizedValueConverter converter(quantizedType);
126   auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
127       &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
128 
129   Type returnedType;
130   auto returnedValue =
131       quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
132 
133   // Check Elements attribute shape and kind are not changed.
134   auto tensorType = returnedType.cast<TensorType>();
135   auto expectedTensorType = realValue.getType().cast<TensorType>();
136   EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
137   EXPECT_EQ(tensorType.getElementType(), convertedType);
138   EXPECT_TRUE(returnedValue.isa<SplatElementsAttr>());
139 
140   // Check Elements attribute element value is expected.
141   auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
142   EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
143 }
144 
TEST(QuantizationUtilsTest,convertRankedSparseAttrUniform)145 TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
146   MLIRContext ctx;
147   ctx.getOrLoadDialect<QuantizationDialect>();
148   IntegerType convertedType = IntegerType::get(8, &ctx);
149   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
150   TestUniformQuantizedValueConverter converter(quantizedType);
151   auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
152 
153   Type returnedType;
154   auto returnedValue =
155       quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
156 
157   // Check Elements attribute shape and kind are not changed.
158   auto tensorType = returnedType.cast<TensorType>();
159   auto expectedTensorType = realValue.getType().cast<TensorType>();
160   EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
161   EXPECT_EQ(tensorType.getElementType(), convertedType);
162   EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>());
163 
164   // Check Elements attribute element value is expected.
165   auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
166   EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
167 }
168 
169 } // end namespace
170