1 //===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains types defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "TestTypes.h"
15 #include "TestDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/Types.h"
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22
23 using namespace mlir;
24 using namespace mlir::test;
25
26 // Custom parser for SignednessSemantics.
27 static ParseResult
parseSignedness(DialectAsmParser & parser,TestIntegerType::SignednessSemantics & result)28 parseSignedness(DialectAsmParser &parser,
29 TestIntegerType::SignednessSemantics &result) {
30 StringRef signStr;
31 auto loc = parser.getCurrentLocation();
32 if (parser.parseKeyword(&signStr))
33 return failure();
34 if (signStr.compare_lower("u") || signStr.compare_lower("unsigned"))
35 result = TestIntegerType::SignednessSemantics::Unsigned;
36 else if (signStr.compare_lower("s") || signStr.compare_lower("signed"))
37 result = TestIntegerType::SignednessSemantics::Signed;
38 else if (signStr.compare_lower("n") || signStr.compare_lower("none"))
39 result = TestIntegerType::SignednessSemantics::Signless;
40 else
41 return parser.emitError(loc, "expected signed, unsigned, or none");
42 return success();
43 }
44
45 // Custom printer for SignednessSemantics.
printSignedness(DialectAsmPrinter & printer,const TestIntegerType::SignednessSemantics & ss)46 static void printSignedness(DialectAsmPrinter &printer,
47 const TestIntegerType::SignednessSemantics &ss) {
48 switch (ss) {
49 case TestIntegerType::SignednessSemantics::Unsigned:
50 printer << "unsigned";
51 break;
52 case TestIntegerType::SignednessSemantics::Signed:
53 printer << "signed";
54 break;
55 case TestIntegerType::SignednessSemantics::Signless:
56 printer << "none";
57 break;
58 }
59 }
60
parse(MLIRContext * ctxt,DialectAsmParser & parser)61 Type CompoundAType::parse(MLIRContext *ctxt, DialectAsmParser &parser) {
62 int widthOfSomething;
63 Type oneType;
64 SmallVector<int, 4> arrayOfInts;
65 if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
66 parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
67 parser.parseLSquare())
68 return Type();
69
70 int i;
71 while (!*parser.parseOptionalInteger(i)) {
72 arrayOfInts.push_back(i);
73 if (parser.parseOptionalComma())
74 break;
75 }
76
77 if (parser.parseRSquare() || parser.parseGreater())
78 return Type();
79
80 return get(ctxt, widthOfSomething, oneType, arrayOfInts);
81 }
print(DialectAsmPrinter & printer) const82 void CompoundAType::print(DialectAsmPrinter &printer) const {
83 printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType()
84 << ", [";
85 auto intArray = getArrayOfInts();
86 llvm::interleaveComma(intArray, printer);
87 printer << "]>";
88 }
89
90 // The functions don't need to be in the header file, but need to be in the mlir
91 // namespace. Declare them here, then define them immediately below. Separating
92 // the declaration and definition adheres to the LLVM coding standards.
93 namespace mlir {
94 namespace test {
95 // FieldInfo is used as part of a parameter, so equality comparison is
96 // compulsory.
97 static bool operator==(const FieldInfo &a, const FieldInfo &b);
98 // FieldInfo is used as part of a parameter, so a hash will be computed.
99 static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
100 } // namespace test
101 } // namespace mlir
102
103 // FieldInfo is used as part of a parameter, so equality comparison is
104 // compulsory.
operator ==(const FieldInfo & a,const FieldInfo & b)105 static bool mlir::test::operator==(const FieldInfo &a, const FieldInfo &b) {
106 return a.name == b.name && a.type == b.type;
107 }
108
109 // FieldInfo is used as part of a parameter, so a hash will be computed.
hash_value(const FieldInfo & fi)110 static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
111 return llvm::hash_combine(fi.name, fi.type);
112 }
113
114 // Example type validity checker.
verifyConstructionInvariants(Location loc,unsigned width,TestIntegerType::SignednessSemantics ss)115 LogicalResult TestIntegerType::verifyConstructionInvariants(
116 Location loc, unsigned width, TestIntegerType::SignednessSemantics ss) {
117 if (width > 8)
118 return failure();
119 return success();
120 }
121
122 //===----------------------------------------------------------------------===//
123 // Tablegen Generated Definitions
124 //===----------------------------------------------------------------------===//
125
126 #define GET_TYPEDEF_CLASSES
127 #include "TestTypeDefs.cpp.inc"
128
129 //===----------------------------------------------------------------------===//
130 // TestDialect
131 //===----------------------------------------------------------------------===//
132
parseTestType(MLIRContext * ctxt,DialectAsmParser & parser,llvm::SetVector<Type> & stack)133 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
134 llvm::SetVector<Type> &stack) {
135 StringRef typeTag;
136 if (failed(parser.parseKeyword(&typeTag)))
137 return Type();
138
139 auto genType = generatedTypeParser(ctxt, parser, typeTag);
140 if (genType != Type())
141 return genType;
142
143 if (typeTag == "test_type")
144 return TestType::get(parser.getBuilder().getContext());
145
146 if (typeTag != "test_rec")
147 return Type();
148
149 StringRef name;
150 if (parser.parseLess() || parser.parseKeyword(&name))
151 return Type();
152 auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
153
154 // If this type already has been parsed above in the stack, expect just the
155 // name.
156 if (stack.contains(rec)) {
157 if (failed(parser.parseGreater()))
158 return Type();
159 return rec;
160 }
161
162 // Otherwise, parse the body and update the type.
163 if (failed(parser.parseComma()))
164 return Type();
165 stack.insert(rec);
166 Type subtype = parseTestType(ctxt, parser, stack);
167 stack.pop_back();
168 if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
169 return Type();
170
171 return rec;
172 }
173
parseType(DialectAsmParser & parser) const174 Type TestDialect::parseType(DialectAsmParser &parser) const {
175 llvm::SetVector<Type> stack;
176 return parseTestType(getContext(), parser, stack);
177 }
178
printTestType(Type type,DialectAsmPrinter & printer,llvm::SetVector<Type> & stack)179 static void printTestType(Type type, DialectAsmPrinter &printer,
180 llvm::SetVector<Type> &stack) {
181 if (succeeded(generatedTypePrinter(type, printer)))
182 return;
183 if (type.isa<TestType>()) {
184 printer << "test_type";
185 return;
186 }
187
188 auto rec = type.cast<TestRecursiveType>();
189 printer << "test_rec<" << rec.getName();
190 if (!stack.contains(rec)) {
191 printer << ", ";
192 stack.insert(rec);
193 printTestType(rec.getBody(), printer, stack);
194 stack.pop_back();
195 }
196 printer << ">";
197 }
198
printType(Type type,DialectAsmPrinter & printer) const199 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
200 llvm::SetVector<Type> stack;
201 printTestType(type, printer, stack);
202 }
203