1 //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/ScopeExit.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::LLVM;
18 
19 //===----------------------------------------------------------------------===//
20 // Printing.
21 //===----------------------------------------------------------------------===//
22 
23 /// If the given type is compatible with the LLVM dialect, prints it using
24 /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
25 /// prints it as usual.
26 static void dispatchPrint(DialectAsmPrinter &printer, Type type) {
27   if (isCompatibleType(type) && !type.isa<IntegerType, FloatType, VectorType>())
28     return mlir::LLVM::detail::printType(type, printer);
29   printer.printType(type);
30 }
31 
32 /// Returns the keyword to use for the given type.
33 static StringRef getTypeKeyword(Type type) {
34   return TypeSwitch<Type, StringRef>(type)
35       .Case<LLVMVoidType>([&](Type) { return "void"; })
36       .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
37       .Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; })
38       .Case<LLVMTokenType>([&](Type) { return "token"; })
39       .Case<LLVMLabelType>([&](Type) { return "label"; })
40       .Case<LLVMMetadataType>([&](Type) { return "metadata"; })
41       .Case<LLVMFunctionType>([&](Type) { return "func"; })
42       .Case<LLVMPointerType>([&](Type) { return "ptr"; })
43       .Case<LLVMFixedVectorType, LLVMScalableVectorType>(
44           [&](Type) { return "vec"; })
45       .Case<LLVMArrayType>([&](Type) { return "array"; })
46       .Case<LLVMStructType>([&](Type) { return "struct"; })
47       .Default([](Type) -> StringRef {
48         llvm_unreachable("unexpected 'llvm' type kind");
49       });
50 }
51 
52 /// Prints a structure type. Keeps track of known struct names to handle self-
53 /// or mutually-referring structs without falling into infinite recursion.
54 static void printStructType(DialectAsmPrinter &printer, LLVMStructType type) {
55   // This keeps track of the names of identified structure types that are
56   // currently being printed. Since such types can refer themselves, this
57   // tracking is necessary to stop the recursion: the current function may be
58   // called recursively from DialectAsmPrinter::printType after the appropriate
59   // dispatch. We maintain the invariant of this storage being modified
60   // exclusively in this function, and at most one name being added per call.
61   // TODO: consider having such functionality inside DialectAsmPrinter.
62   thread_local SetVector<StringRef> knownStructNames;
63   unsigned stackSize = knownStructNames.size();
64   (void)stackSize;
65   auto guard = llvm::make_scope_exit([&]() {
66     assert(knownStructNames.size() == stackSize &&
67            "malformed identified stack when printing recursive structs");
68   });
69 
70   printer << "<";
71   if (type.isIdentified()) {
72     printer << '"' << type.getName() << '"';
73     // If we are printing a reference to one of the enclosing structs, just
74     // print the name and stop to avoid infinitely long output.
75     if (knownStructNames.count(type.getName())) {
76       printer << '>';
77       return;
78     }
79     printer << ", ";
80   }
81 
82   if (type.isIdentified() && type.isOpaque()) {
83     printer << "opaque>";
84     return;
85   }
86 
87   if (type.isPacked())
88     printer << "packed ";
89 
90   // Put the current type on stack to avoid infinite recursion.
91   printer << '(';
92   if (type.isIdentified())
93     knownStructNames.insert(type.getName());
94   llvm::interleaveComma(type.getBody(), printer.getStream(),
95                         [&](Type subtype) { dispatchPrint(printer, subtype); });
96   if (type.isIdentified())
97     knownStructNames.pop_back();
98   printer << ')';
99   printer << '>';
100 }
101 
102 /// Prints a type containing a fixed number of elements.
103 template <typename TypeTy>
104 static void printArrayOrVectorType(DialectAsmPrinter &printer, TypeTy type) {
105   printer << '<' << type.getNumElements() << " x ";
106   dispatchPrint(printer, type.getElementType());
107   printer << '>';
108 }
109 
110 /// Prints a function type.
111 static void printFunctionType(DialectAsmPrinter &printer,
112                               LLVMFunctionType funcType) {
113   printer << '<';
114   dispatchPrint(printer, funcType.getReturnType());
115   printer << " (";
116   llvm::interleaveComma(
117       funcType.getParams(), printer.getStream(),
118       [&printer](Type subtype) { dispatchPrint(printer, subtype); });
119   if (funcType.isVarArg()) {
120     if (funcType.getNumParams() != 0)
121       printer << ", ";
122     printer << "...";
123   }
124   printer << ")>";
125 }
126 
127 /// Prints the given LLVM dialect type recursively. This leverages closedness of
128 /// the LLVM dialect type system to avoid printing the dialect prefix
129 /// repeatedly. For recursive structures, only prints the name of the structure
130 /// when printing a self-reference. Note that this does not apply to sibling
131 /// references. For example,
132 ///   struct<"a", (ptr<struct<"a">>)>
133 ///   struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
134 ///                ptr<struct<"b", (ptr<struct<"c">>)>>)>
135 /// note that "b" is printed twice.
136 void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) {
137   if (!type) {
138     printer << "<<NULL-TYPE>>";
139     return;
140   }
141 
142   printer << getTypeKeyword(type);
143 
144   if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
145     printer << '<';
146     dispatchPrint(printer, ptrType.getElementType());
147     if (ptrType.getAddressSpace() != 0)
148       printer << ", " << ptrType.getAddressSpace();
149     printer << '>';
150     return;
151   }
152 
153   if (auto arrayType = type.dyn_cast<LLVMArrayType>())
154     return printArrayOrVectorType(printer, arrayType);
155   if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
156     return printArrayOrVectorType(printer, vectorType);
157 
158   if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
159     printer << "<? x " << vectorType.getMinNumElements() << " x ";
160     dispatchPrint(printer, vectorType.getElementType());
161     printer << '>';
162     return;
163   }
164 
165   if (auto structType = type.dyn_cast<LLVMStructType>())
166     return printStructType(printer, structType);
167 
168   if (auto funcType = type.dyn_cast<LLVMFunctionType>())
169     return printFunctionType(printer, funcType);
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // Parsing.
174 //===----------------------------------------------------------------------===//
175 
176 static ParseResult dispatchParse(DialectAsmParser &parser, Type &type);
177 
178 /// Parses an LLVM dialect function type.
179 ///   llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
180 static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
181   llvm::SMLoc loc = parser.getCurrentLocation();
182   Type returnType;
183   if (parser.parseLess() || dispatchParse(parser, returnType) ||
184       parser.parseLParen())
185     return LLVMFunctionType();
186 
187   // Function type without arguments.
188   if (succeeded(parser.parseOptionalRParen())) {
189     if (succeeded(parser.parseGreater()))
190       return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
191                                                  /*isVarArg=*/false);
192     return LLVMFunctionType();
193   }
194 
195   // Parse arguments.
196   SmallVector<Type, 8> argTypes;
197   do {
198     if (succeeded(parser.parseOptionalEllipsis())) {
199       if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
200         return LLVMFunctionType();
201       return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
202                                                  /*isVarArg=*/true);
203     }
204 
205     Type arg;
206     if (dispatchParse(parser, arg))
207       return LLVMFunctionType();
208     argTypes.push_back(arg);
209   } while (succeeded(parser.parseOptionalComma()));
210 
211   if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
212     return LLVMFunctionType();
213   return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
214                                              /*isVarArg=*/false);
215 }
216 
217 /// Parses an LLVM dialect pointer type.
218 ///   llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
219 static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
220   llvm::SMLoc loc = parser.getCurrentLocation();
221   Type elementType;
222   if (parser.parseLess() || dispatchParse(parser, elementType))
223     return LLVMPointerType();
224 
225   unsigned addressSpace = 0;
226   if (succeeded(parser.parseOptionalComma()) &&
227       failed(parser.parseInteger(addressSpace)))
228     return LLVMPointerType();
229   if (failed(parser.parseGreater()))
230     return LLVMPointerType();
231   return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace);
232 }
233 
234 /// Parses an LLVM dialect vector type.
235 ///   llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
236 /// Supports both fixed and scalable vectors.
237 static Type parseVectorType(DialectAsmParser &parser) {
238   SmallVector<int64_t, 2> dims;
239   llvm::SMLoc dimPos, typePos;
240   Type elementType;
241   llvm::SMLoc loc = parser.getCurrentLocation();
242   if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
243       parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
244       parser.getCurrentLocation(&typePos) ||
245       dispatchParse(parser, elementType) || parser.parseGreater())
246     return Type();
247 
248   // We parsed a generic dimension list, but vectors only support two forms:
249   //  - single non-dynamic entry in the list (fixed vector);
250   //  - two elements, the first dynamic (indicated by -1) and the second
251   //    non-dynamic (scalable vector).
252   if (dims.empty() || dims.size() > 2 ||
253       ((dims.size() == 2) ^ (dims[0] == -1)) ||
254       (dims.size() == 2 && dims[1] == -1)) {
255     parser.emitError(dimPos)
256         << "expected '? x <integer> x <type>' or '<integer> x <type>'";
257     return Type();
258   }
259 
260   bool isScalable = dims.size() == 2;
261   if (isScalable)
262     return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
263   if (elementType.isSignlessIntOrFloat()) {
264     parser.emitError(typePos)
265         << "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
266     return Type();
267   }
268   return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
269 }
270 
271 /// Parses an LLVM dialect array type.
272 ///   llvm-type ::= `array<` integer `x` llvm-type `>`
273 static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
274   SmallVector<int64_t, 1> dims;
275   llvm::SMLoc sizePos;
276   Type elementType;
277   llvm::SMLoc loc = parser.getCurrentLocation();
278   if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
279       parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
280       dispatchParse(parser, elementType) || parser.parseGreater())
281     return LLVMArrayType();
282 
283   if (dims.size() != 1) {
284     parser.emitError(sizePos) << "expected ? x <type>";
285     return LLVMArrayType();
286   }
287 
288   return parser.getChecked<LLVMArrayType>(loc, elementType, dims[0]);
289 }
290 
291 /// Attempts to set the body of an identified structure type. Reports a parsing
292 /// error at `subtypesLoc` in case of failure.
293 static LLVMStructType trySetStructBody(LLVMStructType type,
294                                        ArrayRef<Type> subtypes, bool isPacked,
295                                        DialectAsmParser &parser,
296                                        llvm::SMLoc subtypesLoc) {
297   for (Type t : subtypes) {
298     if (!LLVMStructType::isValidElementType(t)) {
299       parser.emitError(subtypesLoc)
300           << "invalid LLVM structure element type: " << t;
301       return LLVMStructType();
302     }
303   }
304 
305   if (succeeded(type.setBody(subtypes, isPacked)))
306     return type;
307 
308   parser.emitError(subtypesLoc)
309       << "identified type already used with a different body";
310   return LLVMStructType();
311 }
312 
313 /// Parses an LLVM dialect structure type.
314 ///   llvm-type ::= `struct<` (string-literal `,`)? `packed`?
315 ///                 `(` llvm-type-list `)` `>`
316 ///               | `struct<` string-literal `>`
317 ///               | `struct<` string-literal `, opaque>`
318 static LLVMStructType parseStructType(DialectAsmParser &parser) {
319   // This keeps track of the names of identified structure types that are
320   // currently being parsed. Since such types can refer themselves, this
321   // tracking is necessary to stop the recursion: the current function may be
322   // called recursively from DialectAsmParser::parseType after the appropriate
323   // dispatch. We maintain the invariant of this storage being modified
324   // exclusively in this function, and at most one name being added per call.
325   // TODO: consider having such functionality inside DialectAsmParser.
326   thread_local SetVector<StringRef> knownStructNames;
327   unsigned stackSize = knownStructNames.size();
328   (void)stackSize;
329   auto guard = llvm::make_scope_exit([&]() {
330     assert(knownStructNames.size() == stackSize &&
331            "malformed identified stack when parsing recursive structs");
332   });
333 
334   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
335 
336   if (failed(parser.parseLess()))
337     return LLVMStructType();
338 
339   // If we are parsing a self-reference to a recursive struct, i.e. the parsing
340   // stack already contains a struct with the same identifier, bail out after
341   // the name.
342   StringRef name;
343   bool isIdentified = succeeded(parser.parseOptionalString(&name));
344   if (isIdentified) {
345     if (knownStructNames.count(name)) {
346       if (failed(parser.parseGreater()))
347         return LLVMStructType();
348       return LLVMStructType::getIdentifiedChecked(
349           [loc] { return emitError(loc); }, loc.getContext(), name);
350     }
351     if (failed(parser.parseComma()))
352       return LLVMStructType();
353   }
354 
355   // Handle intentionally opaque structs.
356   llvm::SMLoc kwLoc = parser.getCurrentLocation();
357   if (succeeded(parser.parseOptionalKeyword("opaque"))) {
358     if (!isIdentified)
359       return parser.emitError(kwLoc, "only identified structs can be opaque"),
360              LLVMStructType();
361     if (failed(parser.parseGreater()))
362       return LLVMStructType();
363     auto type = LLVMStructType::getOpaqueChecked(
364         [loc] { return emitError(loc); }, loc.getContext(), name);
365     if (!type.isOpaque()) {
366       parser.emitError(kwLoc, "redeclaring defined struct as opaque");
367       return LLVMStructType();
368     }
369     return type;
370   }
371 
372   // Check for packedness.
373   bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
374   if (failed(parser.parseLParen()))
375     return LLVMStructType();
376 
377   // Fast pass for structs with zero subtypes.
378   if (succeeded(parser.parseOptionalRParen())) {
379     if (failed(parser.parseGreater()))
380       return LLVMStructType();
381     if (!isIdentified)
382       return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); },
383                                                loc.getContext(), {}, isPacked);
384     auto type = LLVMStructType::getIdentifiedChecked(
385         [loc] { return emitError(loc); }, loc.getContext(), name);
386     return trySetStructBody(type, {}, isPacked, parser, kwLoc);
387   }
388 
389   // Parse subtypes. For identified structs, put the identifier of the struct on
390   // the stack to support self-references in the recursive calls.
391   SmallVector<Type, 4> subtypes;
392   llvm::SMLoc subtypesLoc = parser.getCurrentLocation();
393   do {
394     if (isIdentified)
395       knownStructNames.insert(name);
396     Type type;
397     if (dispatchParse(parser, type))
398       return LLVMStructType();
399     subtypes.push_back(type);
400     if (isIdentified)
401       knownStructNames.pop_back();
402   } while (succeeded(parser.parseOptionalComma()));
403 
404   if (parser.parseRParen() || parser.parseGreater())
405     return LLVMStructType();
406 
407   // Construct the struct with body.
408   if (!isIdentified)
409     return LLVMStructType::getLiteralChecked(
410         [loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked);
411   auto type = LLVMStructType::getIdentifiedChecked(
412       [loc] { return emitError(loc); }, loc.getContext(), name);
413   return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
414 }
415 
416 /// Parses a type appearing inside another LLVM dialect-compatible type. This
417 /// will try to parse any type in full form (including types with the `!llvm`
418 /// prefix), and on failure fall back to parsing the short-hand version of the
419 /// LLVM dialect types without the `!llvm` prefix.
420 static Type dispatchParse(DialectAsmParser &parser, bool allowAny = true) {
421   llvm::SMLoc keyLoc = parser.getCurrentLocation();
422 
423   // Try parsing any MLIR type.
424   Type type;
425   OptionalParseResult result = parser.parseOptionalType(type);
426   if (result.hasValue()) {
427     if (failed(result.getValue()))
428       return nullptr;
429     if (!allowAny) {
430       parser.emitError(keyLoc) << "unexpected type, expected keyword";
431       return nullptr;
432     }
433     return type;
434   }
435 
436   // If no type found, fallback to the shorthand form.
437   StringRef key;
438   if (failed(parser.parseKeyword(&key)))
439     return Type();
440 
441   MLIRContext *ctx = parser.getBuilder().getContext();
442   return StringSwitch<function_ref<Type()>>(key)
443       .Case("void", [&] { return LLVMVoidType::get(ctx); })
444       .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
445       .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); })
446       .Case("token", [&] { return LLVMTokenType::get(ctx); })
447       .Case("label", [&] { return LLVMLabelType::get(ctx); })
448       .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
449       .Case("func", [&] { return parseFunctionType(parser); })
450       .Case("ptr", [&] { return parsePointerType(parser); })
451       .Case("vec", [&] { return parseVectorType(parser); })
452       .Case("array", [&] { return parseArrayType(parser); })
453       .Case("struct", [&] { return parseStructType(parser); })
454       .Default([&] {
455         parser.emitError(keyLoc) << "unknown LLVM type: " << key;
456         return Type();
457       })();
458 }
459 
460 /// Helper to use in parse lists.
461 static ParseResult dispatchParse(DialectAsmParser &parser, Type &type) {
462   type = dispatchParse(parser);
463   return success(type != nullptr);
464 }
465 
466 /// Parses one of the LLVM dialect types.
467 Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
468   llvm::SMLoc loc = parser.getCurrentLocation();
469   Type type = dispatchParse(parser, /*allowAny=*/false);
470   if (!type)
471     return type;
472   if (!isCompatibleType(type)) {
473     parser.emitError(loc) << "unexpected type, expected keyword";
474     return nullptr;
475   }
476   return type;
477 }
478