1 //===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains types defined by the TestDialect for testing various
10 // features of MLIR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "TestTypes.h"
15 #include "TestDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "mlir/IR/Types.h"
19 #include "llvm/ADT/Hashing.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::test;
25 
26 // Custom parser for SignednessSemantics.
27 static ParseResult
parseSignedness(DialectAsmParser & parser,TestIntegerType::SignednessSemantics & result)28 parseSignedness(DialectAsmParser &parser,
29                 TestIntegerType::SignednessSemantics &result) {
30   StringRef signStr;
31   auto loc = parser.getCurrentLocation();
32   if (parser.parseKeyword(&signStr))
33     return failure();
34   if (signStr.compare_lower("u") || signStr.compare_lower("unsigned"))
35     result = TestIntegerType::SignednessSemantics::Unsigned;
36   else if (signStr.compare_lower("s") || signStr.compare_lower("signed"))
37     result = TestIntegerType::SignednessSemantics::Signed;
38   else if (signStr.compare_lower("n") || signStr.compare_lower("none"))
39     result = TestIntegerType::SignednessSemantics::Signless;
40   else
41     return parser.emitError(loc, "expected signed, unsigned, or none");
42   return success();
43 }
44 
45 // Custom printer for SignednessSemantics.
printSignedness(DialectAsmPrinter & printer,const TestIntegerType::SignednessSemantics & ss)46 static void printSignedness(DialectAsmPrinter &printer,
47                             const TestIntegerType::SignednessSemantics &ss) {
48   switch (ss) {
49   case TestIntegerType::SignednessSemantics::Unsigned:
50     printer << "unsigned";
51     break;
52   case TestIntegerType::SignednessSemantics::Signed:
53     printer << "signed";
54     break;
55   case TestIntegerType::SignednessSemantics::Signless:
56     printer << "none";
57     break;
58   }
59 }
60 
parse(MLIRContext * ctxt,DialectAsmParser & parser)61 Type CompoundAType::parse(MLIRContext *ctxt, DialectAsmParser &parser) {
62   int widthOfSomething;
63   Type oneType;
64   SmallVector<int, 4> arrayOfInts;
65   if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
66       parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
67       parser.parseLSquare())
68     return Type();
69 
70   int i;
71   while (!*parser.parseOptionalInteger(i)) {
72     arrayOfInts.push_back(i);
73     if (parser.parseOptionalComma())
74       break;
75   }
76 
77   if (parser.parseRSquare() || parser.parseGreater())
78     return Type();
79 
80   return get(ctxt, widthOfSomething, oneType, arrayOfInts);
81 }
print(DialectAsmPrinter & printer) const82 void CompoundAType::print(DialectAsmPrinter &printer) const {
83   printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType()
84           << ", [";
85   auto intArray = getArrayOfInts();
86   llvm::interleaveComma(intArray, printer);
87   printer << "]>";
88 }
89 
90 // The functions don't need to be in the header file, but need to be in the mlir
91 // namespace. Declare them here, then define them immediately below. Separating
92 // the declaration and definition adheres to the LLVM coding standards.
93 namespace mlir {
94 namespace test {
95 // FieldInfo is used as part of a parameter, so equality comparison is
96 // compulsory.
97 static bool operator==(const FieldInfo &a, const FieldInfo &b);
98 // FieldInfo is used as part of a parameter, so a hash will be computed.
99 static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
100 } // namespace test
101 } // namespace mlir
102 
103 // FieldInfo is used as part of a parameter, so equality comparison is
104 // compulsory.
operator ==(const FieldInfo & a,const FieldInfo & b)105 static bool mlir::test::operator==(const FieldInfo &a, const FieldInfo &b) {
106   return a.name == b.name && a.type == b.type;
107 }
108 
109 // FieldInfo is used as part of a parameter, so a hash will be computed.
hash_value(const FieldInfo & fi)110 static llvm::hash_code mlir::test::hash_value(const FieldInfo &fi) { // NOLINT
111   return llvm::hash_combine(fi.name, fi.type);
112 }
113 
114 // Example type validity checker.
verifyConstructionInvariants(Location loc,unsigned width,TestIntegerType::SignednessSemantics ss)115 LogicalResult TestIntegerType::verifyConstructionInvariants(
116     Location loc, unsigned width, TestIntegerType::SignednessSemantics ss) {
117   if (width > 8)
118     return failure();
119   return success();
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // Tablegen Generated Definitions
124 //===----------------------------------------------------------------------===//
125 
126 #define GET_TYPEDEF_CLASSES
127 #include "TestTypeDefs.cpp.inc"
128 
129 //===----------------------------------------------------------------------===//
130 // TestDialect
131 //===----------------------------------------------------------------------===//
132 
parseTestType(MLIRContext * ctxt,DialectAsmParser & parser,llvm::SetVector<Type> & stack)133 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
134                           llvm::SetVector<Type> &stack) {
135   StringRef typeTag;
136   if (failed(parser.parseKeyword(&typeTag)))
137     return Type();
138 
139   auto genType = generatedTypeParser(ctxt, parser, typeTag);
140   if (genType != Type())
141     return genType;
142 
143   if (typeTag == "test_type")
144     return TestType::get(parser.getBuilder().getContext());
145 
146   if (typeTag != "test_rec")
147     return Type();
148 
149   StringRef name;
150   if (parser.parseLess() || parser.parseKeyword(&name))
151     return Type();
152   auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
153 
154   // If this type already has been parsed above in the stack, expect just the
155   // name.
156   if (stack.contains(rec)) {
157     if (failed(parser.parseGreater()))
158       return Type();
159     return rec;
160   }
161 
162   // Otherwise, parse the body and update the type.
163   if (failed(parser.parseComma()))
164     return Type();
165   stack.insert(rec);
166   Type subtype = parseTestType(ctxt, parser, stack);
167   stack.pop_back();
168   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
169     return Type();
170 
171   return rec;
172 }
173 
parseType(DialectAsmParser & parser) const174 Type TestDialect::parseType(DialectAsmParser &parser) const {
175   llvm::SetVector<Type> stack;
176   return parseTestType(getContext(), parser, stack);
177 }
178 
printTestType(Type type,DialectAsmPrinter & printer,llvm::SetVector<Type> & stack)179 static void printTestType(Type type, DialectAsmPrinter &printer,
180                           llvm::SetVector<Type> &stack) {
181   if (succeeded(generatedTypePrinter(type, printer)))
182     return;
183   if (type.isa<TestType>()) {
184     printer << "test_type";
185     return;
186   }
187 
188   auto rec = type.cast<TestRecursiveType>();
189   printer << "test_rec<" << rec.getName();
190   if (!stack.contains(rec)) {
191     printer << ", ";
192     stack.insert(rec);
193     printTestType(rec.getBody(), printer, stack);
194     stack.pop_back();
195   }
196   printer << ">";
197 }
198 
printType(Type type,DialectAsmPrinter & printer) const199 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
200   llvm::SetVector<Type> stack;
201   printTestType(type, printer, stack);
202 }
203