1 //===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributes.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/Identifier.h"
12 #include "gtest/gtest.h"
13 
14 using namespace mlir;
15 using namespace mlir::detail;
16 
17 template <typename EltTy>
testSplat(Type eltType,const EltTy & splatElt)18 static void testSplat(Type eltType, const EltTy &splatElt) {
19   RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
20 
21   // Check that the generated splat is the same for 1 element and N elements.
22   DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
23   EXPECT_TRUE(splat.isSplat());
24 
25   auto detectedSplat =
26       DenseElementsAttr::get(shape, llvm::makeArrayRef({splatElt, splatElt}));
27   EXPECT_EQ(detectedSplat, splat);
28 
29   for (auto newValue : detectedSplat.template getValues<EltTy>())
30     EXPECT_TRUE(newValue == splatElt);
31 }
32 
33 namespace {
TEST(DenseSplatTest,BoolSplat)34 TEST(DenseSplatTest, BoolSplat) {
35   MLIRContext context;
36   IntegerType boolTy = IntegerType::get(&context, 1);
37   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
38 
39   // Check that splat is automatically detected for boolean values.
40   /// True.
41   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
42   EXPECT_TRUE(trueSplat.isSplat());
43   /// False.
44   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
45   EXPECT_TRUE(falseSplat.isSplat());
46   EXPECT_NE(falseSplat, trueSplat);
47 
48   /// Detect and handle splat within 8 elements (bool values are bit-packed).
49   /// True.
50   auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
51   EXPECT_EQ(detectedSplat, trueSplat);
52   /// False.
53   detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
54   EXPECT_EQ(detectedSplat, falseSplat);
55 }
56 
TEST(DenseSplatTest,LargeBoolSplat)57 TEST(DenseSplatTest, LargeBoolSplat) {
58   constexpr int64_t boolCount = 56;
59 
60   MLIRContext context;
61   IntegerType boolTy = IntegerType::get(&context, 1);
62   RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
63 
64   // Check that splat is automatically detected for boolean values.
65   /// True.
66   DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
67   DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
68   EXPECT_TRUE(trueSplat.isSplat());
69   EXPECT_TRUE(falseSplat.isSplat());
70 
71   /// Detect that the large boolean arrays are properly splatted.
72   /// True.
73   SmallVector<bool, 64> trueValues(boolCount, true);
74   auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
75   EXPECT_EQ(detectedSplat, trueSplat);
76   /// False.
77   SmallVector<bool, 64> falseValues(boolCount, false);
78   detectedSplat = DenseElementsAttr::get(shape, falseValues);
79   EXPECT_EQ(detectedSplat, falseSplat);
80 }
81 
TEST(DenseSplatTest,BoolNonSplat)82 TEST(DenseSplatTest, BoolNonSplat) {
83   MLIRContext context;
84   IntegerType boolTy = IntegerType::get(&context, 1);
85   RankedTensorType shape = RankedTensorType::get({6}, boolTy);
86 
87   // Check that we properly handle non-splat values.
88   DenseElementsAttr nonSplat =
89       DenseElementsAttr::get(shape, {false, false, true, false, false, true});
90   EXPECT_FALSE(nonSplat.isSplat());
91 }
92 
TEST(DenseSplatTest,OddIntSplat)93 TEST(DenseSplatTest, OddIntSplat) {
94   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
95   MLIRContext context;
96   constexpr size_t intWidth = 19;
97   IntegerType intTy = IntegerType::get(&context, intWidth);
98   APInt value(intWidth, 10);
99 
100   testSplat(intTy, value);
101 }
102 
TEST(DenseSplatTest,Int32Splat)103 TEST(DenseSplatTest, Int32Splat) {
104   MLIRContext context;
105   IntegerType intTy = IntegerType::get(&context, 32);
106   int value = 64;
107 
108   testSplat(intTy, value);
109 }
110 
TEST(DenseSplatTest,IntAttrSplat)111 TEST(DenseSplatTest, IntAttrSplat) {
112   MLIRContext context;
113   IntegerType intTy = IntegerType::get(&context, 85);
114   Attribute value = IntegerAttr::get(intTy, 109);
115 
116   testSplat(intTy, value);
117 }
118 
TEST(DenseSplatTest,F32Splat)119 TEST(DenseSplatTest, F32Splat) {
120   MLIRContext context;
121   FloatType floatTy = FloatType::getF32(&context);
122   float value = 10.0;
123 
124   testSplat(floatTy, value);
125 }
126 
TEST(DenseSplatTest,F64Splat)127 TEST(DenseSplatTest, F64Splat) {
128   MLIRContext context;
129   FloatType floatTy = FloatType::getF64(&context);
130   double value = 10.0;
131 
132   testSplat(floatTy, APFloat(value));
133 }
134 
TEST(DenseSplatTest,FloatAttrSplat)135 TEST(DenseSplatTest, FloatAttrSplat) {
136   MLIRContext context;
137   FloatType floatTy = FloatType::getF32(&context);
138   Attribute value = FloatAttr::get(floatTy, 10.0);
139 
140   testSplat(floatTy, value);
141 }
142 
TEST(DenseSplatTest,BF16Splat)143 TEST(DenseSplatTest, BF16Splat) {
144   MLIRContext context;
145   FloatType floatTy = FloatType::getBF16(&context);
146   Attribute value = FloatAttr::get(floatTy, 10.0);
147 
148   testSplat(floatTy, value);
149 }
150 
TEST(DenseSplatTest,StringSplat)151 TEST(DenseSplatTest, StringSplat) {
152   MLIRContext context;
153   context.allowUnregisteredDialects();
154   Type stringType =
155       OpaqueType::get(Identifier::get("test", &context), "string");
156   StringRef value = "test-string";
157   testSplat(stringType, value);
158 }
159 
TEST(DenseSplatTest,StringAttrSplat)160 TEST(DenseSplatTest, StringAttrSplat) {
161   MLIRContext context;
162   context.allowUnregisteredDialects();
163   Type stringType =
164       OpaqueType::get(Identifier::get("test", &context), "string");
165   Attribute stringAttr = StringAttr::get("test-string", stringType);
166   testSplat(stringType, stringAttr);
167 }
168 
TEST(DenseComplexTest,ComplexFloatSplat)169 TEST(DenseComplexTest, ComplexFloatSplat) {
170   MLIRContext context;
171   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
172   std::complex<float> value(10.0, 15.0);
173   testSplat(complexType, value);
174 }
175 
TEST(DenseComplexTest,ComplexIntSplat)176 TEST(DenseComplexTest, ComplexIntSplat) {
177   MLIRContext context;
178   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
179   std::complex<int64_t> value(10, 15);
180   testSplat(complexType, value);
181 }
182 
TEST(DenseComplexTest,ComplexAPFloatSplat)183 TEST(DenseComplexTest, ComplexAPFloatSplat) {
184   MLIRContext context;
185   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
186   std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
187   testSplat(complexType, value);
188 }
189 
TEST(DenseComplexTest,ComplexAPIntSplat)190 TEST(DenseComplexTest, ComplexAPIntSplat) {
191   MLIRContext context;
192   ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
193   std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
194   testSplat(complexType, value);
195 }
196 
TEST(DenseScalarTest,ExtractZeroRankElement)197 TEST(DenseScalarTest, ExtractZeroRankElement) {
198   MLIRContext context;
199   const int elementValue = 12;
200   IntegerType intTy = IntegerType::get(&context, 32);
201   Attribute value = IntegerAttr::get(intTy, elementValue);
202   RankedTensorType shape = RankedTensorType::get({}, intTy);
203 
204   auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
205   EXPECT_TRUE(attr.getValue({0}) == value);
206 }
207 
208 } // end namespace
209