1 //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/ScopeExit.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::LLVM;
18 
19 //===----------------------------------------------------------------------===//
20 // Printing.
21 //===----------------------------------------------------------------------===//
22 
23 /// If the given type is compatible with the LLVM dialect, prints it using
24 /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
25 /// prints it as usual.
dispatchPrint(DialectAsmPrinter & printer,Type type)26 static void dispatchPrint(DialectAsmPrinter &printer, Type type) {
27   if (isCompatibleType(type) && !type.isa<IntegerType, FloatType, VectorType>())
28     return mlir::LLVM::detail::printType(type, printer);
29   printer.printType(type);
30 }
31 
32 /// Returns the keyword to use for the given type.
getTypeKeyword(Type type)33 static StringRef getTypeKeyword(Type type) {
34   return TypeSwitch<Type, StringRef>(type)
35       .Case<LLVMVoidType>([&](Type) { return "void"; })
36       .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
37       .Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; })
38       .Case<LLVMTokenType>([&](Type) { return "token"; })
39       .Case<LLVMLabelType>([&](Type) { return "label"; })
40       .Case<LLVMMetadataType>([&](Type) { return "metadata"; })
41       .Case<LLVMFunctionType>([&](Type) { return "func"; })
42       .Case<LLVMPointerType>([&](Type) { return "ptr"; })
43       .Case<LLVMFixedVectorType, LLVMScalableVectorType>(
44           [&](Type) { return "vec"; })
45       .Case<LLVMArrayType>([&](Type) { return "array"; })
46       .Case<LLVMStructType>([&](Type) { return "struct"; })
47       .Default([](Type) -> StringRef {
48         llvm_unreachable("unexpected 'llvm' type kind");
49       });
50 }
51 
52 /// Prints a structure type. Keeps track of known struct names to handle self-
53 /// or mutually-referring structs without falling into infinite recursion.
printStructType(DialectAsmPrinter & printer,LLVMStructType type)54 static void printStructType(DialectAsmPrinter &printer, LLVMStructType type) {
55   // This keeps track of the names of identified structure types that are
56   // currently being printed. Since such types can refer themselves, this
57   // tracking is necessary to stop the recursion: the current function may be
58   // called recursively from DialectAsmPrinter::printType after the appropriate
59   // dispatch. We maintain the invariant of this storage being modified
60   // exclusively in this function, and at most one name being added per call.
61   // TODO: consider having such functionality inside DialectAsmPrinter.
62   thread_local llvm::SetVector<StringRef> knownStructNames;
63   unsigned stackSize = knownStructNames.size();
64   (void)stackSize;
65   auto guard = llvm::make_scope_exit([&]() {
66     assert(knownStructNames.size() == stackSize &&
67            "malformed identified stack when printing recursive structs");
68   });
69 
70   printer << "<";
71   if (type.isIdentified()) {
72     printer << '"' << type.getName() << '"';
73     // If we are printing a reference to one of the enclosing structs, just
74     // print the name and stop to avoid infinitely long output.
75     if (knownStructNames.count(type.getName())) {
76       printer << '>';
77       return;
78     }
79     printer << ", ";
80   }
81 
82   if (type.isIdentified() && type.isOpaque()) {
83     printer << "opaque>";
84     return;
85   }
86 
87   if (type.isPacked())
88     printer << "packed ";
89 
90   // Put the current type on stack to avoid infinite recursion.
91   printer << '(';
92   if (type.isIdentified())
93     knownStructNames.insert(type.getName());
94   llvm::interleaveComma(type.getBody(), printer.getStream(),
95                         [&](Type subtype) { dispatchPrint(printer, subtype); });
96   if (type.isIdentified())
97     knownStructNames.pop_back();
98   printer << ')';
99   printer << '>';
100 }
101 
102 /// Prints a type containing a fixed number of elements.
103 template <typename TypeTy>
printArrayOrVectorType(DialectAsmPrinter & printer,TypeTy type)104 static void printArrayOrVectorType(DialectAsmPrinter &printer, TypeTy type) {
105   printer << '<' << type.getNumElements() << " x ";
106   dispatchPrint(printer, type.getElementType());
107   printer << '>';
108 }
109 
110 /// Prints a function type.
printFunctionType(DialectAsmPrinter & printer,LLVMFunctionType funcType)111 static void printFunctionType(DialectAsmPrinter &printer,
112                               LLVMFunctionType funcType) {
113   printer << '<';
114   dispatchPrint(printer, funcType.getReturnType());
115   printer << " (";
116   llvm::interleaveComma(
117       funcType.getParams(), printer.getStream(),
118       [&printer](Type subtype) { dispatchPrint(printer, subtype); });
119   if (funcType.isVarArg()) {
120     if (funcType.getNumParams() != 0)
121       printer << ", ";
122     printer << "...";
123   }
124   printer << ")>";
125 }
126 
127 /// Prints the given LLVM dialect type recursively. This leverages closedness of
128 /// the LLVM dialect type system to avoid printing the dialect prefix
129 /// repeatedly. For recursive structures, only prints the name of the structure
130 /// when printing a self-reference. Note that this does not apply to sibling
131 /// references. For example,
132 ///   struct<"a", (ptr<struct<"a">>)>
133 ///   struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
134 ///                ptr<struct<"b", (ptr<struct<"c">>)>>)>
135 /// note that "b" is printed twice.
printType(Type type,DialectAsmPrinter & printer)136 void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) {
137   if (!type) {
138     printer << "<<NULL-TYPE>>";
139     return;
140   }
141 
142   printer << getTypeKeyword(type);
143 
144   if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
145     printer << '<';
146     dispatchPrint(printer, ptrType.getElementType());
147     if (ptrType.getAddressSpace() != 0)
148       printer << ", " << ptrType.getAddressSpace();
149     printer << '>';
150     return;
151   }
152 
153   if (auto arrayType = type.dyn_cast<LLVMArrayType>())
154     return printArrayOrVectorType(printer, arrayType);
155   if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
156     return printArrayOrVectorType(printer, vectorType);
157 
158   if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
159     printer << "<? x " << vectorType.getMinNumElements() << " x ";
160     dispatchPrint(printer, vectorType.getElementType());
161     printer << '>';
162     return;
163   }
164 
165   if (auto structType = type.dyn_cast<LLVMStructType>())
166     return printStructType(printer, structType);
167 
168   if (auto funcType = type.dyn_cast<LLVMFunctionType>())
169     return printFunctionType(printer, funcType);
170 }
171 
172 //===----------------------------------------------------------------------===//
173 // Parsing.
174 //===----------------------------------------------------------------------===//
175 
176 static ParseResult dispatchParse(DialectAsmParser &parser, Type &type);
177 
178 /// Parses an LLVM dialect function type.
179 ///   llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
parseFunctionType(DialectAsmParser & parser)180 static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
181   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
182   Type returnType;
183   if (parser.parseLess() || dispatchParse(parser, returnType) ||
184       parser.parseLParen())
185     return LLVMFunctionType();
186 
187   // Function type without arguments.
188   if (succeeded(parser.parseOptionalRParen())) {
189     if (succeeded(parser.parseGreater()))
190       return LLVMFunctionType::getChecked(loc, returnType, {},
191                                           /*isVarArg=*/false);
192     return LLVMFunctionType();
193   }
194 
195   // Parse arguments.
196   SmallVector<Type, 8> argTypes;
197   do {
198     if (succeeded(parser.parseOptionalEllipsis())) {
199       if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
200         return LLVMFunctionType();
201       return LLVMFunctionType::getChecked(loc, returnType, argTypes,
202                                           /*isVarArg=*/true);
203     }
204 
205     Type arg;
206     if (dispatchParse(parser, arg))
207       return LLVMFunctionType();
208     argTypes.push_back(arg);
209   } while (succeeded(parser.parseOptionalComma()));
210 
211   if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
212     return LLVMFunctionType();
213   return LLVMFunctionType::getChecked(loc, returnType, argTypes,
214                                       /*isVarArg=*/false);
215 }
216 
217 /// Parses an LLVM dialect pointer type.
218 ///   llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
parsePointerType(DialectAsmParser & parser)219 static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
220   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
221   Type elementType;
222   if (parser.parseLess() || dispatchParse(parser, elementType))
223     return LLVMPointerType();
224 
225   unsigned addressSpace = 0;
226   if (succeeded(parser.parseOptionalComma()) &&
227       failed(parser.parseInteger(addressSpace)))
228     return LLVMPointerType();
229   if (failed(parser.parseGreater()))
230     return LLVMPointerType();
231   return LLVMPointerType::getChecked(loc, elementType, addressSpace);
232 }
233 
234 /// Parses an LLVM dialect vector type.
235 ///   llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
236 /// Supports both fixed and scalable vectors.
parseVectorType(DialectAsmParser & parser)237 static Type parseVectorType(DialectAsmParser &parser) {
238   SmallVector<int64_t, 2> dims;
239   llvm::SMLoc dimPos;
240   Type elementType;
241   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
242   if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
243       parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
244       dispatchParse(parser, elementType) || parser.parseGreater())
245     return Type();
246 
247   // We parsed a generic dimension list, but vectors only support two forms:
248   //  - single non-dynamic entry in the list (fixed vector);
249   //  - two elements, the first dynamic (indicated by -1) and the second
250   //    non-dynamic (scalable vector).
251   if (dims.empty() || dims.size() > 2 ||
252       ((dims.size() == 2) ^ (dims[0] == -1)) ||
253       (dims.size() == 2 && dims[1] == -1)) {
254     parser.emitError(dimPos)
255         << "expected '? x <integer> x <type>' or '<integer> x <type>'";
256     return Type();
257   }
258 
259   bool isScalable = dims.size() == 2;
260   if (isScalable)
261     return LLVMScalableVectorType::getChecked(loc, elementType, dims[1]);
262   if (elementType.isSignlessIntOrFloat())
263     return VectorType::getChecked(loc, dims, elementType);
264   return LLVMFixedVectorType::getChecked(loc, elementType, dims[0]);
265 }
266 
267 /// Parses an LLVM dialect array type.
268 ///   llvm-type ::= `array<` integer `x` llvm-type `>`
parseArrayType(DialectAsmParser & parser)269 static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
270   SmallVector<int64_t, 1> dims;
271   llvm::SMLoc sizePos;
272   Type elementType;
273   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
274   if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
275       parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
276       dispatchParse(parser, elementType) || parser.parseGreater())
277     return LLVMArrayType();
278 
279   if (dims.size() != 1) {
280     parser.emitError(sizePos) << "expected ? x <type>";
281     return LLVMArrayType();
282   }
283 
284   return LLVMArrayType::getChecked(loc, elementType, dims[0]);
285 }
286 
287 /// Attempts to set the body of an identified structure type. Reports a parsing
288 /// error at `subtypesLoc` in case of failure.
trySetStructBody(LLVMStructType type,ArrayRef<Type> subtypes,bool isPacked,DialectAsmParser & parser,llvm::SMLoc subtypesLoc)289 static LLVMStructType trySetStructBody(LLVMStructType type,
290                                        ArrayRef<Type> subtypes, bool isPacked,
291                                        DialectAsmParser &parser,
292                                        llvm::SMLoc subtypesLoc) {
293   for (Type t : subtypes) {
294     if (!LLVMStructType::isValidElementType(t)) {
295       parser.emitError(subtypesLoc)
296           << "invalid LLVM structure element type: " << t;
297       return LLVMStructType();
298     }
299   }
300 
301   if (succeeded(type.setBody(subtypes, isPacked)))
302     return type;
303 
304   parser.emitError(subtypesLoc)
305       << "identified type already used with a different body";
306   return LLVMStructType();
307 }
308 
309 /// Parses an LLVM dialect structure type.
310 ///   llvm-type ::= `struct<` (string-literal `,`)? `packed`?
311 ///                 `(` llvm-type-list `)` `>`
312 ///               | `struct<` string-literal `>`
313 ///               | `struct<` string-literal `, opaque>`
parseStructType(DialectAsmParser & parser)314 static LLVMStructType parseStructType(DialectAsmParser &parser) {
315   // This keeps track of the names of identified structure types that are
316   // currently being parsed. Since such types can refer themselves, this
317   // tracking is necessary to stop the recursion: the current function may be
318   // called recursively from DialectAsmParser::parseType after the appropriate
319   // dispatch. We maintain the invariant of this storage being modified
320   // exclusively in this function, and at most one name being added per call.
321   // TODO: consider having such functionality inside DialectAsmParser.
322   thread_local llvm::SetVector<StringRef> knownStructNames;
323   unsigned stackSize = knownStructNames.size();
324   (void)stackSize;
325   auto guard = llvm::make_scope_exit([&]() {
326     assert(knownStructNames.size() == stackSize &&
327            "malformed identified stack when parsing recursive structs");
328   });
329 
330   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
331 
332   if (failed(parser.parseLess()))
333     return LLVMStructType();
334 
335   // If we are parsing a self-reference to a recursive struct, i.e. the parsing
336   // stack already contains a struct with the same identifier, bail out after
337   // the name.
338   StringRef name;
339   bool isIdentified = succeeded(parser.parseOptionalString(&name));
340   if (isIdentified) {
341     if (knownStructNames.count(name)) {
342       if (failed(parser.parseGreater()))
343         return LLVMStructType();
344       return LLVMStructType::getIdentifiedChecked(loc, name);
345     }
346     if (failed(parser.parseComma()))
347       return LLVMStructType();
348   }
349 
350   // Handle intentionally opaque structs.
351   llvm::SMLoc kwLoc = parser.getCurrentLocation();
352   if (succeeded(parser.parseOptionalKeyword("opaque"))) {
353     if (!isIdentified)
354       return parser.emitError(kwLoc, "only identified structs can be opaque"),
355              LLVMStructType();
356     if (failed(parser.parseGreater()))
357       return LLVMStructType();
358     auto type = LLVMStructType::getOpaqueChecked(loc, name);
359     if (!type.isOpaque()) {
360       parser.emitError(kwLoc, "redeclaring defined struct as opaque");
361       return LLVMStructType();
362     }
363     return type;
364   }
365 
366   // Check for packedness.
367   bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
368   if (failed(parser.parseLParen()))
369     return LLVMStructType();
370 
371   // Fast pass for structs with zero subtypes.
372   if (succeeded(parser.parseOptionalRParen())) {
373     if (failed(parser.parseGreater()))
374       return LLVMStructType();
375     if (!isIdentified)
376       return LLVMStructType::getLiteralChecked(loc, {}, isPacked);
377     auto type = LLVMStructType::getIdentifiedChecked(loc, name);
378     return trySetStructBody(type, {}, isPacked, parser, kwLoc);
379   }
380 
381   // Parse subtypes. For identified structs, put the identifier of the struct on
382   // the stack to support self-references in the recursive calls.
383   SmallVector<Type, 4> subtypes;
384   llvm::SMLoc subtypesLoc = parser.getCurrentLocation();
385   do {
386     if (isIdentified)
387       knownStructNames.insert(name);
388     Type type;
389     if (dispatchParse(parser, type))
390       return LLVMStructType();
391     subtypes.push_back(type);
392     if (isIdentified)
393       knownStructNames.pop_back();
394   } while (succeeded(parser.parseOptionalComma()));
395 
396   if (parser.parseRParen() || parser.parseGreater())
397     return LLVMStructType();
398 
399   // Construct the struct with body.
400   if (!isIdentified)
401     return LLVMStructType::getLiteralChecked(loc, subtypes, isPacked);
402   auto type = LLVMStructType::getIdentifiedChecked(loc, name);
403   return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
404 }
405 
406 /// Parses a type appearing inside another LLVM dialect-compatible type. This
407 /// will try to parse any type in full form (including types with the `!llvm`
408 /// prefix), and on failure fall back to parsing the short-hand version of the
409 /// LLVM dialect types without the `!llvm` prefix.
dispatchParse(DialectAsmParser & parser,bool allowAny=true)410 static Type dispatchParse(DialectAsmParser &parser, bool allowAny = true) {
411   llvm::SMLoc keyLoc = parser.getCurrentLocation();
412   Location loc = parser.getEncodedSourceLoc(keyLoc);
413 
414   // Try parsing any MLIR type.
415   Type type;
416   OptionalParseResult result = parser.parseOptionalType(type);
417   if (result.hasValue()) {
418     if (failed(result.getValue()))
419       return nullptr;
420     // TODO: integer types are temporarily allowed for compatibility with the
421     // deprecated !llvm.i[0-9]+ syntax.
422     if (!allowAny) {
423       auto intType = type.dyn_cast<IntegerType>();
424       if (!intType || !intType.isSignless()) {
425         parser.emitError(keyLoc) << "unexpected type, expected keyword";
426         return nullptr;
427       }
428       emitWarning(loc) << "deprecated syntax, drop '!llvm.' for integers";
429     }
430     return type;
431   }
432 
433   // If no type found, fallback to the shorthand form.
434   StringRef key;
435   if (failed(parser.parseKeyword(&key)))
436     return Type();
437 
438   MLIRContext *ctx = parser.getBuilder().getContext();
439   return StringSwitch<function_ref<Type()>>(key)
440       .Case("void", [&] { return LLVMVoidType::get(ctx); })
441       .Case("bfloat",
442             [&] {
443               emitWarning(loc) << "deprecated syntax, use bf16 instead";
444               return BFloat16Type::get(ctx);
445             })
446       .Case("half",
447             [&] {
448               emitWarning(loc) << "deprecated syntax, use f16 instead";
449               return Float16Type::get(ctx);
450             })
451       .Case("float",
452             [&] {
453               emitWarning(loc) << "deprecated syntax, use f32 instead";
454               return Float32Type::get(ctx);
455             })
456       .Case("double",
457             [&] {
458               emitWarning(loc) << "deprecated syntax, use f64 instead";
459               return Float64Type::get(ctx);
460             })
461       .Case("fp128",
462             [&] {
463               emitWarning(loc) << "deprecated syntax, use f128 instead";
464               return Float128Type::get(ctx);
465             })
466       .Case("x86_fp80",
467             [&] {
468               emitWarning(loc) << "deprecated syntax, use f80 instead";
469               return Float80Type::get(ctx);
470             })
471       .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
472       .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); })
473       .Case("token", [&] { return LLVMTokenType::get(ctx); })
474       .Case("label", [&] { return LLVMLabelType::get(ctx); })
475       .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
476       .Case("func", [&] { return parseFunctionType(parser); })
477       .Case("ptr", [&] { return parsePointerType(parser); })
478       .Case("vec", [&] { return parseVectorType(parser); })
479       .Case("array", [&] { return parseArrayType(parser); })
480       .Case("struct", [&] { return parseStructType(parser); })
481       .Default([&] {
482         parser.emitError(keyLoc) << "unknown LLVM type: " << key;
483         return Type();
484       })();
485 }
486 
487 /// Helper to use in parse lists.
dispatchParse(DialectAsmParser & parser,Type & type)488 static ParseResult dispatchParse(DialectAsmParser &parser, Type &type) {
489   type = dispatchParse(parser);
490   return success(type != nullptr);
491 }
492 
493 /// Parses one of the LLVM dialect types.
parseType(DialectAsmParser & parser)494 Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
495   llvm::SMLoc loc = parser.getCurrentLocation();
496   Type type = dispatchParse(parser, /*allowAny=*/false);
497   if (!type)
498     return type;
499   if (!isCompatibleType(type)) {
500     parser.emitError(loc) << "unexpected type, expected keyword";
501     return nullptr;
502   }
503   return type;
504 }
505