1 //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/ScopeExit.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/ADT/TypeSwitch.h"
15
16 using namespace mlir;
17 using namespace mlir::LLVM;
18
19 //===----------------------------------------------------------------------===//
20 // Printing.
21 //===----------------------------------------------------------------------===//
22
23 /// If the given type is compatible with the LLVM dialect, prints it using
24 /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
25 /// prints it as usual.
dispatchPrint(DialectAsmPrinter & printer,Type type)26 static void dispatchPrint(DialectAsmPrinter &printer, Type type) {
27 if (isCompatibleType(type) && !type.isa<IntegerType, FloatType, VectorType>())
28 return mlir::LLVM::detail::printType(type, printer);
29 printer.printType(type);
30 }
31
32 /// Returns the keyword to use for the given type.
getTypeKeyword(Type type)33 static StringRef getTypeKeyword(Type type) {
34 return TypeSwitch<Type, StringRef>(type)
35 .Case<LLVMVoidType>([&](Type) { return "void"; })
36 .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
37 .Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; })
38 .Case<LLVMTokenType>([&](Type) { return "token"; })
39 .Case<LLVMLabelType>([&](Type) { return "label"; })
40 .Case<LLVMMetadataType>([&](Type) { return "metadata"; })
41 .Case<LLVMFunctionType>([&](Type) { return "func"; })
42 .Case<LLVMPointerType>([&](Type) { return "ptr"; })
43 .Case<LLVMFixedVectorType, LLVMScalableVectorType>(
44 [&](Type) { return "vec"; })
45 .Case<LLVMArrayType>([&](Type) { return "array"; })
46 .Case<LLVMStructType>([&](Type) { return "struct"; })
47 .Default([](Type) -> StringRef {
48 llvm_unreachable("unexpected 'llvm' type kind");
49 });
50 }
51
52 /// Prints a structure type. Keeps track of known struct names to handle self-
53 /// or mutually-referring structs without falling into infinite recursion.
printStructType(DialectAsmPrinter & printer,LLVMStructType type)54 static void printStructType(DialectAsmPrinter &printer, LLVMStructType type) {
55 // This keeps track of the names of identified structure types that are
56 // currently being printed. Since such types can refer themselves, this
57 // tracking is necessary to stop the recursion: the current function may be
58 // called recursively from DialectAsmPrinter::printType after the appropriate
59 // dispatch. We maintain the invariant of this storage being modified
60 // exclusively in this function, and at most one name being added per call.
61 // TODO: consider having such functionality inside DialectAsmPrinter.
62 thread_local SetVector<StringRef> knownStructNames;
63 unsigned stackSize = knownStructNames.size();
64 (void)stackSize;
65 auto guard = llvm::make_scope_exit([&]() {
66 assert(knownStructNames.size() == stackSize &&
67 "malformed identified stack when printing recursive structs");
68 });
69
70 printer << "<";
71 if (type.isIdentified()) {
72 printer << '"' << type.getName() << '"';
73 // If we are printing a reference to one of the enclosing structs, just
74 // print the name and stop to avoid infinitely long output.
75 if (knownStructNames.count(type.getName())) {
76 printer << '>';
77 return;
78 }
79 printer << ", ";
80 }
81
82 if (type.isIdentified() && type.isOpaque()) {
83 printer << "opaque>";
84 return;
85 }
86
87 if (type.isPacked())
88 printer << "packed ";
89
90 // Put the current type on stack to avoid infinite recursion.
91 printer << '(';
92 if (type.isIdentified())
93 knownStructNames.insert(type.getName());
94 llvm::interleaveComma(type.getBody(), printer.getStream(),
95 [&](Type subtype) { dispatchPrint(printer, subtype); });
96 if (type.isIdentified())
97 knownStructNames.pop_back();
98 printer << ')';
99 printer << '>';
100 }
101
102 /// Prints a type containing a fixed number of elements.
103 template <typename TypeTy>
printArrayOrVectorType(DialectAsmPrinter & printer,TypeTy type)104 static void printArrayOrVectorType(DialectAsmPrinter &printer, TypeTy type) {
105 printer << '<' << type.getNumElements() << " x ";
106 dispatchPrint(printer, type.getElementType());
107 printer << '>';
108 }
109
110 /// Prints a function type.
printFunctionType(DialectAsmPrinter & printer,LLVMFunctionType funcType)111 static void printFunctionType(DialectAsmPrinter &printer,
112 LLVMFunctionType funcType) {
113 printer << '<';
114 dispatchPrint(printer, funcType.getReturnType());
115 printer << " (";
116 llvm::interleaveComma(
117 funcType.getParams(), printer.getStream(),
118 [&printer](Type subtype) { dispatchPrint(printer, subtype); });
119 if (funcType.isVarArg()) {
120 if (funcType.getNumParams() != 0)
121 printer << ", ";
122 printer << "...";
123 }
124 printer << ")>";
125 }
126
127 /// Prints the given LLVM dialect type recursively. This leverages closedness of
128 /// the LLVM dialect type system to avoid printing the dialect prefix
129 /// repeatedly. For recursive structures, only prints the name of the structure
130 /// when printing a self-reference. Note that this does not apply to sibling
131 /// references. For example,
132 /// struct<"a", (ptr<struct<"a">>)>
133 /// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
134 /// ptr<struct<"b", (ptr<struct<"c">>)>>)>
135 /// note that "b" is printed twice.
printType(Type type,DialectAsmPrinter & printer)136 void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) {
137 if (!type) {
138 printer << "<<NULL-TYPE>>";
139 return;
140 }
141
142 printer << getTypeKeyword(type);
143
144 if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
145 printer << '<';
146 dispatchPrint(printer, ptrType.getElementType());
147 if (ptrType.getAddressSpace() != 0)
148 printer << ", " << ptrType.getAddressSpace();
149 printer << '>';
150 return;
151 }
152
153 if (auto arrayType = type.dyn_cast<LLVMArrayType>())
154 return printArrayOrVectorType(printer, arrayType);
155 if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
156 return printArrayOrVectorType(printer, vectorType);
157
158 if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
159 printer << "<? x " << vectorType.getMinNumElements() << " x ";
160 dispatchPrint(printer, vectorType.getElementType());
161 printer << '>';
162 return;
163 }
164
165 if (auto structType = type.dyn_cast<LLVMStructType>())
166 return printStructType(printer, structType);
167
168 if (auto funcType = type.dyn_cast<LLVMFunctionType>())
169 return printFunctionType(printer, funcType);
170 }
171
172 //===----------------------------------------------------------------------===//
173 // Parsing.
174 //===----------------------------------------------------------------------===//
175
176 static ParseResult dispatchParse(DialectAsmParser &parser, Type &type);
177
178 /// Parses an LLVM dialect function type.
179 /// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
parseFunctionType(DialectAsmParser & parser)180 static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
181 llvm::SMLoc loc = parser.getCurrentLocation();
182 Type returnType;
183 if (parser.parseLess() || dispatchParse(parser, returnType) ||
184 parser.parseLParen())
185 return LLVMFunctionType();
186
187 // Function type without arguments.
188 if (succeeded(parser.parseOptionalRParen())) {
189 if (succeeded(parser.parseGreater()))
190 return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
191 /*isVarArg=*/false);
192 return LLVMFunctionType();
193 }
194
195 // Parse arguments.
196 SmallVector<Type, 8> argTypes;
197 do {
198 if (succeeded(parser.parseOptionalEllipsis())) {
199 if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
200 return LLVMFunctionType();
201 return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
202 /*isVarArg=*/true);
203 }
204
205 Type arg;
206 if (dispatchParse(parser, arg))
207 return LLVMFunctionType();
208 argTypes.push_back(arg);
209 } while (succeeded(parser.parseOptionalComma()));
210
211 if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
212 return LLVMFunctionType();
213 return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
214 /*isVarArg=*/false);
215 }
216
217 /// Parses an LLVM dialect pointer type.
218 /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
parsePointerType(DialectAsmParser & parser)219 static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
220 llvm::SMLoc loc = parser.getCurrentLocation();
221 Type elementType;
222 if (parser.parseLess() || dispatchParse(parser, elementType))
223 return LLVMPointerType();
224
225 unsigned addressSpace = 0;
226 if (succeeded(parser.parseOptionalComma()) &&
227 failed(parser.parseInteger(addressSpace)))
228 return LLVMPointerType();
229 if (failed(parser.parseGreater()))
230 return LLVMPointerType();
231 return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace);
232 }
233
234 /// Parses an LLVM dialect vector type.
235 /// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
236 /// Supports both fixed and scalable vectors.
parseVectorType(DialectAsmParser & parser)237 static Type parseVectorType(DialectAsmParser &parser) {
238 SmallVector<int64_t, 2> dims;
239 llvm::SMLoc dimPos, typePos;
240 Type elementType;
241 llvm::SMLoc loc = parser.getCurrentLocation();
242 if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
243 parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
244 parser.getCurrentLocation(&typePos) ||
245 dispatchParse(parser, elementType) || parser.parseGreater())
246 return Type();
247
248 // We parsed a generic dimension list, but vectors only support two forms:
249 // - single non-dynamic entry in the list (fixed vector);
250 // - two elements, the first dynamic (indicated by -1) and the second
251 // non-dynamic (scalable vector).
252 if (dims.empty() || dims.size() > 2 ||
253 ((dims.size() == 2) ^ (dims[0] == -1)) ||
254 (dims.size() == 2 && dims[1] == -1)) {
255 parser.emitError(dimPos)
256 << "expected '? x <integer> x <type>' or '<integer> x <type>'";
257 return Type();
258 }
259
260 bool isScalable = dims.size() == 2;
261 if (isScalable)
262 return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
263 if (elementType.isSignlessIntOrFloat()) {
264 parser.emitError(typePos)
265 << "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
266 return Type();
267 }
268 return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
269 }
270
271 /// Parses an LLVM dialect array type.
272 /// llvm-type ::= `array<` integer `x` llvm-type `>`
parseArrayType(DialectAsmParser & parser)273 static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
274 SmallVector<int64_t, 1> dims;
275 llvm::SMLoc sizePos;
276 Type elementType;
277 llvm::SMLoc loc = parser.getCurrentLocation();
278 if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
279 parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
280 dispatchParse(parser, elementType) || parser.parseGreater())
281 return LLVMArrayType();
282
283 if (dims.size() != 1) {
284 parser.emitError(sizePos) << "expected ? x <type>";
285 return LLVMArrayType();
286 }
287
288 return parser.getChecked<LLVMArrayType>(loc, elementType, dims[0]);
289 }
290
291 /// Attempts to set the body of an identified structure type. Reports a parsing
292 /// error at `subtypesLoc` in case of failure.
trySetStructBody(LLVMStructType type,ArrayRef<Type> subtypes,bool isPacked,DialectAsmParser & parser,llvm::SMLoc subtypesLoc)293 static LLVMStructType trySetStructBody(LLVMStructType type,
294 ArrayRef<Type> subtypes, bool isPacked,
295 DialectAsmParser &parser,
296 llvm::SMLoc subtypesLoc) {
297 for (Type t : subtypes) {
298 if (!LLVMStructType::isValidElementType(t)) {
299 parser.emitError(subtypesLoc)
300 << "invalid LLVM structure element type: " << t;
301 return LLVMStructType();
302 }
303 }
304
305 if (succeeded(type.setBody(subtypes, isPacked)))
306 return type;
307
308 parser.emitError(subtypesLoc)
309 << "identified type already used with a different body";
310 return LLVMStructType();
311 }
312
313 /// Parses an LLVM dialect structure type.
314 /// llvm-type ::= `struct<` (string-literal `,`)? `packed`?
315 /// `(` llvm-type-list `)` `>`
316 /// | `struct<` string-literal `>`
317 /// | `struct<` string-literal `, opaque>`
parseStructType(DialectAsmParser & parser)318 static LLVMStructType parseStructType(DialectAsmParser &parser) {
319 // This keeps track of the names of identified structure types that are
320 // currently being parsed. Since such types can refer themselves, this
321 // tracking is necessary to stop the recursion: the current function may be
322 // called recursively from DialectAsmParser::parseType after the appropriate
323 // dispatch. We maintain the invariant of this storage being modified
324 // exclusively in this function, and at most one name being added per call.
325 // TODO: consider having such functionality inside DialectAsmParser.
326 thread_local SetVector<StringRef> knownStructNames;
327 unsigned stackSize = knownStructNames.size();
328 (void)stackSize;
329 auto guard = llvm::make_scope_exit([&]() {
330 assert(knownStructNames.size() == stackSize &&
331 "malformed identified stack when parsing recursive structs");
332 });
333
334 Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
335
336 if (failed(parser.parseLess()))
337 return LLVMStructType();
338
339 // If we are parsing a self-reference to a recursive struct, i.e. the parsing
340 // stack already contains a struct with the same identifier, bail out after
341 // the name.
342 StringRef name;
343 bool isIdentified = succeeded(parser.parseOptionalString(&name));
344 if (isIdentified) {
345 if (knownStructNames.count(name)) {
346 if (failed(parser.parseGreater()))
347 return LLVMStructType();
348 return LLVMStructType::getIdentifiedChecked(
349 [loc] { return emitError(loc); }, loc.getContext(), name);
350 }
351 if (failed(parser.parseComma()))
352 return LLVMStructType();
353 }
354
355 // Handle intentionally opaque structs.
356 llvm::SMLoc kwLoc = parser.getCurrentLocation();
357 if (succeeded(parser.parseOptionalKeyword("opaque"))) {
358 if (!isIdentified)
359 return parser.emitError(kwLoc, "only identified structs can be opaque"),
360 LLVMStructType();
361 if (failed(parser.parseGreater()))
362 return LLVMStructType();
363 auto type = LLVMStructType::getOpaqueChecked(
364 [loc] { return emitError(loc); }, loc.getContext(), name);
365 if (!type.isOpaque()) {
366 parser.emitError(kwLoc, "redeclaring defined struct as opaque");
367 return LLVMStructType();
368 }
369 return type;
370 }
371
372 // Check for packedness.
373 bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
374 if (failed(parser.parseLParen()))
375 return LLVMStructType();
376
377 // Fast pass for structs with zero subtypes.
378 if (succeeded(parser.parseOptionalRParen())) {
379 if (failed(parser.parseGreater()))
380 return LLVMStructType();
381 if (!isIdentified)
382 return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); },
383 loc.getContext(), {}, isPacked);
384 auto type = LLVMStructType::getIdentifiedChecked(
385 [loc] { return emitError(loc); }, loc.getContext(), name);
386 return trySetStructBody(type, {}, isPacked, parser, kwLoc);
387 }
388
389 // Parse subtypes. For identified structs, put the identifier of the struct on
390 // the stack to support self-references in the recursive calls.
391 SmallVector<Type, 4> subtypes;
392 llvm::SMLoc subtypesLoc = parser.getCurrentLocation();
393 do {
394 if (isIdentified)
395 knownStructNames.insert(name);
396 Type type;
397 if (dispatchParse(parser, type))
398 return LLVMStructType();
399 subtypes.push_back(type);
400 if (isIdentified)
401 knownStructNames.pop_back();
402 } while (succeeded(parser.parseOptionalComma()));
403
404 if (parser.parseRParen() || parser.parseGreater())
405 return LLVMStructType();
406
407 // Construct the struct with body.
408 if (!isIdentified)
409 return LLVMStructType::getLiteralChecked(
410 [loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked);
411 auto type = LLVMStructType::getIdentifiedChecked(
412 [loc] { return emitError(loc); }, loc.getContext(), name);
413 return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
414 }
415
416 /// Parses a type appearing inside another LLVM dialect-compatible type. This
417 /// will try to parse any type in full form (including types with the `!llvm`
418 /// prefix), and on failure fall back to parsing the short-hand version of the
419 /// LLVM dialect types without the `!llvm` prefix.
dispatchParse(DialectAsmParser & parser,bool allowAny=true)420 static Type dispatchParse(DialectAsmParser &parser, bool allowAny = true) {
421 llvm::SMLoc keyLoc = parser.getCurrentLocation();
422
423 // Try parsing any MLIR type.
424 Type type;
425 OptionalParseResult result = parser.parseOptionalType(type);
426 if (result.hasValue()) {
427 if (failed(result.getValue()))
428 return nullptr;
429 if (!allowAny) {
430 parser.emitError(keyLoc) << "unexpected type, expected keyword";
431 return nullptr;
432 }
433 return type;
434 }
435
436 // If no type found, fallback to the shorthand form.
437 StringRef key;
438 if (failed(parser.parseKeyword(&key)))
439 return Type();
440
441 MLIRContext *ctx = parser.getBuilder().getContext();
442 return StringSwitch<function_ref<Type()>>(key)
443 .Case("void", [&] { return LLVMVoidType::get(ctx); })
444 .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
445 .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); })
446 .Case("token", [&] { return LLVMTokenType::get(ctx); })
447 .Case("label", [&] { return LLVMLabelType::get(ctx); })
448 .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
449 .Case("func", [&] { return parseFunctionType(parser); })
450 .Case("ptr", [&] { return parsePointerType(parser); })
451 .Case("vec", [&] { return parseVectorType(parser); })
452 .Case("array", [&] { return parseArrayType(parser); })
453 .Case("struct", [&] { return parseStructType(parser); })
454 .Default([&] {
455 parser.emitError(keyLoc) << "unknown LLVM type: " << key;
456 return Type();
457 })();
458 }
459
460 /// Helper to use in parse lists.
dispatchParse(DialectAsmParser & parser,Type & type)461 static ParseResult dispatchParse(DialectAsmParser &parser, Type &type) {
462 type = dispatchParse(parser);
463 return success(type != nullptr);
464 }
465
466 /// Parses one of the LLVM dialect types.
parseType(DialectAsmParser & parser)467 Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
468 llvm::SMLoc loc = parser.getCurrentLocation();
469 Type type = dispatchParse(parser, /*allowAny=*/false);
470 if (!type)
471 return type;
472 if (!isCompatibleType(type)) {
473 parser.emitError(loc) << "unexpected type, expected keyword";
474 return nullptr;
475 }
476 return type;
477 }
478