1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/Identifier.h"
12 #include "gtest/gtest.h"
13
14 using namespace mlir;
15 using namespace mlir::detail;
16
17 template <typename EltTy>
testSplat(Type eltType,const EltTy & splatElt)18 static void testSplat(Type eltType, const EltTy &splatElt) {
19 RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
20
21 // Check that the generated splat is the same for 1 element and N elements.
22 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
23 EXPECT_TRUE(splat.isSplat());
24
25 auto detectedSplat =
26 DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
27 EXPECT_EQ(detectedSplat, splat);
28
29 for (auto newValue : detectedSplat.template getValues<EltTy>())
30 EXPECT_TRUE(newValue == splatElt);
31 }
32
33 namespace {
TEST(DenseSplatTest,BoolSplat)34 TEST(DenseSplatTest, BoolSplat) {
35 MLIRContext context;
36 IntegerType boolTy = IntegerType::get(&context, 1);
37 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
38
39 // Check that splat is automatically detected for boolean values.
40 /// True.
41 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
42 EXPECT_TRUE(trueSplat.isSplat());
43 /// False.
44 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
45 EXPECT_TRUE(falseSplat.isSplat());
46 EXPECT_NE(falseSplat, trueSplat);
47
48 /// Detect and handle splat within 8 elements (bool values are bit-packed).
49 /// True.
50 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
51 EXPECT_EQ(detectedSplat, trueSplat);
52 /// False.
53 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
54 EXPECT_EQ(detectedSplat, falseSplat);
55 }
56
TEST(DenseSplatTest,LargeBoolSplat)57 TEST(DenseSplatTest, LargeBoolSplat) {
58 constexpr int64_t boolCount = 56;
59
60 MLIRContext context;
61 IntegerType boolTy = IntegerType::get(&context, 1);
62 RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
63
64 // Check that splat is automatically detected for boolean values.
65 /// True.
66 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
67 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
68 EXPECT_TRUE(trueSplat.isSplat());
69 EXPECT_TRUE(falseSplat.isSplat());
70
71 /// Detect that the large boolean arrays are properly splatted.
72 /// True.
73 SmallVector<bool, 64> trueValues(boolCount, true);
74 auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
75 EXPECT_EQ(detectedSplat, trueSplat);
76 /// False.
77 SmallVector<bool, 64> falseValues(boolCount, false);
78 detectedSplat = DenseElementsAttr::get(shape, falseValues);
79 EXPECT_EQ(detectedSplat, falseSplat);
80 }
81
TEST(DenseSplatTest,BoolNonSplat)82 TEST(DenseSplatTest, BoolNonSplat) {
83 MLIRContext context;
84 IntegerType boolTy = IntegerType::get(&context, 1);
85 RankedTensorType shape = RankedTensorType::get({6}, boolTy);
86
87 // Check that we properly handle non-splat values.
88 DenseElementsAttr nonSplat =
89 DenseElementsAttr::get(shape, {false, false, true, false, false, true});
90 EXPECT_FALSE(nonSplat.isSplat());
91 }
92
TEST(DenseSplatTest,OddIntSplat)93 TEST(DenseSplatTest, OddIntSplat) {
94 // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
95 MLIRContext context;
96 constexpr size_t intWidth = 19;
97 IntegerType intTy = IntegerType::get(&context, intWidth);
98 APInt value(intWidth, 10);
99
100 testSplat(intTy, value);
101 }
102
TEST(DenseSplatTest,Int32Splat)103 TEST(DenseSplatTest, Int32Splat) {
104 MLIRContext context;
105 IntegerType intTy = IntegerType::get(&context, 32);
106 int value = 64;
107
108 testSplat(intTy, value);
109 }
110
TEST(DenseSplatTest,IntAttrSplat)111 TEST(DenseSplatTest, IntAttrSplat) {
112 MLIRContext context;
113 IntegerType intTy = IntegerType::get(&context, 85);
114 Attribute value = IntegerAttr::get(intTy, 109);
115
116 testSplat(intTy, value);
117 }
118
TEST(DenseSplatTest,F32Splat)119 TEST(DenseSplatTest, F32Splat) {
120 MLIRContext context;
121 FloatType floatTy = FloatType::getF32(&context);
122 float value = 10.0;
123
124 testSplat(floatTy, value);
125 }
126
TEST(DenseSplatTest,F64Splat)127 TEST(DenseSplatTest, F64Splat) {
128 MLIRContext context;
129 FloatType floatTy = FloatType::getF64(&context);
130 double value = 10.0;
131
132 testSplat(floatTy, APFloat(value));
133 }
134
TEST(DenseSplatTest,FloatAttrSplat)135 TEST(DenseSplatTest, FloatAttrSplat) {
136 MLIRContext context;
137 FloatType floatTy = FloatType::getF32(&context);
138 Attribute value = FloatAttr::get(floatTy, 10.0);
139
140 testSplat(floatTy, value);
141 }
142
TEST(DenseSplatTest,BF16Splat)143 TEST(DenseSplatTest, BF16Splat) {
144 MLIRContext context;
145 FloatType floatTy = FloatType::getBF16(&context);
146 Attribute value = FloatAttr::get(floatTy, 10.0);
147
148 testSplat(floatTy, value);
149 }
150
TEST(DenseSplatTest,StringSplat)151 TEST(DenseSplatTest, StringSplat) {
152 MLIRContext context;
153 context.allowUnregisteredDialects();
154 Type stringType =
155 OpaqueType::get(Identifier::get("test", &context), "string");
156 StringRef value = "test-string";
157 testSplat(stringType, value);
158 }
159
TEST(DenseSplatTest,StringAttrSplat)160 TEST(DenseSplatTest, StringAttrSplat) {
161 MLIRContext context;
162 context.allowUnregisteredDialects();
163 Type stringType =
164 OpaqueType::get(Identifier::get("test", &context), "string");
165 Attribute stringAttr = StringAttr::get("test-string", stringType);
166 testSplat(stringType, stringAttr);
167 }
168
TEST(DenseComplexTest,ComplexFloatSplat)169 TEST(DenseComplexTest, ComplexFloatSplat) {
170 MLIRContext context;
171 ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
172 std::complex<float> value(10.0, 15.0);
173 testSplat(complexType, value);
174 }
175
TEST(DenseComplexTest,ComplexIntSplat)176 TEST(DenseComplexTest, ComplexIntSplat) {
177 MLIRContext context;
178 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
179 std::complex<int64_t> value(10, 15);
180 testSplat(complexType, value);
181 }
182
TEST(DenseComplexTest,ComplexAPFloatSplat)183 TEST(DenseComplexTest, ComplexAPFloatSplat) {
184 MLIRContext context;
185 ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
186 std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
187 testSplat(complexType, value);
188 }
189
TEST(DenseComplexTest,ComplexAPIntSplat)190 TEST(DenseComplexTest, ComplexAPIntSplat) {
191 MLIRContext context;
192 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
193 std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
194 testSplat(complexType, value);
195 }
196
TEST(DenseScalarTest,ExtractZeroRankElement)197 TEST(DenseScalarTest, ExtractZeroRankElement) {
198 MLIRContext context;
199 const int elementValue = 12;
200 IntegerType intTy = IntegerType::get(&context, 32);
201 Attribute value = IntegerAttr::get(intTy, elementValue);
202 RankedTensorType shape = RankedTensorType::get({}, intTy);
203
204 auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
205 EXPECT_TRUE(attr.getValue({0}) == value);
206 }
207
208 } // end namespace
209