1 //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This classes used by the implementation details of Op types. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_IR_OPIMPLEMENTATION_H 14 #define MLIR_IR_OPIMPLEMENTATION_H 15 16 #include "mlir/IR/BuiltinTypes.h" 17 #include "mlir/IR/DialectInterface.h" 18 #include "mlir/IR/OpDefinition.h" 19 #include "llvm/ADT/Twine.h" 20 #include "llvm/Support/SMLoc.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 namespace mlir { 24 25 class Builder; 26 27 //===----------------------------------------------------------------------===// 28 // OpAsmPrinter 29 //===----------------------------------------------------------------------===// 30 31 /// This is a pure-virtual base class that exposes the asmprinter hooks 32 /// necessary to implement a custom print() method. 33 class OpAsmPrinter { 34 public: OpAsmPrinter()35 OpAsmPrinter() {} 36 virtual ~OpAsmPrinter(); 37 virtual raw_ostream &getStream() const = 0; 38 39 /// Print a newline and indent the printer to the start of the current 40 /// operation. 41 virtual void printNewline() = 0; 42 43 /// Print implementations for various things an operation contains. 44 virtual void printOperand(Value value) = 0; 45 virtual void printOperand(Value value, raw_ostream &os) = 0; 46 47 /// Print a comma separated list of operands. 48 template <typename ContainerType> printOperands(const ContainerType & container)49 void printOperands(const ContainerType &container) { 50 printOperands(container.begin(), container.end()); 51 } 52 53 /// Print a comma separated list of operands. 54 template <typename IteratorType> printOperands(IteratorType it,IteratorType end)55 void printOperands(IteratorType it, IteratorType end) { 56 if (it == end) 57 return; 58 printOperand(*it); 59 for (++it; it != end; ++it) { 60 getStream() << ", "; 61 printOperand(*it); 62 } 63 } 64 virtual void printType(Type type) = 0; 65 virtual void printAttribute(Attribute attr) = 0; 66 67 /// Print the given attribute without its type. The corresponding parser must 68 /// provide a valid type for the attribute. 69 virtual void printAttributeWithoutType(Attribute attr) = 0; 70 71 /// Print the given successor. 72 virtual void printSuccessor(Block *successor) = 0; 73 74 /// Print the successor and its operands. 75 virtual void printSuccessorAndUseList(Block *successor, 76 ValueRange succOperands) = 0; 77 78 /// If the specified operation has attributes, print out an attribute 79 /// dictionary with their values. elidedAttrs allows the client to ignore 80 /// specific well known attributes, commonly used if the attribute value is 81 /// printed some other way (like as a fixed operand). 82 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, 83 ArrayRef<StringRef> elidedAttrs = {}) = 0; 84 85 /// If the specified operation has attributes, print out an attribute 86 /// dictionary prefixed with 'attributes'. 87 virtual void 88 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs, 89 ArrayRef<StringRef> elidedAttrs = {}) = 0; 90 91 /// Print the entire operation with the default generic assembly form. 92 virtual void printGenericOp(Operation *op) = 0; 93 94 /// Prints a region. 95 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, 96 bool printBlockTerminators = true) = 0; 97 98 /// Renumber the arguments for the specified region to the same names as the 99 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove 100 /// operations. If any entry in namesToUse is null, the corresponding 101 /// argument name is left alone. 102 virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; 103 104 /// Prints an affine map of SSA ids, where SSA id names are used in place 105 /// of dims/symbols. 106 /// Operand values must come from single-result sources, and be valid 107 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. 108 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, 109 ValueRange operands) = 0; 110 111 /// Print an optional arrow followed by a type list. 112 template <typename TypeRange> printOptionalArrowTypeList(TypeRange && types)113 void printOptionalArrowTypeList(TypeRange &&types) { 114 if (types.begin() != types.end()) 115 printArrowTypeList(types); 116 } 117 template <typename TypeRange> printArrowTypeList(TypeRange && types)118 void printArrowTypeList(TypeRange &&types) { 119 auto &os = getStream() << " -> "; 120 121 bool wrapped = !llvm::hasSingleElement(types) || 122 (*types.begin()).template isa<FunctionType>(); 123 if (wrapped) 124 os << '('; 125 llvm::interleaveComma(types, *this); 126 if (wrapped) 127 os << ')'; 128 } 129 130 /// Print the complete type of an operation in functional form. 131 void printFunctionalType(Operation *op); 132 133 /// Print the two given type ranges in a functional form. 134 template <typename InputRangeT, typename ResultRangeT> printFunctionalType(InputRangeT && inputs,ResultRangeT && results)135 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { 136 auto &os = getStream(); 137 os << '('; 138 llvm::interleaveComma(inputs, *this); 139 os << ')'; 140 printArrowTypeList(results); 141 } 142 143 /// Print the given string as a symbol reference, i.e. a form representable by 144 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed 145 /// with '@'. The reference is surrounded with ""'s and escaped if it has any 146 /// special or non-printable characters in it. 147 virtual void printSymbolName(StringRef symbolRef) = 0; 148 149 private: 150 OpAsmPrinter(const OpAsmPrinter &) = delete; 151 void operator=(const OpAsmPrinter &) = delete; 152 }; 153 154 // Make the implementations convenient to use. 155 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) { 156 p.printOperand(value); 157 return p; 158 } 159 160 template <typename T, 161 typename std::enable_if<std::is_convertible<T &, ValueRange>::value && 162 !std::is_convertible<T &, Value &>::value, 163 T>::type * = nullptr> 164 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { 165 p.printOperands(values); 166 return p; 167 } 168 169 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) { 170 p.printType(type); 171 return p; 172 } 173 174 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) { 175 p.printAttribute(attr); 176 return p; 177 } 178 179 // Support printing anything that isn't convertible to one of the above types, 180 // even if it isn't exactly one of them. For example, we want to print 181 // FunctionType with the Type version above, not have it match this. 182 template <typename T, typename std::enable_if< 183 !std::is_convertible<T &, Value &>::value && 184 !std::is_convertible<T &, Type &>::value && 185 !std::is_convertible<T &, Attribute &>::value && 186 !std::is_convertible<T &, ValueRange>::value && 187 !llvm::is_one_of<T, bool>::value, 188 T>::type * = nullptr> 189 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) { 190 p.getStream() << other; 191 return p; 192 } 193 194 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) { 195 return p << (value ? StringRef("true") : "false"); 196 } 197 198 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { 199 p.printSuccessor(value); 200 return p; 201 } 202 203 template <typename ValueRangeT> 204 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, 205 const ValueTypeRange<ValueRangeT> &types) { 206 llvm::interleaveComma(types, p); 207 return p; 208 } 209 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const TypeRange &types) { 210 llvm::interleaveComma(types, p); 211 return p; 212 } 213 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) { 214 llvm::interleaveComma(types, p); 215 return p; 216 } 217 218 //===----------------------------------------------------------------------===// 219 // OpAsmParser 220 //===----------------------------------------------------------------------===// 221 222 /// The OpAsmParser has methods for interacting with the asm parser: parsing 223 /// things from it, emitting errors etc. It has an intentionally high-level API 224 /// that is designed to reduce/constrain syntax innovation in individual 225 /// operations. 226 /// 227 /// For example, consider an op like this: 228 /// 229 /// %x = load %p[%1, %2] : memref<...> 230 /// 231 /// The "%x = load" tokens are already parsed and therefore invisible to the 232 /// custom op parser. This can be supported by calling `parseOperandList` to 233 /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to 234 /// parse the indices, then calling `parseColonTypeList` to parse the result 235 /// type. 236 /// 237 class OpAsmParser { 238 public: 239 virtual ~OpAsmParser(); 240 241 /// Emit a diagnostic at the specified location and return failure. 242 virtual InFlightDiagnostic emitError(llvm::SMLoc loc, 243 const Twine &message = {}) = 0; 244 245 /// Return a builder which provides useful access to MLIRContext, global 246 /// objects like types and attributes. 247 virtual Builder &getBuilder() const = 0; 248 249 /// Get the location of the next token and store it into the argument. This 250 /// always succeeds. 251 virtual llvm::SMLoc getCurrentLocation() = 0; getCurrentLocation(llvm::SMLoc * loc)252 ParseResult getCurrentLocation(llvm::SMLoc *loc) { 253 *loc = getCurrentLocation(); 254 return success(); 255 } 256 257 /// Return the name of the specified result in the specified syntax, as well 258 /// as the sub-element in the name. It returns an empty string and ~0U for 259 /// invalid result numbers. For example, in this operation: 260 /// 261 /// %x, %y:2, %z = foo.op 262 /// 263 /// getResultName(0) == {"x", 0 } 264 /// getResultName(1) == {"y", 0 } 265 /// getResultName(2) == {"y", 1 } 266 /// getResultName(3) == {"z", 0 } 267 /// getResultName(4) == {"", ~0U } 268 virtual std::pair<StringRef, unsigned> 269 getResultName(unsigned resultNo) const = 0; 270 271 /// Return the number of declared SSA results. This returns 4 for the foo.op 272 /// example in the comment for `getResultName`. 273 virtual size_t getNumResults() const = 0; 274 275 /// Return the location of the original name token. 276 virtual llvm::SMLoc getNameLoc() const = 0; 277 278 // These methods emit an error and return failure or success. This allows 279 // these to be chained together into a linear sequence of || expressions in 280 // many cases. 281 282 /// Parse an operation in its generic form. 283 /// The parsed operation is parsed in the current context and inserted in the 284 /// provided block and insertion point. The results produced by this operation 285 /// aren't mapped to any named value in the parser. Returns nullptr on 286 /// failure. 287 virtual Operation *parseGenericOperation(Block *insertBlock, 288 Block::iterator insertPt) = 0; 289 290 //===--------------------------------------------------------------------===// 291 // Token Parsing 292 //===--------------------------------------------------------------------===// 293 294 /// Parse a '->' token. 295 virtual ParseResult parseArrow() = 0; 296 297 /// Parse a '->' token if present 298 virtual ParseResult parseOptionalArrow() = 0; 299 300 /// Parse a `{` token. 301 virtual ParseResult parseLBrace() = 0; 302 303 /// Parse a `{` token if present. 304 virtual ParseResult parseOptionalLBrace() = 0; 305 306 /// Parse a `}` token. 307 virtual ParseResult parseRBrace() = 0; 308 309 /// Parse a `}` token if present. 310 virtual ParseResult parseOptionalRBrace() = 0; 311 312 /// Parse a `:` token. 313 virtual ParseResult parseColon() = 0; 314 315 /// Parse a `:` token if present. 316 virtual ParseResult parseOptionalColon() = 0; 317 318 /// Parse a `,` token. 319 virtual ParseResult parseComma() = 0; 320 321 /// Parse a `,` token if present. 322 virtual ParseResult parseOptionalComma() = 0; 323 324 /// Parse a `=` token. 325 virtual ParseResult parseEqual() = 0; 326 327 /// Parse a `=` token if present. 328 virtual ParseResult parseOptionalEqual() = 0; 329 330 /// Parse a '<' token. 331 virtual ParseResult parseLess() = 0; 332 333 /// Parse a '<' token if present. 334 virtual ParseResult parseOptionalLess() = 0; 335 336 /// Parse a '>' token. 337 virtual ParseResult parseGreater() = 0; 338 339 /// Parse a '>' token if present. 340 virtual ParseResult parseOptionalGreater() = 0; 341 342 /// Parse a '?' token. 343 virtual ParseResult parseQuestion() = 0; 344 345 /// Parse a '?' token if present. 346 virtual ParseResult parseOptionalQuestion() = 0; 347 348 /// Parse a '+' token. 349 virtual ParseResult parsePlus() = 0; 350 351 /// Parse a '+' token if present. 352 virtual ParseResult parseOptionalPlus() = 0; 353 354 /// Parse a '*' token. 355 virtual ParseResult parseStar() = 0; 356 357 /// Parse a '*' token if present. 358 virtual ParseResult parseOptionalStar() = 0; 359 360 /// Parse a given keyword. 361 ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { 362 auto loc = getCurrentLocation(); 363 if (parseOptionalKeyword(keyword)) 364 return emitError(loc, "expected '") << keyword << "'" << msg; 365 return success(); 366 } 367 368 /// Parse a keyword into 'keyword'. parseKeyword(StringRef * keyword)369 ParseResult parseKeyword(StringRef *keyword) { 370 auto loc = getCurrentLocation(); 371 if (parseOptionalKeyword(keyword)) 372 return emitError(loc, "expected valid keyword"); 373 return success(); 374 } 375 376 /// Parse the given keyword if present. 377 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; 378 379 /// Parse a keyword, if present, into 'keyword'. 380 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; 381 382 /// Parse a keyword, if present, and if one of the 'allowedValues', 383 /// into 'keyword' 384 virtual ParseResult 385 parseOptionalKeyword(StringRef *keyword, 386 ArrayRef<StringRef> allowedValues) = 0; 387 388 /// Parse a `(` token. 389 virtual ParseResult parseLParen() = 0; 390 391 /// Parse a `(` token if present. 392 virtual ParseResult parseOptionalLParen() = 0; 393 394 /// Parse a `)` token. 395 virtual ParseResult parseRParen() = 0; 396 397 /// Parse a `)` token if present. 398 virtual ParseResult parseOptionalRParen() = 0; 399 400 /// Parse a `[` token. 401 virtual ParseResult parseLSquare() = 0; 402 403 /// Parse a `[` token if present. 404 virtual ParseResult parseOptionalLSquare() = 0; 405 406 /// Parse a `]` token. 407 virtual ParseResult parseRSquare() = 0; 408 409 /// Parse a `]` token if present. 410 virtual ParseResult parseOptionalRSquare() = 0; 411 412 /// Parse a `...` token if present; 413 virtual ParseResult parseOptionalEllipsis() = 0; 414 415 /// Parse an integer value from the stream. 416 template <typename IntT> parseInteger(IntT & result)417 ParseResult parseInteger(IntT &result) { 418 auto loc = getCurrentLocation(); 419 OptionalParseResult parseResult = parseOptionalInteger(result); 420 if (!parseResult.hasValue()) 421 return emitError(loc, "expected integer value"); 422 return *parseResult; 423 } 424 425 /// Parse an optional integer value from the stream. 426 virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0; 427 428 template <typename IntT> parseOptionalInteger(IntT & result)429 OptionalParseResult parseOptionalInteger(IntT &result) { 430 auto loc = getCurrentLocation(); 431 432 // Parse the unsigned variant. 433 uint64_t uintResult; 434 OptionalParseResult parseResult = parseOptionalInteger(uintResult); 435 if (!parseResult.hasValue() || failed(*parseResult)) 436 return parseResult; 437 438 // Try to convert to the provided integer type. 439 result = IntT(uintResult); 440 if (uint64_t(result) != uintResult) 441 return emitError(loc, "integer value too large"); 442 return success(); 443 } 444 445 //===--------------------------------------------------------------------===// 446 // Attribute Parsing 447 //===--------------------------------------------------------------------===// 448 449 /// Parse an arbitrary attribute of a given type and return it in result. 450 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; 451 452 /// Parse an attribute of a specific kind and type. 453 template <typename AttrType> 454 ParseResult parseAttribute(AttrType &result, Type type = {}) { 455 llvm::SMLoc loc = getCurrentLocation(); 456 457 // Parse any kind of attribute. 458 Attribute attr; 459 if (parseAttribute(attr, type)) 460 return failure(); 461 462 // Check for the right kind of attribute. 463 if (!(result = attr.dyn_cast<AttrType>())) 464 return emitError(loc, "invalid kind of attribute specified"); 465 466 return success(); 467 } 468 469 /// Parse an arbitrary attribute and return it in result. This also adds the 470 /// attribute to the specified attribute list with the specified name. parseAttribute(Attribute & result,StringRef attrName,NamedAttrList & attrs)471 ParseResult parseAttribute(Attribute &result, StringRef attrName, 472 NamedAttrList &attrs) { 473 return parseAttribute(result, Type(), attrName, attrs); 474 } 475 476 /// Parse an attribute of a specific kind and type. 477 template <typename AttrType> parseAttribute(AttrType & result,StringRef attrName,NamedAttrList & attrs)478 ParseResult parseAttribute(AttrType &result, StringRef attrName, 479 NamedAttrList &attrs) { 480 return parseAttribute(result, Type(), attrName, attrs); 481 } 482 483 /// Parse an optional attribute. 484 virtual OptionalParseResult parseOptionalAttribute(Attribute &result, 485 Type type, 486 StringRef attrName, 487 NamedAttrList &attrs) = 0; 488 template <typename AttrT> parseOptionalAttribute(AttrT & result,StringRef attrName,NamedAttrList & attrs)489 OptionalParseResult parseOptionalAttribute(AttrT &result, StringRef attrName, 490 NamedAttrList &attrs) { 491 return parseOptionalAttribute(result, Type(), attrName, attrs); 492 } 493 494 /// Specialized variants of `parseOptionalAttribute` that remove potential 495 /// ambiguities in syntax. 496 virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result, 497 Type type, 498 StringRef attrName, 499 NamedAttrList &attrs) = 0; 500 virtual OptionalParseResult parseOptionalAttribute(StringAttr &result, 501 Type type, 502 StringRef attrName, 503 NamedAttrList &attrs) = 0; 504 505 /// Parse an arbitrary attribute of a given type and return it in result. This 506 /// also adds the attribute to the specified attribute list with the specified 507 /// name. 508 template <typename AttrType> parseAttribute(AttrType & result,Type type,StringRef attrName,NamedAttrList & attrs)509 ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, 510 NamedAttrList &attrs) { 511 llvm::SMLoc loc = getCurrentLocation(); 512 513 // Parse any kind of attribute. 514 Attribute attr; 515 if (parseAttribute(attr, type)) 516 return failure(); 517 518 // Check for the right kind of attribute. 519 result = attr.dyn_cast<AttrType>(); 520 if (!result) 521 return emitError(loc, "invalid kind of attribute specified"); 522 523 attrs.append(attrName, result); 524 return success(); 525 } 526 527 /// Parse a named dictionary into 'result' if it is present. 528 virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; 529 530 /// Parse a named dictionary into 'result' if the `attributes` keyword is 531 /// present. 532 virtual ParseResult 533 parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; 534 535 /// Parse an affine map instance into 'map'. 536 virtual ParseResult parseAffineMap(AffineMap &map) = 0; 537 538 /// Parse an integer set instance into 'set'. 539 virtual ParseResult printIntegerSet(IntegerSet &set) = 0; 540 541 //===--------------------------------------------------------------------===// 542 // Identifier Parsing 543 //===--------------------------------------------------------------------===// 544 545 /// Parse an @-identifier and store it (without the '@' symbol) in a string 546 /// attribute named 'attrName'. parseSymbolName(StringAttr & result,StringRef attrName,NamedAttrList & attrs)547 ParseResult parseSymbolName(StringAttr &result, StringRef attrName, 548 NamedAttrList &attrs) { 549 if (failed(parseOptionalSymbolName(result, attrName, attrs))) 550 return emitError(getCurrentLocation()) 551 << "expected valid '@'-identifier for symbol name"; 552 return success(); 553 } 554 555 /// Parse an optional @-identifier and store it (without the '@' symbol) in a 556 /// string attribute named 'attrName'. 557 virtual ParseResult parseOptionalSymbolName(StringAttr &result, 558 StringRef attrName, 559 NamedAttrList &attrs) = 0; 560 561 //===--------------------------------------------------------------------===// 562 // Operand Parsing 563 //===--------------------------------------------------------------------===// 564 565 /// This is the representation of an operand reference. 566 struct OperandType { 567 llvm::SMLoc location; // Location of the token. 568 StringRef name; // Value name, e.g. %42 or %abc 569 unsigned number; // Number, e.g. 12 for an operand like %xyz#12 570 }; 571 572 /// Parse a single operand. 573 virtual ParseResult parseOperand(OperandType &result) = 0; 574 575 /// Parse a single operand if present. 576 virtual OptionalParseResult parseOptionalOperand(OperandType &result) = 0; 577 578 /// These are the supported delimiters around operand lists and region 579 /// argument lists, used by parseOperandList and parseRegionArgumentList. 580 enum class Delimiter { 581 /// Zero or more operands with no delimiters. 582 None, 583 /// Parens surrounding zero or more operands. 584 Paren, 585 /// Square brackets surrounding zero or more operands. 586 Square, 587 /// Parens supporting zero or more operands, or nothing. 588 OptionalParen, 589 /// Square brackets supporting zero or more ops, or nothing. 590 OptionalSquare, 591 }; 592 593 /// Parse zero or more SSA comma-separated operand references with a specified 594 /// surrounding delimiter, and an optional required operand count. 595 virtual ParseResult 596 parseOperandList(SmallVectorImpl<OperandType> &result, 597 int requiredOperandCount = -1, 598 Delimiter delimiter = Delimiter::None) = 0; parseOperandList(SmallVectorImpl<OperandType> & result,Delimiter delimiter)599 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, 600 Delimiter delimiter) { 601 return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter); 602 } 603 604 /// Parse zero or more trailing SSA comma-separated trailing operand 605 /// references with a specified surrounding delimiter, and an optional 606 /// required operand count. A leading comma is expected before the operands. 607 virtual ParseResult 608 parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 609 int requiredOperandCount = -1, 610 Delimiter delimiter = Delimiter::None) = 0; parseTrailingOperandList(SmallVectorImpl<OperandType> & result,Delimiter delimiter)611 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, 612 Delimiter delimiter) { 613 return parseTrailingOperandList(result, /*requiredOperandCount=*/-1, 614 delimiter); 615 } 616 617 /// Resolve an operand to an SSA value, emitting an error on failure. 618 virtual ParseResult resolveOperand(const OperandType &operand, Type type, 619 SmallVectorImpl<Value> &result) = 0; 620 621 /// Resolve a list of operands to SSA values, emitting an error on failure, or 622 /// appending the results to the list on success. This method should be used 623 /// when all operands have the same type. resolveOperands(ArrayRef<OperandType> operands,Type type,SmallVectorImpl<Value> & result)624 ParseResult resolveOperands(ArrayRef<OperandType> operands, Type type, 625 SmallVectorImpl<Value> &result) { 626 for (auto elt : operands) 627 if (resolveOperand(elt, type, result)) 628 return failure(); 629 return success(); 630 } 631 632 /// Resolve a list of operands and a list of operand types to SSA values, 633 /// emitting an error and returning failure, or appending the results 634 /// to the list on success. resolveOperands(ArrayRef<OperandType> operands,ArrayRef<Type> types,llvm::SMLoc loc,SmallVectorImpl<Value> & result)635 ParseResult resolveOperands(ArrayRef<OperandType> operands, 636 ArrayRef<Type> types, llvm::SMLoc loc, 637 SmallVectorImpl<Value> &result) { 638 if (operands.size() != types.size()) 639 return emitError(loc) 640 << operands.size() << " operands present, but expected " 641 << types.size(); 642 643 for (unsigned i = 0, e = operands.size(); i != e; ++i) 644 if (resolveOperand(operands[i], types[i], result)) 645 return failure(); 646 return success(); 647 } 648 template <typename Operands> resolveOperands(Operands && operands,Type type,llvm::SMLoc loc,SmallVectorImpl<Value> & result)649 ParseResult resolveOperands(Operands &&operands, Type type, llvm::SMLoc loc, 650 SmallVectorImpl<Value> &result) { 651 return resolveOperands(std::forward<Operands>(operands), 652 ArrayRef<Type>(type), loc, result); 653 } 654 template <typename Operands, typename Types> 655 std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult> resolveOperands(Operands && operands,Types && types,llvm::SMLoc loc,SmallVectorImpl<Value> & result)656 resolveOperands(Operands &&operands, Types &&types, llvm::SMLoc loc, 657 SmallVectorImpl<Value> &result) { 658 size_t operandSize = std::distance(operands.begin(), operands.end()); 659 size_t typeSize = std::distance(types.begin(), types.end()); 660 if (operandSize != typeSize) 661 return emitError(loc) 662 << operandSize << " operands present, but expected " << typeSize; 663 664 for (auto it : llvm::zip(operands, types)) 665 if (resolveOperand(std::get<0>(it), std::get<1>(it), result)) 666 return failure(); 667 return success(); 668 } 669 670 /// Parses an affine map attribute where dims and symbols are SSA operands. 671 /// Operand values must come from single-result sources, and be valid 672 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. 673 virtual ParseResult 674 parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map, 675 StringRef attrName, NamedAttrList &attrs, 676 Delimiter delimiter = Delimiter::Square) = 0; 677 678 //===--------------------------------------------------------------------===// 679 // Region Parsing 680 //===--------------------------------------------------------------------===// 681 682 /// Parses a region. Any parsed blocks are appended to 'region' and must be 683 /// moved to the op regions after the op is created. The first block of the 684 /// region takes 'arguments' of types 'argTypes'. If 'enableNameShadowing' is 685 /// set to true, the argument names are allowed to shadow the names of other 686 /// existing SSA values defined above the region scope. 'enableNameShadowing' 687 /// can only be set to true for regions attached to operations that are 688 /// 'IsolatedFromAbove. 689 virtual ParseResult parseRegion(Region ®ion, 690 ArrayRef<OperandType> arguments = {}, 691 ArrayRef<Type> argTypes = {}, 692 bool enableNameShadowing = false) = 0; 693 694 /// Parses a region if present. 695 virtual OptionalParseResult 696 parseOptionalRegion(Region ®ion, ArrayRef<OperandType> arguments = {}, 697 ArrayRef<Type> argTypes = {}, 698 bool enableNameShadowing = false) = 0; 699 700 /// Parses a region if present. If the region is present, a new region is 701 /// allocated and placed in `region`. If no region is present or on failure, 702 /// `region` remains untouched. 703 virtual OptionalParseResult parseOptionalRegion( 704 std::unique_ptr<Region> ®ion, ArrayRef<OperandType> arguments = {}, 705 ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0; 706 707 /// Parse a region argument, this argument is resolved when calling 708 /// 'parseRegion'. 709 virtual ParseResult parseRegionArgument(OperandType &argument) = 0; 710 711 /// Parse zero or more region arguments with a specified surrounding 712 /// delimiter, and an optional required argument count. Region arguments 713 /// define new values; so this also checks if values with the same names have 714 /// not been defined yet. 715 virtual ParseResult 716 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 717 int requiredOperandCount = -1, 718 Delimiter delimiter = Delimiter::None) = 0; 719 virtual ParseResult parseRegionArgumentList(SmallVectorImpl<OperandType> & result,Delimiter delimiter)720 parseRegionArgumentList(SmallVectorImpl<OperandType> &result, 721 Delimiter delimiter) { 722 return parseRegionArgumentList(result, /*requiredOperandCount=*/-1, 723 delimiter); 724 } 725 726 /// Parse a region argument if present. 727 virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0; 728 729 //===--------------------------------------------------------------------===// 730 // Successor Parsing 731 //===--------------------------------------------------------------------===// 732 733 /// Parse a single operation successor. 734 virtual ParseResult parseSuccessor(Block *&dest) = 0; 735 736 /// Parse an optional operation successor. 737 virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0; 738 739 /// Parse a single operation successor and its operand list. 740 virtual ParseResult 741 parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0; 742 743 //===--------------------------------------------------------------------===// 744 // Type Parsing 745 //===--------------------------------------------------------------------===// 746 747 /// Parse a type. 748 virtual ParseResult parseType(Type &result) = 0; 749 750 /// Parse an optional type. 751 virtual OptionalParseResult parseOptionalType(Type &result) = 0; 752 753 /// Parse a type of a specific type. 754 template <typename TypeT> parseType(TypeT & result)755 ParseResult parseType(TypeT &result) { 756 llvm::SMLoc loc = getCurrentLocation(); 757 758 // Parse any kind of type. 759 Type type; 760 if (parseType(type)) 761 return failure(); 762 763 // Check for the right kind of attribute. 764 result = type.dyn_cast<TypeT>(); 765 if (!result) 766 return emitError(loc, "invalid kind of type specified"); 767 768 return success(); 769 } 770 771 /// Parse a type list. parseTypeList(SmallVectorImpl<Type> & result)772 ParseResult parseTypeList(SmallVectorImpl<Type> &result) { 773 do { 774 Type type; 775 if (parseType(type)) 776 return failure(); 777 result.push_back(type); 778 } while (succeeded(parseOptionalComma())); 779 return success(); 780 } 781 782 /// Parse an arrow followed by a type list. 783 virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0; 784 785 /// Parse an optional arrow followed by a type list. 786 virtual ParseResult 787 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0; 788 789 /// Parse a colon followed by a type. 790 virtual ParseResult parseColonType(Type &result) = 0; 791 792 /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType. 793 template <typename TypeType> parseColonType(TypeType & result)794 ParseResult parseColonType(TypeType &result) { 795 llvm::SMLoc loc = getCurrentLocation(); 796 797 // Parse any kind of type. 798 Type type; 799 if (parseColonType(type)) 800 return failure(); 801 802 // Check for the right kind of attribute. 803 result = type.dyn_cast<TypeType>(); 804 if (!result) 805 return emitError(loc, "invalid kind of type specified"); 806 807 return success(); 808 } 809 810 /// Parse a colon followed by a type list, which must have at least one type. 811 virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0; 812 813 /// Parse an optional colon followed by a type list, which if present must 814 /// have at least one type. 815 virtual ParseResult 816 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0; 817 818 /// Parse a list of assignments of the form 819 /// (%x1 = %y1, %x2 = %y2, ...) parseAssignmentList(SmallVectorImpl<OperandType> & lhs,SmallVectorImpl<OperandType> & rhs)820 ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs, 821 SmallVectorImpl<OperandType> &rhs) { 822 OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); 823 if (!result.hasValue()) 824 return emitError(getCurrentLocation(), "expected '('"); 825 return result.getValue(); 826 } 827 828 virtual OptionalParseResult 829 parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs, 830 SmallVectorImpl<OperandType> &rhs) = 0; 831 832 /// Parse a keyword followed by a type. parseKeywordType(const char * keyword,Type & result)833 ParseResult parseKeywordType(const char *keyword, Type &result) { 834 return failure(parseKeyword(keyword) || parseType(result)); 835 } 836 837 /// Add the specified type to the end of the specified type list and return 838 /// success. This is a helper designed to allow parse methods to be simple 839 /// and chain through || operators. addTypeToList(Type type,SmallVectorImpl<Type> & result)840 ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) { 841 result.push_back(type); 842 return success(); 843 } 844 845 /// Add the specified types to the end of the specified type list and return 846 /// success. This is a helper designed to allow parse methods to be simple 847 /// and chain through || operators. addTypesToList(ArrayRef<Type> types,SmallVectorImpl<Type> & result)848 ParseResult addTypesToList(ArrayRef<Type> types, 849 SmallVectorImpl<Type> &result) { 850 result.append(types.begin(), types.end()); 851 return success(); 852 } 853 854 private: 855 /// Parse either an operand list or a region argument list depending on 856 /// whether isOperandList is true. 857 ParseResult parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, 858 bool isOperandList, 859 int requiredOperandCount, 860 Delimiter delimiter); 861 }; 862 863 //===--------------------------------------------------------------------===// 864 // Dialect OpAsm interface. 865 //===--------------------------------------------------------------------===// 866 867 /// A functor used to set the name of the start of a result group of an 868 /// operation. See 'getAsmResultNames' below for more details. 869 using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>; 870 871 class OpAsmDialectInterface 872 : public DialectInterface::Base<OpAsmDialectInterface> { 873 public: OpAsmDialectInterface(Dialect * dialect)874 OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} 875 876 /// Hooks for getting an alias identifier alias for a given symbol, that is 877 /// not necessarily a part of this dialect. The identifier is used in place of 878 /// the symbol when printing textual IR. These aliases must not contain `.` or 879 /// end with a numeric digit([0-9]+). Returns success if an alias was 880 /// provided, failure otherwise. getAlias(Attribute attr,raw_ostream & os)881 virtual LogicalResult getAlias(Attribute attr, raw_ostream &os) const { 882 return failure(); 883 } getAlias(Type type,raw_ostream & os)884 virtual LogicalResult getAlias(Type type, raw_ostream &os) const { 885 return failure(); 886 } 887 888 /// Get a special name to use when printing the given operation. See 889 /// OpAsmInterface.td#getAsmResultNames for usage details and documentation. getAsmResultNames(Operation * op,OpAsmSetValueNameFn setNameFn)890 virtual void getAsmResultNames(Operation *op, 891 OpAsmSetValueNameFn setNameFn) const {} 892 893 /// Get a special name to use when printing the entry block arguments of the 894 /// region contained by an operation in this dialect. getAsmBlockArgumentNames(Block * block,OpAsmSetValueNameFn setNameFn)895 virtual void getAsmBlockArgumentNames(Block *block, 896 OpAsmSetValueNameFn setNameFn) const {} 897 }; 898 } // end namespace mlir 899 900 //===--------------------------------------------------------------------===// 901 // Operation OpAsm interface. 902 //===--------------------------------------------------------------------===// 903 904 /// The OpAsmOpInterface, see OpAsmInterface.td for more details. 905 #include "mlir/IR/OpAsmInterface.h.inc" 906 907 #endif 908