1 //===- StructsGenTest.cpp - TableGen StructsGen Tests ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Attributes.h"
10 #include "mlir/IR/Identifier.h"
11 #include "mlir/IR/StandardTypes.h"
12 #include "llvm/ADT/DenseMap.h"
13 #include "llvm/ADT/Optional.h"
14 #include "llvm/ADT/StringSwitch.h"
15 #include "gmock/gmock.h"
16 #include <type_traits>
17 
18 namespace mlir {
19 
20 /// Pull in generated enum utility declarations and definitions.
21 #include "StructAttrGenTest.h.inc"
22 #include "StructAttrGenTest.cpp.inc"
23 
24 /// Helper that returns an example test::TestStruct for testing its
25 /// implementation.
getTestStruct(mlir::MLIRContext * context)26 static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
27   auto integerType = mlir::IntegerType::get(32, context);
28   auto integerAttr = mlir::IntegerAttr::get(integerType, 127);
29 
30   auto floatType = mlir::FloatType::getF32(context);
31   auto floatAttr = mlir::FloatAttr::get(floatType, 0.25);
32 
33   auto elementsType = mlir::RankedTensorType::get({2, 3}, integerType);
34   auto elementsAttr =
35       mlir::DenseIntElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6});
36   auto optionalAttr = nullptr;
37 
38   return test::TestStruct::get(integerAttr, floatAttr, elementsAttr,
39                                optionalAttr, context);
40 }
41 
42 /// Validates that test::TestStruct::classof correctly identifies a valid
43 /// test::TestStruct.
TEST(StructsGenTest,ClassofTrue)44 TEST(StructsGenTest, ClassofTrue) {
45   mlir::MLIRContext context;
46   auto structAttr = getTestStruct(&context);
47   ASSERT_TRUE(test::TestStruct::classof(structAttr));
48 }
49 
50 /// Validates that test::TestStruct::classof fails when an extra attribute is in
51 /// the class.
TEST(StructsGenTest,ClassofExtraFalse)52 TEST(StructsGenTest, ClassofExtraFalse) {
53   mlir::MLIRContext context;
54   mlir::DictionaryAttr structAttr = getTestStruct(&context);
55   auto expectedValues = structAttr.getValue();
56   ASSERT_EQ(expectedValues.size(), 3u);
57 
58   // Copy the set of named attributes.
59   llvm::SmallVector<mlir::NamedAttribute, 5> newValues(expectedValues.begin(),
60                                                        expectedValues.end());
61 
62   // Add an extra NamedAttribute.
63   auto wrongId = mlir::Identifier::get("wrong", &context);
64   auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second);
65   newValues.push_back(wrongAttr);
66 
67   // Make a new DictionaryAttr and validate.
68   auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
69   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
70 }
71 
72 /// Validates that test::TestStruct::classof fails when a NamedAttribute has an
73 /// incorrect name.
TEST(StructsGenTest,ClassofBadNameFalse)74 TEST(StructsGenTest, ClassofBadNameFalse) {
75   mlir::MLIRContext context;
76   mlir::DictionaryAttr structAttr = getTestStruct(&context);
77   auto expectedValues = structAttr.getValue();
78   ASSERT_EQ(expectedValues.size(), 3u);
79 
80   // Create a copy of all but the first NamedAttributes.
81   llvm::SmallVector<mlir::NamedAttribute, 4> newValues(
82       expectedValues.begin() + 1, expectedValues.end());
83 
84   // Add a copy of the first attribute with the wrong Identifier.
85   auto wrongId = mlir::Identifier::get("wrong", &context);
86   auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second);
87   newValues.push_back(wrongAttr);
88 
89   auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
90   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
91 }
92 
93 /// Validates that test::TestStruct::classof fails when a NamedAttribute has an
94 /// incorrect type.
TEST(StructsGenTest,ClassofBadTypeFalse)95 TEST(StructsGenTest, ClassofBadTypeFalse) {
96   mlir::MLIRContext context;
97   mlir::DictionaryAttr structAttr = getTestStruct(&context);
98   auto expectedValues = structAttr.getValue();
99   ASSERT_EQ(expectedValues.size(), 3u);
100 
101   // Create a copy of all but the last NamedAttributes.
102   llvm::SmallVector<mlir::NamedAttribute, 4> newValues(
103       expectedValues.begin(), expectedValues.end() - 1);
104 
105   // Add a copy of the last attribute with the wrong type.
106   auto i64Type = mlir::IntegerType::get(64, &context);
107   auto elementsType = mlir::RankedTensorType::get({3}, i64Type);
108   auto elementsAttr =
109       mlir::DenseIntElementsAttr::get(elementsType, ArrayRef<int64_t>{1, 2, 3});
110   mlir::Identifier id = expectedValues.back().first;
111   auto wrongAttr = mlir::NamedAttribute(id, elementsAttr);
112   newValues.push_back(wrongAttr);
113 
114   auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
115   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
116 }
117 
118 /// Validates that test::TestStruct::classof fails when a NamedAttribute is
119 /// missing.
TEST(StructsGenTest,ClassofMissingFalse)120 TEST(StructsGenTest, ClassofMissingFalse) {
121   mlir::MLIRContext context;
122   mlir::DictionaryAttr structAttr = getTestStruct(&context);
123   auto expectedValues = structAttr.getValue();
124   ASSERT_EQ(expectedValues.size(), 3u);
125 
126   // Copy a subset of the structures Named Attributes.
127   llvm::SmallVector<mlir::NamedAttribute, 3> newValues(
128       expectedValues.begin() + 1, expectedValues.end());
129 
130   // Make a new DictionaryAttr and validate it is not a validate TestStruct.
131   auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
132   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
133 }
134 
135 /// Validate the accessor for the FloatAttr value.
TEST(StructsGenTest,GetFloat)136 TEST(StructsGenTest, GetFloat) {
137   mlir::MLIRContext context;
138   auto structAttr = getTestStruct(&context);
139   auto returnedAttr = structAttr.sample_float();
140   EXPECT_EQ(returnedAttr.getValueAsDouble(), 0.25);
141 }
142 
143 /// Validate the accessor for the IntegerAttr value.
TEST(StructsGenTest,GetInteger)144 TEST(StructsGenTest, GetInteger) {
145   mlir::MLIRContext context;
146   auto structAttr = getTestStruct(&context);
147   auto returnedAttr = structAttr.sample_integer();
148   EXPECT_EQ(returnedAttr.getInt(), 127);
149 }
150 
151 /// Validate the accessor for the ElementsAttr value.
TEST(StructsGenTest,GetElements)152 TEST(StructsGenTest, GetElements) {
153   mlir::MLIRContext context;
154   auto structAttr = getTestStruct(&context);
155   auto returnedAttr = structAttr.sample_elements();
156   auto denseAttr = returnedAttr.dyn_cast<mlir::DenseElementsAttr>();
157   ASSERT_TRUE(denseAttr);
158 
159   for (const auto &valIndexIt : llvm::enumerate(denseAttr.getIntValues())) {
160     EXPECT_EQ(valIndexIt.value(), valIndexIt.index() + 1);
161   }
162 }
163 
TEST(StructsGenTest,EmptyOptional)164 TEST(StructsGenTest, EmptyOptional) {
165   mlir::MLIRContext context;
166   auto structAttr = getTestStruct(&context);
167   EXPECT_EQ(structAttr.sample_optional_integer(), nullptr);
168 }
169 
170 } // namespace mlir
171