1 //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
10 #include "mlir/IR/Builders.h"
11 #include "mlir/IR/DialectImplementation.h"
12 #include "llvm/ADT/ScopeExit.h"
13 #include "llvm/ADT/SetVector.h"
14 #include "llvm/ADT/TypeSwitch.h"
15
16 using namespace mlir;
17 using namespace mlir::LLVM;
18
19 //===----------------------------------------------------------------------===//
20 // Printing.
21 //===----------------------------------------------------------------------===//
22
23 /// If the given type is compatible with the LLVM dialect, prints it using
24 /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise
25 /// prints it as usual.
dispatchPrint(DialectAsmPrinter & printer,Type type)26 static void dispatchPrint(DialectAsmPrinter &printer, Type type) {
27 if (isCompatibleType(type) && !type.isa<IntegerType, FloatType, VectorType>())
28 return mlir::LLVM::detail::printType(type, printer);
29 printer.printType(type);
30 }
31
32 /// Returns the keyword to use for the given type.
getTypeKeyword(Type type)33 static StringRef getTypeKeyword(Type type) {
34 return TypeSwitch<Type, StringRef>(type)
35 .Case<LLVMVoidType>([&](Type) { return "void"; })
36 .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; })
37 .Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; })
38 .Case<LLVMTokenType>([&](Type) { return "token"; })
39 .Case<LLVMLabelType>([&](Type) { return "label"; })
40 .Case<LLVMMetadataType>([&](Type) { return "metadata"; })
41 .Case<LLVMFunctionType>([&](Type) { return "func"; })
42 .Case<LLVMPointerType>([&](Type) { return "ptr"; })
43 .Case<LLVMFixedVectorType, LLVMScalableVectorType>(
44 [&](Type) { return "vec"; })
45 .Case<LLVMArrayType>([&](Type) { return "array"; })
46 .Case<LLVMStructType>([&](Type) { return "struct"; })
47 .Default([](Type) -> StringRef {
48 llvm_unreachable("unexpected 'llvm' type kind");
49 });
50 }
51
52 /// Prints a structure type. Keeps track of known struct names to handle self-
53 /// or mutually-referring structs without falling into infinite recursion.
printStructType(DialectAsmPrinter & printer,LLVMStructType type)54 static void printStructType(DialectAsmPrinter &printer, LLVMStructType type) {
55 // This keeps track of the names of identified structure types that are
56 // currently being printed. Since such types can refer themselves, this
57 // tracking is necessary to stop the recursion: the current function may be
58 // called recursively from DialectAsmPrinter::printType after the appropriate
59 // dispatch. We maintain the invariant of this storage being modified
60 // exclusively in this function, and at most one name being added per call.
61 // TODO: consider having such functionality inside DialectAsmPrinter.
62 thread_local llvm::SetVector<StringRef> knownStructNames;
63 unsigned stackSize = knownStructNames.size();
64 (void)stackSize;
65 auto guard = llvm::make_scope_exit([&]() {
66 assert(knownStructNames.size() == stackSize &&
67 "malformed identified stack when printing recursive structs");
68 });
69
70 printer << "<";
71 if (type.isIdentified()) {
72 printer << '"' << type.getName() << '"';
73 // If we are printing a reference to one of the enclosing structs, just
74 // print the name and stop to avoid infinitely long output.
75 if (knownStructNames.count(type.getName())) {
76 printer << '>';
77 return;
78 }
79 printer << ", ";
80 }
81
82 if (type.isIdentified() && type.isOpaque()) {
83 printer << "opaque>";
84 return;
85 }
86
87 if (type.isPacked())
88 printer << "packed ";
89
90 // Put the current type on stack to avoid infinite recursion.
91 printer << '(';
92 if (type.isIdentified())
93 knownStructNames.insert(type.getName());
94 llvm::interleaveComma(type.getBody(), printer.getStream(),
95 [&](Type subtype) { dispatchPrint(printer, subtype); });
96 if (type.isIdentified())
97 knownStructNames.pop_back();
98 printer << ')';
99 printer << '>';
100 }
101
102 /// Prints a type containing a fixed number of elements.
103 template <typename TypeTy>
printArrayOrVectorType(DialectAsmPrinter & printer,TypeTy type)104 static void printArrayOrVectorType(DialectAsmPrinter &printer, TypeTy type) {
105 printer << '<' << type.getNumElements() << " x ";
106 dispatchPrint(printer, type.getElementType());
107 printer << '>';
108 }
109
110 /// Prints a function type.
printFunctionType(DialectAsmPrinter & printer,LLVMFunctionType funcType)111 static void printFunctionType(DialectAsmPrinter &printer,
112 LLVMFunctionType funcType) {
113 printer << '<';
114 dispatchPrint(printer, funcType.getReturnType());
115 printer << " (";
116 llvm::interleaveComma(
117 funcType.getParams(), printer.getStream(),
118 [&printer](Type subtype) { dispatchPrint(printer, subtype); });
119 if (funcType.isVarArg()) {
120 if (funcType.getNumParams() != 0)
121 printer << ", ";
122 printer << "...";
123 }
124 printer << ")>";
125 }
126
127 /// Prints the given LLVM dialect type recursively. This leverages closedness of
128 /// the LLVM dialect type system to avoid printing the dialect prefix
129 /// repeatedly. For recursive structures, only prints the name of the structure
130 /// when printing a self-reference. Note that this does not apply to sibling
131 /// references. For example,
132 /// struct<"a", (ptr<struct<"a">>)>
133 /// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
134 /// ptr<struct<"b", (ptr<struct<"c">>)>>)>
135 /// note that "b" is printed twice.
printType(Type type,DialectAsmPrinter & printer)136 void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) {
137 if (!type) {
138 printer << "<<NULL-TYPE>>";
139 return;
140 }
141
142 printer << getTypeKeyword(type);
143
144 if (auto ptrType = type.dyn_cast<LLVMPointerType>()) {
145 printer << '<';
146 dispatchPrint(printer, ptrType.getElementType());
147 if (ptrType.getAddressSpace() != 0)
148 printer << ", " << ptrType.getAddressSpace();
149 printer << '>';
150 return;
151 }
152
153 if (auto arrayType = type.dyn_cast<LLVMArrayType>())
154 return printArrayOrVectorType(printer, arrayType);
155 if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>())
156 return printArrayOrVectorType(printer, vectorType);
157
158 if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) {
159 printer << "<? x " << vectorType.getMinNumElements() << " x ";
160 dispatchPrint(printer, vectorType.getElementType());
161 printer << '>';
162 return;
163 }
164
165 if (auto structType = type.dyn_cast<LLVMStructType>())
166 return printStructType(printer, structType);
167
168 if (auto funcType = type.dyn_cast<LLVMFunctionType>())
169 return printFunctionType(printer, funcType);
170 }
171
172 //===----------------------------------------------------------------------===//
173 // Parsing.
174 //===----------------------------------------------------------------------===//
175
176 static ParseResult dispatchParse(DialectAsmParser &parser, Type &type);
177
178 /// Parses an LLVM dialect function type.
179 /// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
parseFunctionType(DialectAsmParser & parser)180 static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
181 Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
182 Type returnType;
183 if (parser.parseLess() || dispatchParse(parser, returnType) ||
184 parser.parseLParen())
185 return LLVMFunctionType();
186
187 // Function type without arguments.
188 if (succeeded(parser.parseOptionalRParen())) {
189 if (succeeded(parser.parseGreater()))
190 return LLVMFunctionType::getChecked(loc, returnType, {},
191 /*isVarArg=*/false);
192 return LLVMFunctionType();
193 }
194
195 // Parse arguments.
196 SmallVector<Type, 8> argTypes;
197 do {
198 if (succeeded(parser.parseOptionalEllipsis())) {
199 if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
200 return LLVMFunctionType();
201 return LLVMFunctionType::getChecked(loc, returnType, argTypes,
202 /*isVarArg=*/true);
203 }
204
205 Type arg;
206 if (dispatchParse(parser, arg))
207 return LLVMFunctionType();
208 argTypes.push_back(arg);
209 } while (succeeded(parser.parseOptionalComma()));
210
211 if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
212 return LLVMFunctionType();
213 return LLVMFunctionType::getChecked(loc, returnType, argTypes,
214 /*isVarArg=*/false);
215 }
216
217 /// Parses an LLVM dialect pointer type.
218 /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
parsePointerType(DialectAsmParser & parser)219 static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
220 Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
221 Type elementType;
222 if (parser.parseLess() || dispatchParse(parser, elementType))
223 return LLVMPointerType();
224
225 unsigned addressSpace = 0;
226 if (succeeded(parser.parseOptionalComma()) &&
227 failed(parser.parseInteger(addressSpace)))
228 return LLVMPointerType();
229 if (failed(parser.parseGreater()))
230 return LLVMPointerType();
231 return LLVMPointerType::getChecked(loc, elementType, addressSpace);
232 }
233
234 /// Parses an LLVM dialect vector type.
235 /// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>`
236 /// Supports both fixed and scalable vectors.
parseVectorType(DialectAsmParser & parser)237 static Type parseVectorType(DialectAsmParser &parser) {
238 SmallVector<int64_t, 2> dims;
239 llvm::SMLoc dimPos;
240 Type elementType;
241 Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
242 if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
243 parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
244 dispatchParse(parser, elementType) || parser.parseGreater())
245 return Type();
246
247 // We parsed a generic dimension list, but vectors only support two forms:
248 // - single non-dynamic entry in the list (fixed vector);
249 // - two elements, the first dynamic (indicated by -1) and the second
250 // non-dynamic (scalable vector).
251 if (dims.empty() || dims.size() > 2 ||
252 ((dims.size() == 2) ^ (dims[0] == -1)) ||
253 (dims.size() == 2 && dims[1] == -1)) {
254 parser.emitError(dimPos)
255 << "expected '? x <integer> x <type>' or '<integer> x <type>'";
256 return Type();
257 }
258
259 bool isScalable = dims.size() == 2;
260 if (isScalable)
261 return LLVMScalableVectorType::getChecked(loc, elementType, dims[1]);
262 if (elementType.isSignlessIntOrFloat())
263 return VectorType::getChecked(loc, dims, elementType);
264 return LLVMFixedVectorType::getChecked(loc, elementType, dims[0]);
265 }
266
267 /// Parses an LLVM dialect array type.
268 /// llvm-type ::= `array<` integer `x` llvm-type `>`
parseArrayType(DialectAsmParser & parser)269 static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
270 SmallVector<int64_t, 1> dims;
271 llvm::SMLoc sizePos;
272 Type elementType;
273 Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
274 if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
275 parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
276 dispatchParse(parser, elementType) || parser.parseGreater())
277 return LLVMArrayType();
278
279 if (dims.size() != 1) {
280 parser.emitError(sizePos) << "expected ? x <type>";
281 return LLVMArrayType();
282 }
283
284 return LLVMArrayType::getChecked(loc, elementType, dims[0]);
285 }
286
287 /// Attempts to set the body of an identified structure type. Reports a parsing
288 /// error at `subtypesLoc` in case of failure.
trySetStructBody(LLVMStructType type,ArrayRef<Type> subtypes,bool isPacked,DialectAsmParser & parser,llvm::SMLoc subtypesLoc)289 static LLVMStructType trySetStructBody(LLVMStructType type,
290 ArrayRef<Type> subtypes, bool isPacked,
291 DialectAsmParser &parser,
292 llvm::SMLoc subtypesLoc) {
293 for (Type t : subtypes) {
294 if (!LLVMStructType::isValidElementType(t)) {
295 parser.emitError(subtypesLoc)
296 << "invalid LLVM structure element type: " << t;
297 return LLVMStructType();
298 }
299 }
300
301 if (succeeded(type.setBody(subtypes, isPacked)))
302 return type;
303
304 parser.emitError(subtypesLoc)
305 << "identified type already used with a different body";
306 return LLVMStructType();
307 }
308
309 /// Parses an LLVM dialect structure type.
310 /// llvm-type ::= `struct<` (string-literal `,`)? `packed`?
311 /// `(` llvm-type-list `)` `>`
312 /// | `struct<` string-literal `>`
313 /// | `struct<` string-literal `, opaque>`
parseStructType(DialectAsmParser & parser)314 static LLVMStructType parseStructType(DialectAsmParser &parser) {
315 // This keeps track of the names of identified structure types that are
316 // currently being parsed. Since such types can refer themselves, this
317 // tracking is necessary to stop the recursion: the current function may be
318 // called recursively from DialectAsmParser::parseType after the appropriate
319 // dispatch. We maintain the invariant of this storage being modified
320 // exclusively in this function, and at most one name being added per call.
321 // TODO: consider having such functionality inside DialectAsmParser.
322 thread_local llvm::SetVector<StringRef> knownStructNames;
323 unsigned stackSize = knownStructNames.size();
324 (void)stackSize;
325 auto guard = llvm::make_scope_exit([&]() {
326 assert(knownStructNames.size() == stackSize &&
327 "malformed identified stack when parsing recursive structs");
328 });
329
330 Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
331
332 if (failed(parser.parseLess()))
333 return LLVMStructType();
334
335 // If we are parsing a self-reference to a recursive struct, i.e. the parsing
336 // stack already contains a struct with the same identifier, bail out after
337 // the name.
338 StringRef name;
339 bool isIdentified = succeeded(parser.parseOptionalString(&name));
340 if (isIdentified) {
341 if (knownStructNames.count(name)) {
342 if (failed(parser.parseGreater()))
343 return LLVMStructType();
344 return LLVMStructType::getIdentifiedChecked(loc, name);
345 }
346 if (failed(parser.parseComma()))
347 return LLVMStructType();
348 }
349
350 // Handle intentionally opaque structs.
351 llvm::SMLoc kwLoc = parser.getCurrentLocation();
352 if (succeeded(parser.parseOptionalKeyword("opaque"))) {
353 if (!isIdentified)
354 return parser.emitError(kwLoc, "only identified structs can be opaque"),
355 LLVMStructType();
356 if (failed(parser.parseGreater()))
357 return LLVMStructType();
358 auto type = LLVMStructType::getOpaqueChecked(loc, name);
359 if (!type.isOpaque()) {
360 parser.emitError(kwLoc, "redeclaring defined struct as opaque");
361 return LLVMStructType();
362 }
363 return type;
364 }
365
366 // Check for packedness.
367 bool isPacked = succeeded(parser.parseOptionalKeyword("packed"));
368 if (failed(parser.parseLParen()))
369 return LLVMStructType();
370
371 // Fast pass for structs with zero subtypes.
372 if (succeeded(parser.parseOptionalRParen())) {
373 if (failed(parser.parseGreater()))
374 return LLVMStructType();
375 if (!isIdentified)
376 return LLVMStructType::getLiteralChecked(loc, {}, isPacked);
377 auto type = LLVMStructType::getIdentifiedChecked(loc, name);
378 return trySetStructBody(type, {}, isPacked, parser, kwLoc);
379 }
380
381 // Parse subtypes. For identified structs, put the identifier of the struct on
382 // the stack to support self-references in the recursive calls.
383 SmallVector<Type, 4> subtypes;
384 llvm::SMLoc subtypesLoc = parser.getCurrentLocation();
385 do {
386 if (isIdentified)
387 knownStructNames.insert(name);
388 Type type;
389 if (dispatchParse(parser, type))
390 return LLVMStructType();
391 subtypes.push_back(type);
392 if (isIdentified)
393 knownStructNames.pop_back();
394 } while (succeeded(parser.parseOptionalComma()));
395
396 if (parser.parseRParen() || parser.parseGreater())
397 return LLVMStructType();
398
399 // Construct the struct with body.
400 if (!isIdentified)
401 return LLVMStructType::getLiteralChecked(loc, subtypes, isPacked);
402 auto type = LLVMStructType::getIdentifiedChecked(loc, name);
403 return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc);
404 }
405
406 /// Parses a type appearing inside another LLVM dialect-compatible type. This
407 /// will try to parse any type in full form (including types with the `!llvm`
408 /// prefix), and on failure fall back to parsing the short-hand version of the
409 /// LLVM dialect types without the `!llvm` prefix.
dispatchParse(DialectAsmParser & parser,bool allowAny=true)410 static Type dispatchParse(DialectAsmParser &parser, bool allowAny = true) {
411 llvm::SMLoc keyLoc = parser.getCurrentLocation();
412 Location loc = parser.getEncodedSourceLoc(keyLoc);
413
414 // Try parsing any MLIR type.
415 Type type;
416 OptionalParseResult result = parser.parseOptionalType(type);
417 if (result.hasValue()) {
418 if (failed(result.getValue()))
419 return nullptr;
420 // TODO: integer types are temporarily allowed for compatibility with the
421 // deprecated !llvm.i[0-9]+ syntax.
422 if (!allowAny) {
423 auto intType = type.dyn_cast<IntegerType>();
424 if (!intType || !intType.isSignless()) {
425 parser.emitError(keyLoc) << "unexpected type, expected keyword";
426 return nullptr;
427 }
428 emitWarning(loc) << "deprecated syntax, drop '!llvm.' for integers";
429 }
430 return type;
431 }
432
433 // If no type found, fallback to the shorthand form.
434 StringRef key;
435 if (failed(parser.parseKeyword(&key)))
436 return Type();
437
438 MLIRContext *ctx = parser.getBuilder().getContext();
439 return StringSwitch<function_ref<Type()>>(key)
440 .Case("void", [&] { return LLVMVoidType::get(ctx); })
441 .Case("bfloat",
442 [&] {
443 emitWarning(loc) << "deprecated syntax, use bf16 instead";
444 return BFloat16Type::get(ctx);
445 })
446 .Case("half",
447 [&] {
448 emitWarning(loc) << "deprecated syntax, use f16 instead";
449 return Float16Type::get(ctx);
450 })
451 .Case("float",
452 [&] {
453 emitWarning(loc) << "deprecated syntax, use f32 instead";
454 return Float32Type::get(ctx);
455 })
456 .Case("double",
457 [&] {
458 emitWarning(loc) << "deprecated syntax, use f64 instead";
459 return Float64Type::get(ctx);
460 })
461 .Case("fp128",
462 [&] {
463 emitWarning(loc) << "deprecated syntax, use f128 instead";
464 return Float128Type::get(ctx);
465 })
466 .Case("x86_fp80",
467 [&] {
468 emitWarning(loc) << "deprecated syntax, use f80 instead";
469 return Float80Type::get(ctx);
470 })
471 .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); })
472 .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); })
473 .Case("token", [&] { return LLVMTokenType::get(ctx); })
474 .Case("label", [&] { return LLVMLabelType::get(ctx); })
475 .Case("metadata", [&] { return LLVMMetadataType::get(ctx); })
476 .Case("func", [&] { return parseFunctionType(parser); })
477 .Case("ptr", [&] { return parsePointerType(parser); })
478 .Case("vec", [&] { return parseVectorType(parser); })
479 .Case("array", [&] { return parseArrayType(parser); })
480 .Case("struct", [&] { return parseStructType(parser); })
481 .Default([&] {
482 parser.emitError(keyLoc) << "unknown LLVM type: " << key;
483 return Type();
484 })();
485 }
486
487 /// Helper to use in parse lists.
dispatchParse(DialectAsmParser & parser,Type & type)488 static ParseResult dispatchParse(DialectAsmParser &parser, Type &type) {
489 type = dispatchParse(parser);
490 return success(type != nullptr);
491 }
492
493 /// Parses one of the LLVM dialect types.
parseType(DialectAsmParser & parser)494 Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
495 llvm::SMLoc loc = parser.getCurrentLocation();
496 Type type = dispatchParse(parser, /*allowAny=*/false);
497 if (!type)
498 return type;
499 if (!isCompatibleType(type)) {
500 parser.emitError(loc) << "unexpected type, expected keyword";
501 return nullptr;
502 }
503 return type;
504 }
505