1 //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/DialectImplementation.h" 12 #include "llvm/ADT/ScopeExit.h" 13 #include "llvm/ADT/SetVector.h" 14 #include "llvm/ADT/TypeSwitch.h" 15 16 using namespace mlir; 17 using namespace mlir::LLVM; 18 19 //===----------------------------------------------------------------------===// 20 // Printing. 21 //===----------------------------------------------------------------------===// 22 23 /// If the given type is compatible with the LLVM dialect, prints it using 24 /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise 25 /// prints it as usual. 26 static void dispatchPrint(DialectAsmPrinter &printer, Type type) { 27 if (isCompatibleType(type) && !type.isa<IntegerType, FloatType, VectorType>()) 28 return mlir::LLVM::detail::printType(type, printer); 29 printer.printType(type); 30 } 31 32 /// Returns the keyword to use for the given type. 33 static StringRef getTypeKeyword(Type type) { 34 return TypeSwitch<Type, StringRef>(type) 35 .Case<LLVMVoidType>([&](Type) { return "void"; }) 36 .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; }) 37 .Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; }) 38 .Case<LLVMTokenType>([&](Type) { return "token"; }) 39 .Case<LLVMLabelType>([&](Type) { return "label"; }) 40 .Case<LLVMMetadataType>([&](Type) { return "metadata"; }) 41 .Case<LLVMFunctionType>([&](Type) { return "func"; }) 42 .Case<LLVMPointerType>([&](Type) { return "ptr"; }) 43 .Case<LLVMFixedVectorType, LLVMScalableVectorType>( 44 [&](Type) { return "vec"; }) 45 .Case<LLVMArrayType>([&](Type) { return "array"; }) 46 .Case<LLVMStructType>([&](Type) { return "struct"; }) 47 .Default([](Type) -> StringRef { 48 llvm_unreachable("unexpected 'llvm' type kind"); 49 }); 50 } 51 52 /// Prints a structure type. Keeps track of known struct names to handle self- 53 /// or mutually-referring structs without falling into infinite recursion. 54 static void printStructType(DialectAsmPrinter &printer, LLVMStructType type) { 55 // This keeps track of the names of identified structure types that are 56 // currently being printed. Since such types can refer themselves, this 57 // tracking is necessary to stop the recursion: the current function may be 58 // called recursively from DialectAsmPrinter::printType after the appropriate 59 // dispatch. We maintain the invariant of this storage being modified 60 // exclusively in this function, and at most one name being added per call. 61 // TODO: consider having such functionality inside DialectAsmPrinter. 62 thread_local SetVector<StringRef> knownStructNames; 63 unsigned stackSize = knownStructNames.size(); 64 (void)stackSize; 65 auto guard = llvm::make_scope_exit([&]() { 66 assert(knownStructNames.size() == stackSize && 67 "malformed identified stack when printing recursive structs"); 68 }); 69 70 printer << "<"; 71 if (type.isIdentified()) { 72 printer << '"' << type.getName() << '"'; 73 // If we are printing a reference to one of the enclosing structs, just 74 // print the name and stop to avoid infinitely long output. 75 if (knownStructNames.count(type.getName())) { 76 printer << '>'; 77 return; 78 } 79 printer << ", "; 80 } 81 82 if (type.isIdentified() && type.isOpaque()) { 83 printer << "opaque>"; 84 return; 85 } 86 87 if (type.isPacked()) 88 printer << "packed "; 89 90 // Put the current type on stack to avoid infinite recursion. 91 printer << '('; 92 if (type.isIdentified()) 93 knownStructNames.insert(type.getName()); 94 llvm::interleaveComma(type.getBody(), printer.getStream(), 95 [&](Type subtype) { dispatchPrint(printer, subtype); }); 96 if (type.isIdentified()) 97 knownStructNames.pop_back(); 98 printer << ')'; 99 printer << '>'; 100 } 101 102 /// Prints a type containing a fixed number of elements. 103 template <typename TypeTy> 104 static void printArrayOrVectorType(DialectAsmPrinter &printer, TypeTy type) { 105 printer << '<' << type.getNumElements() << " x "; 106 dispatchPrint(printer, type.getElementType()); 107 printer << '>'; 108 } 109 110 /// Prints a function type. 111 static void printFunctionType(DialectAsmPrinter &printer, 112 LLVMFunctionType funcType) { 113 printer << '<'; 114 dispatchPrint(printer, funcType.getReturnType()); 115 printer << " ("; 116 llvm::interleaveComma( 117 funcType.getParams(), printer.getStream(), 118 [&printer](Type subtype) { dispatchPrint(printer, subtype); }); 119 if (funcType.isVarArg()) { 120 if (funcType.getNumParams() != 0) 121 printer << ", "; 122 printer << "..."; 123 } 124 printer << ")>"; 125 } 126 127 /// Prints the given LLVM dialect type recursively. This leverages closedness of 128 /// the LLVM dialect type system to avoid printing the dialect prefix 129 /// repeatedly. For recursive structures, only prints the name of the structure 130 /// when printing a self-reference. Note that this does not apply to sibling 131 /// references. For example, 132 /// struct<"a", (ptr<struct<"a">>)> 133 /// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>, 134 /// ptr<struct<"b", (ptr<struct<"c">>)>>)> 135 /// note that "b" is printed twice. 136 void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) { 137 if (!type) { 138 printer << "<<NULL-TYPE>>"; 139 return; 140 } 141 142 printer << getTypeKeyword(type); 143 144 if (auto ptrType = type.dyn_cast<LLVMPointerType>()) { 145 printer << '<'; 146 dispatchPrint(printer, ptrType.getElementType()); 147 if (ptrType.getAddressSpace() != 0) 148 printer << ", " << ptrType.getAddressSpace(); 149 printer << '>'; 150 return; 151 } 152 153 if (auto arrayType = type.dyn_cast<LLVMArrayType>()) 154 return printArrayOrVectorType(printer, arrayType); 155 if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>()) 156 return printArrayOrVectorType(printer, vectorType); 157 158 if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) { 159 printer << "<? x " << vectorType.getMinNumElements() << " x "; 160 dispatchPrint(printer, vectorType.getElementType()); 161 printer << '>'; 162 return; 163 } 164 165 if (auto structType = type.dyn_cast<LLVMStructType>()) 166 return printStructType(printer, structType); 167 168 if (auto funcType = type.dyn_cast<LLVMFunctionType>()) 169 return printFunctionType(printer, funcType); 170 } 171 172 //===----------------------------------------------------------------------===// 173 // Parsing. 174 //===----------------------------------------------------------------------===// 175 176 static ParseResult dispatchParse(DialectAsmParser &parser, Type &type); 177 178 /// Parses an LLVM dialect function type. 179 /// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>` 180 static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) { 181 llvm::SMLoc loc = parser.getCurrentLocation(); 182 Type returnType; 183 if (parser.parseLess() || dispatchParse(parser, returnType) || 184 parser.parseLParen()) 185 return LLVMFunctionType(); 186 187 // Function type without arguments. 188 if (succeeded(parser.parseOptionalRParen())) { 189 if (succeeded(parser.parseGreater())) 190 return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None, 191 /*isVarArg=*/false); 192 return LLVMFunctionType(); 193 } 194 195 // Parse arguments. 196 SmallVector<Type, 8> argTypes; 197 do { 198 if (succeeded(parser.parseOptionalEllipsis())) { 199 if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) 200 return LLVMFunctionType(); 201 return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes, 202 /*isVarArg=*/true); 203 } 204 205 Type arg; 206 if (dispatchParse(parser, arg)) 207 return LLVMFunctionType(); 208 argTypes.push_back(arg); 209 } while (succeeded(parser.parseOptionalComma())); 210 211 if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) 212 return LLVMFunctionType(); 213 return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes, 214 /*isVarArg=*/false); 215 } 216 217 /// Parses an LLVM dialect pointer type. 218 /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>` 219 static LLVMPointerType parsePointerType(DialectAsmParser &parser) { 220 llvm::SMLoc loc = parser.getCurrentLocation(); 221 Type elementType; 222 if (parser.parseLess() || dispatchParse(parser, elementType)) 223 return LLVMPointerType(); 224 225 unsigned addressSpace = 0; 226 if (succeeded(parser.parseOptionalComma()) && 227 failed(parser.parseInteger(addressSpace))) 228 return LLVMPointerType(); 229 if (failed(parser.parseGreater())) 230 return LLVMPointerType(); 231 return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace); 232 } 233 234 /// Parses an LLVM dialect vector type. 235 /// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>` 236 /// Supports both fixed and scalable vectors. 237 static Type parseVectorType(DialectAsmParser &parser) { 238 SmallVector<int64_t, 2> dims; 239 llvm::SMLoc dimPos, typePos; 240 Type elementType; 241 llvm::SMLoc loc = parser.getCurrentLocation(); 242 if (parser.parseLess() || parser.getCurrentLocation(&dimPos) || 243 parser.parseDimensionList(dims, /*allowDynamic=*/true) || 244 parser.getCurrentLocation(&typePos) || 245 dispatchParse(parser, elementType) || parser.parseGreater()) 246 return Type(); 247 248 // We parsed a generic dimension list, but vectors only support two forms: 249 // - single non-dynamic entry in the list (fixed vector); 250 // - two elements, the first dynamic (indicated by -1) and the second 251 // non-dynamic (scalable vector). 252 if (dims.empty() || dims.size() > 2 || 253 ((dims.size() == 2) ^ (dims[0] == -1)) || 254 (dims.size() == 2 && dims[1] == -1)) { 255 parser.emitError(dimPos) 256 << "expected '? x <integer> x <type>' or '<integer> x <type>'"; 257 return Type(); 258 } 259 260 bool isScalable = dims.size() == 2; 261 if (isScalable) 262 return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]); 263 if (elementType.isSignlessIntOrFloat()) { 264 parser.emitError(typePos) 265 << "cannot use !llvm.vec for built-in primitives, use 'vector' instead"; 266 return Type(); 267 } 268 return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]); 269 } 270 271 /// Parses an LLVM dialect array type. 272 /// llvm-type ::= `array<` integer `x` llvm-type `>` 273 static LLVMArrayType parseArrayType(DialectAsmParser &parser) { 274 SmallVector<int64_t, 1> dims; 275 llvm::SMLoc sizePos; 276 Type elementType; 277 llvm::SMLoc loc = parser.getCurrentLocation(); 278 if (parser.parseLess() || parser.getCurrentLocation(&sizePos) || 279 parser.parseDimensionList(dims, /*allowDynamic=*/false) || 280 dispatchParse(parser, elementType) || parser.parseGreater()) 281 return LLVMArrayType(); 282 283 if (dims.size() != 1) { 284 parser.emitError(sizePos) << "expected ? x <type>"; 285 return LLVMArrayType(); 286 } 287 288 return parser.getChecked<LLVMArrayType>(loc, elementType, dims[0]); 289 } 290 291 /// Attempts to set the body of an identified structure type. Reports a parsing 292 /// error at `subtypesLoc` in case of failure. 293 static LLVMStructType trySetStructBody(LLVMStructType type, 294 ArrayRef<Type> subtypes, bool isPacked, 295 DialectAsmParser &parser, 296 llvm::SMLoc subtypesLoc) { 297 for (Type t : subtypes) { 298 if (!LLVMStructType::isValidElementType(t)) { 299 parser.emitError(subtypesLoc) 300 << "invalid LLVM structure element type: " << t; 301 return LLVMStructType(); 302 } 303 } 304 305 if (succeeded(type.setBody(subtypes, isPacked))) 306 return type; 307 308 parser.emitError(subtypesLoc) 309 << "identified type already used with a different body"; 310 return LLVMStructType(); 311 } 312 313 /// Parses an LLVM dialect structure type. 314 /// llvm-type ::= `struct<` (string-literal `,`)? `packed`? 315 /// `(` llvm-type-list `)` `>` 316 /// | `struct<` string-literal `>` 317 /// | `struct<` string-literal `, opaque>` 318 static LLVMStructType parseStructType(DialectAsmParser &parser) { 319 // This keeps track of the names of identified structure types that are 320 // currently being parsed. Since such types can refer themselves, this 321 // tracking is necessary to stop the recursion: the current function may be 322 // called recursively from DialectAsmParser::parseType after the appropriate 323 // dispatch. We maintain the invariant of this storage being modified 324 // exclusively in this function, and at most one name being added per call. 325 // TODO: consider having such functionality inside DialectAsmParser. 326 thread_local SetVector<StringRef> knownStructNames; 327 unsigned stackSize = knownStructNames.size(); 328 (void)stackSize; 329 auto guard = llvm::make_scope_exit([&]() { 330 assert(knownStructNames.size() == stackSize && 331 "malformed identified stack when parsing recursive structs"); 332 }); 333 334 Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); 335 336 if (failed(parser.parseLess())) 337 return LLVMStructType(); 338 339 // If we are parsing a self-reference to a recursive struct, i.e. the parsing 340 // stack already contains a struct with the same identifier, bail out after 341 // the name. 342 StringRef name; 343 bool isIdentified = succeeded(parser.parseOptionalString(&name)); 344 if (isIdentified) { 345 if (knownStructNames.count(name)) { 346 if (failed(parser.parseGreater())) 347 return LLVMStructType(); 348 return LLVMStructType::getIdentifiedChecked( 349 [loc] { return emitError(loc); }, loc.getContext(), name); 350 } 351 if (failed(parser.parseComma())) 352 return LLVMStructType(); 353 } 354 355 // Handle intentionally opaque structs. 356 llvm::SMLoc kwLoc = parser.getCurrentLocation(); 357 if (succeeded(parser.parseOptionalKeyword("opaque"))) { 358 if (!isIdentified) 359 return parser.emitError(kwLoc, "only identified structs can be opaque"), 360 LLVMStructType(); 361 if (failed(parser.parseGreater())) 362 return LLVMStructType(); 363 auto type = LLVMStructType::getOpaqueChecked( 364 [loc] { return emitError(loc); }, loc.getContext(), name); 365 if (!type.isOpaque()) { 366 parser.emitError(kwLoc, "redeclaring defined struct as opaque"); 367 return LLVMStructType(); 368 } 369 return type; 370 } 371 372 // Check for packedness. 373 bool isPacked = succeeded(parser.parseOptionalKeyword("packed")); 374 if (failed(parser.parseLParen())) 375 return LLVMStructType(); 376 377 // Fast pass for structs with zero subtypes. 378 if (succeeded(parser.parseOptionalRParen())) { 379 if (failed(parser.parseGreater())) 380 return LLVMStructType(); 381 if (!isIdentified) 382 return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); }, 383 loc.getContext(), {}, isPacked); 384 auto type = LLVMStructType::getIdentifiedChecked( 385 [loc] { return emitError(loc); }, loc.getContext(), name); 386 return trySetStructBody(type, {}, isPacked, parser, kwLoc); 387 } 388 389 // Parse subtypes. For identified structs, put the identifier of the struct on 390 // the stack to support self-references in the recursive calls. 391 SmallVector<Type, 4> subtypes; 392 llvm::SMLoc subtypesLoc = parser.getCurrentLocation(); 393 do { 394 if (isIdentified) 395 knownStructNames.insert(name); 396 Type type; 397 if (dispatchParse(parser, type)) 398 return LLVMStructType(); 399 subtypes.push_back(type); 400 if (isIdentified) 401 knownStructNames.pop_back(); 402 } while (succeeded(parser.parseOptionalComma())); 403 404 if (parser.parseRParen() || parser.parseGreater()) 405 return LLVMStructType(); 406 407 // Construct the struct with body. 408 if (!isIdentified) 409 return LLVMStructType::getLiteralChecked( 410 [loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked); 411 auto type = LLVMStructType::getIdentifiedChecked( 412 [loc] { return emitError(loc); }, loc.getContext(), name); 413 return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc); 414 } 415 416 /// Parses a type appearing inside another LLVM dialect-compatible type. This 417 /// will try to parse any type in full form (including types with the `!llvm` 418 /// prefix), and on failure fall back to parsing the short-hand version of the 419 /// LLVM dialect types without the `!llvm` prefix. 420 static Type dispatchParse(DialectAsmParser &parser, bool allowAny = true) { 421 llvm::SMLoc keyLoc = parser.getCurrentLocation(); 422 423 // Try parsing any MLIR type. 424 Type type; 425 OptionalParseResult result = parser.parseOptionalType(type); 426 if (result.hasValue()) { 427 if (failed(result.getValue())) 428 return nullptr; 429 if (!allowAny) { 430 parser.emitError(keyLoc) << "unexpected type, expected keyword"; 431 return nullptr; 432 } 433 return type; 434 } 435 436 // If no type found, fallback to the shorthand form. 437 StringRef key; 438 if (failed(parser.parseKeyword(&key))) 439 return Type(); 440 441 MLIRContext *ctx = parser.getBuilder().getContext(); 442 return StringSwitch<function_ref<Type()>>(key) 443 .Case("void", [&] { return LLVMVoidType::get(ctx); }) 444 .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); }) 445 .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); }) 446 .Case("token", [&] { return LLVMTokenType::get(ctx); }) 447 .Case("label", [&] { return LLVMLabelType::get(ctx); }) 448 .Case("metadata", [&] { return LLVMMetadataType::get(ctx); }) 449 .Case("func", [&] { return parseFunctionType(parser); }) 450 .Case("ptr", [&] { return parsePointerType(parser); }) 451 .Case("vec", [&] { return parseVectorType(parser); }) 452 .Case("array", [&] { return parseArrayType(parser); }) 453 .Case("struct", [&] { return parseStructType(parser); }) 454 .Default([&] { 455 parser.emitError(keyLoc) << "unknown LLVM type: " << key; 456 return Type(); 457 })(); 458 } 459 460 /// Helper to use in parse lists. 461 static ParseResult dispatchParse(DialectAsmParser &parser, Type &type) { 462 type = dispatchParse(parser); 463 return success(type != nullptr); 464 } 465 466 /// Parses one of the LLVM dialect types. 467 Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) { 468 llvm::SMLoc loc = parser.getCurrentLocation(); 469 Type type = dispatchParse(parser, /*allowAny=*/false); 470 if (!type) 471 return type; 472 if (!isCompatibleType(type)) { 473 parser.emitError(loc) << "unexpected type, expected keyword"; 474 return nullptr; 475 } 476 return type; 477 } 478