1 //===- StructsGenTest.cpp - TableGen StructsGen Tests ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/Attributes.h"
10 #include "mlir/IR/Identifier.h"
11 #include "mlir/IR/StandardTypes.h"
12 #include "llvm/ADT/DenseMap.h"
13 #include "llvm/ADT/Optional.h"
14 #include "llvm/ADT/StringSwitch.h"
15 #include "gmock/gmock.h"
16 #include <type_traits>
17
18 namespace mlir {
19
20 /// Pull in generated enum utility declarations and definitions.
21 #include "StructAttrGenTest.h.inc"
22 #include "StructAttrGenTest.cpp.inc"
23
24 /// Helper that returns an example test::TestStruct for testing its
25 /// implementation.
getTestStruct(mlir::MLIRContext * context)26 static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
27 auto integerType = mlir::IntegerType::get(32, context);
28 auto integerAttr = mlir::IntegerAttr::get(integerType, 127);
29
30 auto floatType = mlir::FloatType::getF32(context);
31 auto floatAttr = mlir::FloatAttr::get(floatType, 0.25);
32
33 auto elementsType = mlir::RankedTensorType::get({2, 3}, integerType);
34 auto elementsAttr =
35 mlir::DenseIntElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6});
36 auto optionalAttr = nullptr;
37
38 return test::TestStruct::get(integerAttr, floatAttr, elementsAttr,
39 optionalAttr, context);
40 }
41
42 /// Validates that test::TestStruct::classof correctly identifies a valid
43 /// test::TestStruct.
TEST(StructsGenTest,ClassofTrue)44 TEST(StructsGenTest, ClassofTrue) {
45 mlir::MLIRContext context;
46 auto structAttr = getTestStruct(&context);
47 ASSERT_TRUE(test::TestStruct::classof(structAttr));
48 }
49
50 /// Validates that test::TestStruct::classof fails when an extra attribute is in
51 /// the class.
TEST(StructsGenTest,ClassofExtraFalse)52 TEST(StructsGenTest, ClassofExtraFalse) {
53 mlir::MLIRContext context;
54 mlir::DictionaryAttr structAttr = getTestStruct(&context);
55 auto expectedValues = structAttr.getValue();
56 ASSERT_EQ(expectedValues.size(), 3u);
57
58 // Copy the set of named attributes.
59 llvm::SmallVector<mlir::NamedAttribute, 5> newValues(expectedValues.begin(),
60 expectedValues.end());
61
62 // Add an extra NamedAttribute.
63 auto wrongId = mlir::Identifier::get("wrong", &context);
64 auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second);
65 newValues.push_back(wrongAttr);
66
67 // Make a new DictionaryAttr and validate.
68 auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
69 ASSERT_FALSE(test::TestStruct::classof(badDictionary));
70 }
71
72 /// Validates that test::TestStruct::classof fails when a NamedAttribute has an
73 /// incorrect name.
TEST(StructsGenTest,ClassofBadNameFalse)74 TEST(StructsGenTest, ClassofBadNameFalse) {
75 mlir::MLIRContext context;
76 mlir::DictionaryAttr structAttr = getTestStruct(&context);
77 auto expectedValues = structAttr.getValue();
78 ASSERT_EQ(expectedValues.size(), 3u);
79
80 // Create a copy of all but the first NamedAttributes.
81 llvm::SmallVector<mlir::NamedAttribute, 4> newValues(
82 expectedValues.begin() + 1, expectedValues.end());
83
84 // Add a copy of the first attribute with the wrong Identifier.
85 auto wrongId = mlir::Identifier::get("wrong", &context);
86 auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second);
87 newValues.push_back(wrongAttr);
88
89 auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
90 ASSERT_FALSE(test::TestStruct::classof(badDictionary));
91 }
92
93 /// Validates that test::TestStruct::classof fails when a NamedAttribute has an
94 /// incorrect type.
TEST(StructsGenTest,ClassofBadTypeFalse)95 TEST(StructsGenTest, ClassofBadTypeFalse) {
96 mlir::MLIRContext context;
97 mlir::DictionaryAttr structAttr = getTestStruct(&context);
98 auto expectedValues = structAttr.getValue();
99 ASSERT_EQ(expectedValues.size(), 3u);
100
101 // Create a copy of all but the last NamedAttributes.
102 llvm::SmallVector<mlir::NamedAttribute, 4> newValues(
103 expectedValues.begin(), expectedValues.end() - 1);
104
105 // Add a copy of the last attribute with the wrong type.
106 auto i64Type = mlir::IntegerType::get(64, &context);
107 auto elementsType = mlir::RankedTensorType::get({3}, i64Type);
108 auto elementsAttr =
109 mlir::DenseIntElementsAttr::get(elementsType, ArrayRef<int64_t>{1, 2, 3});
110 mlir::Identifier id = expectedValues.back().first;
111 auto wrongAttr = mlir::NamedAttribute(id, elementsAttr);
112 newValues.push_back(wrongAttr);
113
114 auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
115 ASSERT_FALSE(test::TestStruct::classof(badDictionary));
116 }
117
118 /// Validates that test::TestStruct::classof fails when a NamedAttribute is
119 /// missing.
TEST(StructsGenTest,ClassofMissingFalse)120 TEST(StructsGenTest, ClassofMissingFalse) {
121 mlir::MLIRContext context;
122 mlir::DictionaryAttr structAttr = getTestStruct(&context);
123 auto expectedValues = structAttr.getValue();
124 ASSERT_EQ(expectedValues.size(), 3u);
125
126 // Copy a subset of the structures Named Attributes.
127 llvm::SmallVector<mlir::NamedAttribute, 3> newValues(
128 expectedValues.begin() + 1, expectedValues.end());
129
130 // Make a new DictionaryAttr and validate it is not a validate TestStruct.
131 auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
132 ASSERT_FALSE(test::TestStruct::classof(badDictionary));
133 }
134
135 /// Validate the accessor for the FloatAttr value.
TEST(StructsGenTest,GetFloat)136 TEST(StructsGenTest, GetFloat) {
137 mlir::MLIRContext context;
138 auto structAttr = getTestStruct(&context);
139 auto returnedAttr = structAttr.sample_float();
140 EXPECT_EQ(returnedAttr.getValueAsDouble(), 0.25);
141 }
142
143 /// Validate the accessor for the IntegerAttr value.
TEST(StructsGenTest,GetInteger)144 TEST(StructsGenTest, GetInteger) {
145 mlir::MLIRContext context;
146 auto structAttr = getTestStruct(&context);
147 auto returnedAttr = structAttr.sample_integer();
148 EXPECT_EQ(returnedAttr.getInt(), 127);
149 }
150
151 /// Validate the accessor for the ElementsAttr value.
TEST(StructsGenTest,GetElements)152 TEST(StructsGenTest, GetElements) {
153 mlir::MLIRContext context;
154 auto structAttr = getTestStruct(&context);
155 auto returnedAttr = structAttr.sample_elements();
156 auto denseAttr = returnedAttr.dyn_cast<mlir::DenseElementsAttr>();
157 ASSERT_TRUE(denseAttr);
158
159 for (const auto &valIndexIt : llvm::enumerate(denseAttr.getIntValues())) {
160 EXPECT_EQ(valIndexIt.value(), valIndexIt.index() + 1);
161 }
162 }
163
TEST(StructsGenTest,EmptyOptional)164 TEST(StructsGenTest, EmptyOptional) {
165 mlir::MLIRContext context;
166 auto structAttr = getTestStruct(&context);
167 EXPECT_EQ(structAttr.sample_optional_integer(), nullptr);
168 }
169
170 } // namespace mlir
171