1 //===- mlir-linalg-ods-gen.cpp - Linalg ODS generation from math form -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the implementation for the Tensor Comprehension-inspired 10 // parser and ODS pretty-printer for specifying Linalg "named ops" from a 11 // mathematical form. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/IR/AffineExpr.h" 16 #include "mlir/IR/AffineMap.h" 17 #include "mlir/IR/MLIRContext.h" 18 #include "mlir/IR/OpImplementation.h" 19 #include "mlir/Support/FileUtilities.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Support/LogicalResult.h" 22 #include "llvm/ADT/DenseMap.h" 23 #include "llvm/ADT/Optional.h" 24 #include "llvm/ADT/STLExtras.h" 25 #include "llvm/ADT/SetVector.h" 26 #include "llvm/ADT/SmallVector.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/StringRef.h" 29 #include "llvm/ADT/StringSwitch.h" 30 #include "llvm/ADT/Twine.h" 31 #include "llvm/Support/Casting.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/Debug.h" 34 #include "llvm/Support/FormatVariadic.h" 35 #include "llvm/Support/MemoryBuffer.h" 36 #include "llvm/Support/SourceMgr.h" 37 #include "llvm/Support/ToolOutputFile.h" 38 39 #include <map> 40 #include <set> 41 42 #define DEBUG_TYPE "linalg-ods-gen" 43 44 static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen"); 45 46 // Commandline options 47 static llvm::cl::opt<std::string> 48 inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"), 49 llvm::cl::init("-"), llvm::cl::value_desc("filename")); 50 51 static llvm::cl::opt<std::string> 52 outputFilename("o", llvm::cl::desc("Output filename"), 53 llvm::cl::value_desc("filename"), llvm::cl::init("-")); 54 55 static llvm::cl::opt<bool> 56 genODSDecl("gen-ods-decl", llvm::cl::desc("Emit the ODS ops declarations."), 57 llvm::cl::cat(ODSGenCat)); 58 59 static llvm::cl::opt<bool> 60 genODSImpl("gen-impl", llvm::cl::desc("Emit the ops implementations"), 61 llvm::cl::init(false), llvm::cl::cat(ODSGenCat)); 62 63 static llvm::cl::opt<bool> testEmitIncludeTdHeader( 64 "test-emit-include-td-header", 65 llvm::cl::desc("Include LinalgStructuredOps.td for end-to-end " 66 "tblgen testing."), 67 llvm::cl::init(false), llvm::cl::cat(ODSGenCat)); 68 69 using llvm::SMLoc; 70 using llvm::StringRef; 71 using llvm::Twine; 72 73 using namespace mlir; 74 75 //===----------------------------------------------------------------------===// 76 // Special "op aliases" substitutions. 77 //===----------------------------------------------------------------------===// 78 79 /// Perform substitutions of known special ops. 80 /// This is a poor man's way of achieving "op aliases": i.e. giving an op a 81 /// name. 82 /// This is hacky and temporary until migration to the python opdsl is complete. 83 static void substituteOpAliases(std::string &expressionsStr) { 84 for (auto kvp : SmallVector<std::pair<std::string, std::string>>{ 85 {"b.create<CmpIOpSGT>(", "b.create<CmpIOp>(CmpIPredicate::sgt, "}, 86 {"b.create<CmpFOpOGT>(", "b.create<CmpFOp>(CmpFPredicate::OGT, "}, 87 {"b.create<CmpFOpOLT>(", "b.create<CmpFOp>(CmpFPredicate::OLT, "}, 88 {"b.create<SignExtendIOp32>(", 89 "b.create<SignExtendIOp>(b.getI32Type(), "}, 90 }) { 91 size_t pos = 0; 92 while ((pos = expressionsStr.find(kvp.first, pos)) != std::string::npos) { 93 expressionsStr.replace(pos, kvp.first.size(), kvp.second); 94 pos += kvp.second.size(); 95 } 96 } 97 } 98 99 //===----------------------------------------------------------------------===// 100 // Lexer 101 //===----------------------------------------------------------------------===// 102 103 namespace { 104 /// This class represents a specific token in the input format. 105 class Token { 106 public: 107 enum class Kind { 108 // Markers. 109 eof, 110 error, 111 112 // Tokens with no info. 113 colon, 114 comma, 115 doc_str, 116 equal, 117 gt, 118 l_brace, 119 l_paren, 120 l_square, 121 lt, 122 minus, 123 plus, 124 question, 125 r_brace, 126 r_paren, 127 r_square, 128 semicolon, 129 star, 130 131 // Keywords. 132 kw_def, 133 FIRST_KEYWORD = kw_def, 134 kw_ods_def, 135 kw_implements_interface, 136 kw_attr_def, 137 kw_floordiv, 138 kw_ceildiv, 139 kw_mod, 140 LAST_KEYWORD = kw_mod, 141 142 // String valued tokens. 143 id, 144 integer, 145 }; 146 147 Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} 148 149 /// Return the bytes that make up this token. 150 StringRef getSpelling() const { return spelling; } 151 152 /// Return the kind of this token. 153 Kind getKind() const { return kind; } 154 155 /// Return a location for this token. 156 llvm::SMLoc getLoc() const { 157 return llvm::SMLoc::getFromPointer(spelling.data()); 158 } 159 160 /// Return if this token is a keyword. 161 bool isKeyword() const { 162 return kind >= Kind::FIRST_KEYWORD && kind <= Kind::LAST_KEYWORD; 163 } 164 bool is(Kind k) const { return kind == k; } 165 bool isNot(Kind k) const { return kind != k; } 166 167 Optional<uint64_t> getUInt64IntegerValue() const { 168 bool isHex = spelling.size() > 1 && spelling[1] == 'x'; 169 170 uint64_t result = 0; 171 if (spelling.getAsInteger(isHex ? 0 : 10, result)) 172 return None; 173 return result; 174 } 175 176 private: 177 /// Discriminator that indicates the kind of token this is. 178 Kind kind; 179 180 /// A reference to the entire token contents; this is always a pointer into 181 /// a memory buffer owned by the source manager. 182 StringRef spelling; 183 }; 184 185 /// This class implements a simple lexer. 186 class Lexer { 187 public: 188 Lexer(llvm::SourceMgr &mgr); 189 190 /// Lex the next token and return it. 191 Token lexToken(); 192 193 /// Emit an error to the lexer with the given location and message. 194 Token emitError(llvm::SMLoc loc, const Twine &msg); 195 Token emitError(const char *loc, const Twine &msg); 196 197 /// Change the position of the lexer cursor. The next token we lex will start 198 /// at the designated point in the input. 199 void resetPointer(const char *newPtr) { curPtr = newPtr; } 200 201 private: 202 Token formToken(Token::Kind kind, const char *tokStart) { 203 return Token(kind, StringRef(tokStart, curPtr - tokStart)); 204 } 205 206 /// Return the next character in the stream. 207 int getNextChar(); 208 209 /// Lex an identifier. 210 Token lexIdentifier(const char *tokStart); 211 212 // Lex an integer. 213 Token lexInteger(const char *tokStart); 214 215 // Lex a string. 216 Token lexString(const char *tokStart); 217 218 // Skip a comment line, starting with a '//'. 219 void skipComment(); 220 221 llvm::SourceMgr &srcMgr; 222 StringRef curBuffer; 223 const char *curPtr; 224 }; 225 } // end anonymous namespace 226 227 Lexer::Lexer(llvm::SourceMgr &mgr) : srcMgr(mgr) { 228 curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer(); 229 curPtr = curBuffer.begin(); 230 } 231 232 Token Lexer::emitError(llvm::SMLoc loc, const Twine &msg) { 233 srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); 234 return formToken(Token::Kind::error, loc.getPointer()); 235 } 236 Token Lexer::emitError(const char *loc, const Twine &msg) { 237 return emitError(llvm::SMLoc::getFromPointer(loc), msg); 238 } 239 240 int Lexer::getNextChar() { 241 char curChar = *curPtr++; 242 switch (curChar) { 243 default: 244 return (unsigned char)curChar; 245 case 0: { 246 // A nul character in the stream is either the end of the current buffer 247 // or a random nul in the file. Disambiguate that here. 248 if (curPtr - 1 != curBuffer.end()) 249 return 0; 250 251 // Otherwise, return end of file. 252 --curPtr; 253 return EOF; 254 } 255 case '\n': 256 case '\r': 257 // Handle the newline character by ignoring it and incrementing the line 258 // count. However, be careful about 'dos style' files with \n\r in them. 259 // Only treat a \n\r or \r\n as a single line. 260 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) 261 ++curPtr; 262 return '\n'; 263 } 264 } 265 266 Token Lexer::lexToken() { 267 while (true) { 268 const char *tokStart = curPtr; 269 270 // This always consumes at least one character. 271 int curChar = getNextChar(); 272 switch (curChar) { 273 default: 274 // Handle identifiers: [a-zA-Z_] 275 if (isalpha(curChar) || curChar == '_') 276 return lexIdentifier(tokStart); 277 278 // Handle integers: [0-9] 279 if (isdigit(curChar)) 280 return lexInteger(tokStart); 281 282 // Unknown character, emit an error. 283 return emitError(tokStart, "unexpected character"); 284 285 case EOF: 286 // Return EOF denoting the end of lexing. 287 return formToken(Token::Kind::eof, tokStart); 288 289 // Lex punctuation. 290 case ':': 291 return formToken(Token::Kind::colon, tokStart); 292 case ',': 293 return formToken(Token::Kind::comma, tokStart); 294 case '=': 295 return formToken(Token::Kind::equal, tokStart); 296 case '{': 297 return formToken(Token::Kind::l_brace, tokStart); 298 case '(': 299 return formToken(Token::Kind::l_paren, tokStart); 300 case '[': 301 return formToken(Token::Kind::l_square, tokStart); 302 case '}': 303 return formToken(Token::Kind::r_brace, tokStart); 304 case ')': 305 return formToken(Token::Kind::r_paren, tokStart); 306 case ']': 307 return formToken(Token::Kind::r_square, tokStart); 308 case '<': 309 return formToken(Token::Kind::lt, tokStart); 310 case '>': 311 return formToken(Token::Kind::gt, tokStart); 312 case '+': 313 return formToken(Token::Kind::plus, tokStart); 314 case '-': 315 return formToken(Token::Kind::minus, tokStart); 316 case ';': 317 return formToken(Token::Kind::semicolon, tokStart); 318 case '*': 319 return formToken(Token::Kind::star, tokStart); 320 case '?': 321 return formToken(Token::Kind::question, tokStart); 322 case '"': 323 return lexString(tokStart); 324 case '/': 325 if (*curPtr == '/') { 326 skipComment(); 327 continue; 328 } 329 // Unknown character, emit an error. 330 return emitError(tokStart, "unexpected character: not a comment"); 331 332 // Ignore whitespace characters. 333 case 0: 334 case ' ': 335 case '\t': 336 case '\n': 337 return lexToken(); 338 } 339 } 340 } 341 342 Token Lexer::lexIdentifier(const char *tokStart) { 343 // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* 344 while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') 345 ++curPtr; 346 347 // Check to see if this identifier is a keyword. 348 StringRef str(tokStart, curPtr - tokStart); 349 Token::Kind kind = 350 StringSwitch<Token::Kind>(str) 351 .Case("attr", Token::Kind::kw_attr_def) 352 .Case("def", Token::Kind::kw_def) 353 .Case("ods_def", Token::Kind::kw_ods_def) 354 .Case("implements_interface", Token::Kind::kw_implements_interface) 355 .Case("floordiv", Token::Kind::kw_floordiv) 356 .Case("ceildiv", Token::Kind::kw_ceildiv) 357 .Case("mod", Token::Kind::kw_mod) 358 .Default(Token::Kind::id); 359 360 return Token(kind, str); 361 } 362 363 Token Lexer::lexInteger(const char *tokStart) { 364 // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* 365 while (isdigit(*curPtr)) 366 ++curPtr; 367 368 StringRef str(tokStart, curPtr - tokStart); 369 return Token(Token::Kind::integer, str); 370 } 371 372 Token Lexer::lexString(const char *tokStart) { 373 assert(curPtr[-1] == '"'); 374 375 if (*curPtr == '"' && *(curPtr + 1) == '"') { 376 curPtr += 2; 377 while (true) { 378 switch (*curPtr++) { 379 case '"': 380 if (*curPtr == '"' && *(curPtr + 1) == '"') { 381 Token token(Token::Kind::doc_str, 382 StringRef(tokStart + 3, curPtr - tokStart - 4)); 383 curPtr += 2; 384 return token; 385 } 386 continue; 387 case 0: 388 // If this is a random nul character in the middle of the doc string, 389 // just include it. If it is the end of file, then it is an error. 390 if (curPtr - 1 != curBuffer.end()) 391 continue; 392 return emitError(curPtr - 1, "expected '\"\"\"' to end doc string"); 393 default: 394 continue; 395 } 396 } 397 } 398 399 return emitError(curPtr - 1, "expected '\"\"\"' to start doc string"); 400 } 401 402 /// Skip a comment line, starting with a '//'. 403 void Lexer::skipComment() { 404 // Advance over the second '/' in a '//' comment. 405 assert(*curPtr == '/'); 406 ++curPtr; 407 408 while (true) { 409 switch (*curPtr++) { 410 case '\n': 411 case '\r': 412 // Newline is end of comment. 413 return; 414 case 0: 415 // If this is the end of the buffer, end the comment. 416 if (curPtr - 1 == curBuffer.end()) { 417 --curPtr; 418 return; 419 } 420 LLVM_FALLTHROUGH; 421 default: 422 // Skip over other characters. 423 break; 424 } 425 } 426 } 427 428 namespace { 429 430 class Parser { 431 public: 432 Parser(llvm::SourceMgr &mgr, MLIRContext *ctx) 433 : lexer(mgr), curToken(lexer.lexToken()), context(ctx) {} 434 435 //===--------------------------------------------------------------------===// 436 // Lexer Utilities 437 //===--------------------------------------------------------------------===// 438 439 LogicalResult parseInteger(uint64_t &value) { 440 if (!curToken.is(Token::Kind::integer)) 441 return emitError(curToken.getLoc(), "expected integer"); 442 value = curToken.getUInt64IntegerValue().getValue(); 443 consumeToken(); 444 return success(); 445 } 446 447 /// Advance the current lexer onto the next token. 448 void consumeToken() { 449 assert(curToken.getKind() != Token::Kind::eof && 450 curToken.getKind() != Token::Kind::error && 451 "shouldn't advance past EOF or errors"); 452 curToken = lexer.lexToken(); 453 } 454 455 void consumeToken(Token::Kind kind) { 456 assert(curToken.getKind() == kind && "unexpected token"); 457 curToken = lexer.lexToken(); 458 } 459 460 LogicalResult parseToken(Token::Kind kind, const Twine &msg) { 461 if (curToken.getKind() != kind) 462 return emitError(curToken.getLoc(), msg); 463 consumeToken(); 464 return success(); 465 } 466 467 /// Parses an optional token and returns failure if failed to parse. 468 LogicalResult parseOptionalToken(Token::Kind kind) { 469 return success(consumeIf(kind)); 470 } 471 472 LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) { 473 lexer.emitError(loc, msg); 474 return failure(); 475 } 476 477 LogicalResult emitError(const Twine &msg) { 478 return emitError(curToken.getLoc(), msg); 479 } 480 481 bool consumeIf(Token::Kind kind) { 482 if (curToken.isNot(kind)) 483 return false; 484 consumeToken(kind); 485 return true; 486 } 487 488 LogicalResult 489 parseCommaSeparatedList(llvm::function_ref<ParseResult()> parseElement) { 490 // Non-empty case starts with an element. 491 if (parseElement()) 492 return failure(); 493 494 // Otherwise we have a list of comma separated elements. 495 while (consumeIf(Token::Kind::comma)) { 496 if (parseElement()) 497 return failure(); 498 } 499 return success(); 500 } 501 502 LogicalResult 503 parseCommaSeparatedListUntil(Token::Kind rightToken, 504 llvm::function_ref<ParseResult()> parseElement, 505 bool allowEmptyList) { 506 // Handle the empty case. 507 if (curToken.is(rightToken)) { 508 if (!allowEmptyList) 509 return emitError("expected list element"); 510 consumeToken(rightToken); 511 return success(); 512 } 513 514 if (failed(parseCommaSeparatedList(parseElement)) || 515 failed( 516 parseToken(rightToken, "expected ',' or right-terminating token"))) 517 return failure(); 518 519 return success(); 520 } 521 522 Lexer lexer; 523 Token curToken; 524 MLIRContext *context; 525 }; 526 } // namespace 527 528 /// Encodes an attribute use of the form: 529 /// 530 /// index-list ::= integer-literal (`,` integer-literal)* 531 /// attr-use ::= bare-id `[` index-list `]` 532 struct AttrUse { 533 // Referenced attribute 534 StringRef attrName; 535 // Indices into the attribute 536 SmallVector<uint64_t, 4> indices; 537 /// Affine symbol for this usage. 538 /// This is represented as an affine symbol because at the time of parsing the 539 /// spec and generating the op's ODS/C++, we don't know the concrete constant 540 /// value. But they should be replaced with constants read from the attribute 541 /// and thus folded away for concrete op instances. 542 AffineExpr symbol; 543 544 std::string getKey() { 545 SmallVector<std::string, 4> indexStrs; 546 for (uint64_t index : indices) 547 indexStrs.push_back(std::to_string(index)); 548 return llvm::formatv("{0}[{1}]", attrName, llvm::join(indexStrs, ",")); 549 } 550 }; 551 552 //===----------------------------------------------------------------------===// 553 // Affine parsing. 554 //===----------------------------------------------------------------------===// 555 556 namespace { 557 558 /// Lower precedence ops (all at the same precedence level). LNoOp is false in 559 /// the boolean sense. 560 enum AffineLowPrecOp { 561 /// Null value. 562 LNoOp, 563 Add, 564 Sub 565 }; 566 567 /// Higher precedence ops - all at the same precedence level. HNoOp is false 568 /// in the boolean sense. 569 enum AffineHighPrecOp { 570 /// Null value. 571 HNoOp, 572 Mul, 573 FloorDiv, 574 CeilDiv, 575 Mod 576 }; 577 578 using AffineDimList = SmallVector<std::pair<StringRef, AffineExpr>, 4>; 579 using AffineSymbolList = SmallVector<std::pair<StringRef, AffineExpr>, 4>; 580 581 /// This is a specialized parser for affine expressions. 582 class AffineParser { 583 public: 584 /// Creates an affine parser that parses tokens from `p`. 585 /// 586 /// The affine parser introduces new dimensions and symbols eagerly as new 587 /// `id` are discovered. To additionally support attribute use `id`s, for a 588 /// parsed `id`, the resolution mechanism proceeds as follows: 589 /// 1. Try to parse `id` as an attribute use (using the `attrUseParsingHook`). 590 /// 2. If unsuccessful, try to match `id` to a known dim or symbol. 591 /// 3. If still unsuccessful, eagerly create a new dim or symbol and add it to 592 /// the known dims or symbols (using the `bareIdParsingHook`). 593 explicit AffineParser( 594 Parser &p, std::function<AffineExpr(StringRef)> bareIdParsingHook, 595 std::function<llvm::Optional<AffineExpr>()> attrUseParsingHook, 596 AffineDimList &dimList, AffineSymbolList &symbolList) 597 : parser(p), bareIdFallback(bareIdParsingHook), 598 attrUseCallback(attrUseParsingHook), dims(dimList), 599 symbols(symbolList) {} 600 601 /// Parse a comma-separated list of affine exprs. 602 SmallVector<AffineExpr, 4> 603 parseAffineExprs(Token::Kind lDelim = Token::Kind::l_paren, 604 Token::Kind rDelim = Token::Kind::r_paren); 605 606 /// Parse a single affine expr.`. 607 AffineExpr parseAffineExpr(); 608 609 private: 610 // Binary affine op parsing. 611 AffineLowPrecOp consumeIfLowPrecOp(); 612 AffineHighPrecOp consumeIfHighPrecOp(); 613 614 // AffineExpr parsing. 615 AffineExpr parseParentheticalExpr(); 616 AffineExpr parseNegateExpression(AffineExpr lhs); 617 AffineExpr parseIntegerExpr(); 618 AffineExpr parseAttrUseOrBareIdExpr(); 619 AffineExpr parseBareIdExpr(); 620 621 AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, 622 AffineExpr rhs, SMLoc opLoc); 623 AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, 624 AffineExpr rhs); 625 AffineExpr parseAffineOperandExpr(AffineExpr lhs); 626 AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); 627 AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, 628 SMLoc llhsOpLoc); 629 630 Parser &parser; 631 std::function<AffineExpr(StringRef)> bareIdFallback; 632 std::function<llvm::Optional<AffineExpr>()> attrUseCallback; 633 AffineDimList &dims; 634 AffineSymbolList &symbols; 635 }; 636 } // end anonymous namespace 637 638 /// Create an affine binary high precedence op expression (mul's, div's, mod). 639 /// opLoc is the location of the op token to be used to report errors 640 /// for non-conforming expressions. 641 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, 642 AffineExpr lhs, AffineExpr rhs, 643 SMLoc opLoc) { 644 switch (op) { 645 case Mul: 646 if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { 647 (void)parser.emitError( 648 opLoc, "non-affine expression: at least one of the multiply " 649 "operands has to be either a constant or symbolic"); 650 return nullptr; 651 } 652 return lhs * rhs; 653 case FloorDiv: 654 if (!rhs.isSymbolicOrConstant()) { 655 (void)parser.emitError(opLoc, 656 "non-affine expression: right operand of floordiv " 657 "has to be either a constant or symbolic"); 658 return nullptr; 659 } 660 return lhs.floorDiv(rhs); 661 case CeilDiv: 662 if (!rhs.isSymbolicOrConstant()) { 663 (void)parser.emitError(opLoc, 664 "non-affine expression: right operand of ceildiv " 665 "has to be either a constant or symbolic"); 666 return nullptr; 667 } 668 return lhs.ceilDiv(rhs); 669 case Mod: 670 if (!rhs.isSymbolicOrConstant()) { 671 (void)parser.emitError(opLoc, 672 "non-affine expression: right operand of mod " 673 "has to be either a constant or symbolic"); 674 return nullptr; 675 } 676 return lhs % rhs; 677 case HNoOp: 678 llvm_unreachable("can't create affine expression for null high prec op"); 679 return nullptr; 680 } 681 llvm_unreachable("Unknown AffineHighPrecOp"); 682 } 683 684 /// Create an affine binary low precedence op expression (add, sub). 685 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, 686 AffineExpr lhs, AffineExpr rhs) { 687 switch (op) { 688 case AffineLowPrecOp::Add: 689 return lhs + rhs; 690 case AffineLowPrecOp::Sub: 691 return lhs - rhs; 692 case AffineLowPrecOp::LNoOp: 693 llvm_unreachable("can't create affine expression for null low prec op"); 694 return nullptr; 695 } 696 llvm_unreachable("Unknown AffineLowPrecOp"); 697 } 698 699 /// Consume this token if it is a lower precedence affine op (there are only 700 /// two precedence levels). 701 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { 702 switch (parser.curToken.getKind()) { 703 case Token::Kind::plus: 704 parser.consumeToken(); 705 return AffineLowPrecOp::Add; 706 case Token::Kind::minus: 707 parser.consumeToken(); 708 return AffineLowPrecOp::Sub; 709 default: 710 return AffineLowPrecOp::LNoOp; 711 } 712 } 713 714 /// Consume this token if it is a higher precedence affine op (there are only 715 /// two precedence levels) 716 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { 717 switch (parser.curToken.getKind()) { 718 case Token::Kind::star: 719 parser.consumeToken(Token::Kind::star); 720 return Mul; 721 case Token::Kind::kw_floordiv: 722 parser.consumeToken(Token::Kind::kw_floordiv); 723 return FloorDiv; 724 case Token::Kind::kw_ceildiv: 725 parser.consumeToken(Token::Kind::kw_ceildiv); 726 return CeilDiv; 727 case Token::Kind::kw_mod: 728 parser.consumeToken(Token::Kind::kw_mod); 729 return Mod; 730 default: 731 return HNoOp; 732 } 733 } 734 735 /// Parse a high precedence op expression list: mul, div, and mod are high 736 /// precedence binary ops, i.e., parse a 737 /// expr_1 op_1 expr_2 op_2 ... expr_n 738 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). 739 /// All affine binary ops are left associative. 740 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is 741 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is 742 /// null. llhsOpLoc is the location of the llhsOp token that will be used to 743 /// report an error for non-conforming expressions. 744 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, 745 AffineHighPrecOp llhsOp, 746 SMLoc llhsOpLoc) { 747 AffineExpr lhs = parseAffineOperandExpr(llhs); 748 if (!lhs) 749 return nullptr; 750 751 // Found an LHS. Parse the remaining expression. 752 auto opLoc = parser.curToken.getLoc(); 753 if (AffineHighPrecOp op = consumeIfHighPrecOp()) { 754 if (llhs) { 755 AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); 756 if (!expr) 757 return nullptr; 758 return parseAffineHighPrecOpExpr(expr, op, opLoc); 759 } 760 // No LLHS, get RHS 761 return parseAffineHighPrecOpExpr(lhs, op, opLoc); 762 } 763 764 // This is the last operand in this expression. 765 if (llhs) 766 return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); 767 768 // No llhs, 'lhs' itself is the expression. 769 return lhs; 770 } 771 772 /// Parse an affine expression inside parentheses. 773 /// 774 /// affine-expr ::= `(` affine-expr `)` 775 AffineExpr AffineParser::parseParentheticalExpr() { 776 if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) 777 return nullptr; 778 if (parser.curToken.is(Token::Kind::r_paren)) 779 return ((void)parser.emitError("no expression inside parentheses"), 780 nullptr); 781 782 auto expr = parseAffineExpr(); 783 if (!expr) 784 return nullptr; 785 if (failed(parser.parseToken(Token::Kind::r_paren, "expected ')'"))) 786 return nullptr; 787 788 return expr; 789 } 790 791 /// Parse the negation expression. 792 /// 793 /// affine-expr ::= `-` affine-expr 794 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { 795 if (failed(parser.parseToken(Token::Kind::minus, "expected '-'"))) 796 return nullptr; 797 798 AffineExpr operand = parseAffineOperandExpr(lhs); 799 // Since negation has the highest precedence of all ops (including high 800 // precedence ops) but lower than parentheses, we are only going to use 801 // parseAffineOperandExpr instead of parseAffineExpr here. 802 if (!operand) 803 // Extra error message although parseAffineOperandExpr would have 804 // complained. Leads to a better diagnostic. 805 return ((void)parser.emitError("missing operand of negation"), nullptr); 806 return (-1) * operand; 807 } 808 809 AffineExpr AffineParser::parseAttrUseOrBareIdExpr() { 810 if (llvm::Optional<AffineExpr> attrUse = attrUseCallback()) 811 return attrUse.getValue(); 812 return parseBareIdExpr(); 813 } 814 815 /// Parse a bare id that may appear in an affine expression. 816 /// 817 /// affine-expr ::= bare-id 818 AffineExpr AffineParser::parseBareIdExpr() { 819 if (parser.curToken.isNot(Token::Kind::id)) 820 return ((void)parser.emitError("expected id"), nullptr); 821 822 StringRef sRef = parser.curToken.getSpelling(); 823 for (auto &list : {dims, symbols}) { 824 for (auto entry : list) { 825 if (entry.first == sRef) { 826 parser.consumeToken(Token::Kind::id); 827 return entry.second; 828 } 829 } 830 } 831 832 // Not found, check fallback path. 833 AffineExpr expr = bareIdFallback(sRef); 834 if (expr) { 835 parser.consumeToken(Token::Kind::id); 836 return expr; 837 } 838 839 return ((void)parser.emitError("use of undeclared id"), nullptr); 840 } 841 842 /// Parse a positive integral constant appearing in an affine expression. 843 /// 844 /// affine-expr ::= integer-literal 845 AffineExpr AffineParser::parseIntegerExpr() { 846 auto val = parser.curToken.getUInt64IntegerValue(); 847 if (!val.hasValue() || (int64_t)val.getValue() < 0) 848 return ((void)parser.emitError("constant too large for index"), nullptr); 849 850 parser.consumeToken(Token::Kind::integer); 851 return getAffineConstantExpr((int64_t)val.getValue(), parser.context); 852 } 853 854 /// Parses an expression that can be a valid operand of an affine expression. 855 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary 856 /// operator, the rhs of which is being parsed. This is used to determine 857 /// whether an error should be emitted for a missing right operand. 858 // Eg: for an expression without parentheses (like i + j + k + l), each 859 // of the four identifiers is an operand. For i + j*k + l, j*k is not an 860 // operand expression, it's an op expression and will be parsed via 861 // parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and 862 // -l are valid operands that will be parsed by this function. 863 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { 864 switch (parser.curToken.getKind()) { 865 case Token::Kind::id: 866 return parseAttrUseOrBareIdExpr(); 867 case Token::Kind::integer: 868 return parseIntegerExpr(); 869 case Token::Kind::l_paren: 870 return parseParentheticalExpr(); 871 case Token::Kind::minus: 872 return parseNegateExpression(lhs); 873 case Token::Kind::kw_ceildiv: 874 case Token::Kind::kw_floordiv: 875 case Token::Kind::kw_mod: 876 case Token::Kind::plus: 877 case Token::Kind::star: 878 if (lhs) 879 (void)parser.emitError("missing right operand of binary operator"); 880 else 881 (void)parser.emitError("missing left operand of binary operator"); 882 return nullptr; 883 default: 884 if (lhs) 885 (void)parser.emitError("missing right operand of binary operator"); 886 else 887 (void)parser.emitError("expected affine expression"); 888 return nullptr; 889 } 890 } 891 892 /// Parse affine expressions that are bare-id's, integer constants, 893 /// parenthetical affine expressions, and affine op expressions that are a 894 /// composition of those. 895 /// 896 /// All binary op's associate from left to right. 897 /// 898 /// {add, sub} have lower precedence than {mul, div, and mod}. 899 /// 900 /// Add, sub'are themselves at the same precedence level. Mul, floordiv, 901 /// ceildiv, and mod are at the same higher precedence level. Negation has 902 /// higher precedence than any binary op. 903 /// 904 /// llhs: the affine expression appearing on the left of the one being parsed. 905 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, 906 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned 907 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left 908 /// associativity. 909 /// 910 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function 911 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where 912 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). 913 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, 914 AffineLowPrecOp llhsOp) { 915 AffineExpr lhs; 916 if (!(lhs = parseAffineOperandExpr(llhs))) 917 return nullptr; 918 919 // Found an LHS. Deal with the ops. 920 if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { 921 if (llhs) { 922 AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); 923 return parseAffineLowPrecOpExpr(sum, lOp); 924 } 925 // No LLHS, get RHS and form the expression. 926 return parseAffineLowPrecOpExpr(lhs, lOp); 927 } 928 auto opLoc = parser.curToken.getLoc(); 929 if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { 930 // We have a higher precedence op here. Get the rhs operand for the llhs 931 // through parseAffineHighPrecOpExpr. 932 AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); 933 if (!highRes) 934 return nullptr; 935 936 // If llhs is null, the product forms the first operand of the yet to be 937 // found expression. If non-null, the op to associate with llhs is llhsOp. 938 AffineExpr expr = 939 llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; 940 941 // Recurse for subsequent low prec op's after the affine high prec op 942 // expression. 943 if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) 944 return parseAffineLowPrecOpExpr(expr, nextOp); 945 return expr; 946 } 947 // Last operand in the expression list. 948 if (llhs) 949 return getAffineBinaryOpExpr(llhsOp, llhs, lhs); 950 // No llhs, 'lhs' itself is the expression. 951 return lhs; 952 } 953 954 /// Parse an affine expression. 955 /// affine-expr ::= `(` affine-expr `)` 956 /// | `-` affine-expr 957 /// | affine-expr `+` affine-expr 958 /// | affine-expr `-` affine-expr 959 /// | affine-expr `*` affine-expr 960 /// | affine-expr `floordiv` affine-expr 961 /// | affine-expr `ceildiv` affine-expr 962 /// | affine-expr `mod` affine-expr 963 /// | bare-id 964 /// | integer-literal 965 /// 966 /// Additional conditions are checked depending on the production. For eg., 967 /// one of the operands for `*` has to be either constant/symbolic; the second 968 /// operand for floordiv, ceildiv, and mod has to be a positive integer. 969 AffineExpr AffineParser::parseAffineExpr() { 970 return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); 971 } 972 973 SmallVector<AffineExpr, 4> AffineParser::parseAffineExprs(Token::Kind lDelim, 974 Token::Kind rDelim) { 975 if (failed(parser.parseToken(lDelim, 976 "expected lDelim at start of affine expr list"))) 977 return {}; 978 979 SmallVector<AffineExpr, 4> exprs; 980 auto parseElt = [&]() -> LogicalResult { 981 auto elt = parseAffineExpr(); 982 exprs.push_back(elt); 983 return elt ? success() : failure(); 984 }; 985 986 if (failed(parser.parseCommaSeparatedListUntil(rDelim, parseElt, 987 /*allowEmptyList=*/true))) 988 llvm_unreachable("Failed AffineExpr parsing"); 989 990 return exprs; 991 } 992 993 //===----------------------------------------------------------------------===// 994 // TC parsing. 995 //===----------------------------------------------------------------------===// 996 997 namespace { 998 999 /// Base class for expressions involved in TC parsing. 1000 struct Expression { 1001 enum class Kind { 1002 Uninitialized = 0, 1003 TensorExpr = 1, 1004 TensorUse = 2, 1005 }; 1006 1007 explicit Expression(Kind k = Kind::Uninitialized) : kind(k) {} 1008 virtual ~Expression() = default; 1009 1010 operator bool() const { return kind != Kind::Uninitialized; } 1011 1012 Kind kind; 1013 }; 1014 1015 /// Encodes a tensor use of the form: 1016 /// 1017 /// affine-expr-list ::= affine-expr (`,` affine-expr)* 1018 /// tensor-use ::= bare-id `(` `)` 1019 /// | bare-id `(` affine-expr-list `)` 1020 /// 1021 /// The affine-expr-list is stored as an AffineMap. 1022 struct TensorUse : public Expression { 1023 TensorUse() : TensorUse("", AffineMap()) {} 1024 TensorUse(StringRef name, AffineMap map) 1025 : Expression(Kind::TensorUse), tensorId(name), indexingMap(map) {} 1026 1027 static bool classof(const Expression *e) { 1028 return e->kind == Kind::TensorUse; 1029 } 1030 1031 bool operator==(const TensorUse &other) const { 1032 return tensorId == other.tensorId && indexingMap == other.indexingMap; 1033 } 1034 1035 /// Visitation function. Performs preorder or postorder traversal depending on 1036 /// `PreOrder` and applies `callback` on each node. 1037 template <typename Lambda, bool PreOrder> void visit(Lambda callback) const; 1038 1039 StringRef tensorId; 1040 AffineMap indexingMap; 1041 }; 1042 1043 /// Encodes a tensor expression of the form: 1044 /// 1045 /// op-spec ::= bare-id `<` reduction-dims-list `>` 1046 /// | bare-id 1047 /// op-arg ::= tensor-expr 1048 /// | tensor-use 1049 /// op-arg-list ::= op-arg (`,` op-arg)* 1050 /// tensor-expr ::= op-spec `(` op-arg-list `)` 1051 /// 1052 /// Underlying op-arg are stored by unique_ptr to base class. 1053 struct TensorExpr : public Expression { 1054 TensorExpr(StringRef name, 1055 SmallVectorImpl<std::unique_ptr<Expression>> &&exprs, 1056 ArrayRef<unsigned> reductionDims) 1057 : Expression(Kind::TensorExpr), operationName(name), 1058 expressions(std::move(exprs)), 1059 reductionDimensions(reductionDims.begin(), reductionDims.end()) {} 1060 1061 static bool classof(const Expression *e) { 1062 return e->kind == Kind::TensorExpr; 1063 } 1064 1065 bool operator==(const TensorExpr &other) const { 1066 if (operationName != other.operationName) 1067 return false; 1068 if (expressions.size() != other.expressions.size()) 1069 return false; 1070 for (unsigned i = 0, e = expressions.size(); i < e; ++i) 1071 if (*expressions[i] != *other.expressions[i]) 1072 return false; 1073 for (unsigned i = 0, e = reductionDimensions.size(); i < e; ++i) 1074 if (reductionDimensions[i] != other.reductionDimensions[i]) 1075 return false; 1076 return true; 1077 } 1078 1079 /// Visitation function. Performs preorder or postorder traversal depending on 1080 /// `PreOrder` and applies `callback` on each node. 1081 template <typename Lambda, bool PreOrder> void visit(Lambda callback) const; 1082 1083 StringRef operationName; 1084 SmallVector<std::unique_ptr<Expression>, 4> expressions; 1085 SetVector<unsigned> reductionDimensions; 1086 }; 1087 1088 /// This is a specialized parser for a TCDef. 1089 /// This maintains the dims it finds in an eager fashion. 1090 class TCParser { 1091 enum class EagerDiscoveryMode { None = 0, Symbols, Dimensions }; 1092 1093 public: 1094 explicit TCParser(Parser &p); 1095 1096 /// Uses the AffineParser to parse the affine exprs used in a tensor 1097 /// definition. If `discoveryMode` is set to Symbols (resp. Dimensions), new 1098 /// symbols (resp. dimensions) are added eagerly. Otherwise, an error is 1099 /// emitted on new identifiers. 1100 SmallVector<AffineExpr, 4> 1101 parseAffineExprs(EagerDiscoveryMode discoveryMode, AffineDimList &dims, 1102 Token::Kind lDelim = Token::Kind::l_paren, 1103 Token::Kind rDelim = Token::Kind::r_paren); 1104 1105 /// Parse the information for a tensor def. 1106 /// All the affine-expr must be dimensionless (i.e. contain only expressions 1107 /// involving symbols and constants), but can otherwise contain arbitrary 1108 /// affine expressions. 1109 LogicalResult parseTensorDef(bool isOutput); 1110 1111 /// Parses a tensor use. 1112 struct ComprehensionParsingState { 1113 /// The number of operands (which includes inputs and outputs) in a 1114 /// comprehension. 1115 size_t numArgs; 1116 AffineDimList dims; 1117 SmallVector<std::unique_ptr<Expression>, 4> expressions; 1118 llvm::DenseMap<TensorUse, unsigned> orderedTensorArgs; 1119 }; 1120 LogicalResult parseTensorUse(TensorUse &result, 1121 ComprehensionParsingState &state); 1122 1123 /// Parses an attribute definition. 1124 LogicalResult parseAttrDef(); 1125 1126 /// Parses an optional attribute use. 1127 LogicalResult parseAttrUse(AttrUse &result); 1128 1129 /// Parses a tensor expression. 1130 LogicalResult parseExpression(TensorUse currentDefinition, 1131 std::unique_ptr<Expression> &result, 1132 ComprehensionParsingState &state); 1133 1134 /// Parse a single comprehension. 1135 LogicalResult parseOneComprehension(StringRef cppOpName, 1136 StringRef linalgOpName, 1137 ComprehensionParsingState &state); 1138 1139 /// Parse and print the information for a TC def. 1140 /// When `gen-ods-decl` is used, this prints the ODS declaration for the TC. 1141 /// When `gen-impl` is used, this prints the C++ implementation for the extra 1142 /// methods defined in ODS (`iterator_types`, `indexing_maps` and 1143 /// `regionBuilder`). 1144 LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os); 1145 1146 /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. 1147 void printODS(llvm::raw_ostream &os, StringRef cppOpName, 1148 StringRef linalgOpName, ArrayRef<StringRef> interfaces, 1149 ComprehensionParsingState &state); 1150 1151 /// Print the C++ StructuredOpsInterface impl of `iterator_types`. 1152 void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName, 1153 ComprehensionParsingState &state); 1154 1155 /// Print methods related to indexing map required attributes. 1156 /// 1157 /// Specifically, this prints the definitions for the following methods: 1158 /// bool hasDynamicIndexingMaps(); 1159 /// LogicalResult verifyIndexingMapRequiredAttributes(); 1160 void printIndexingMapRequiredAttrMethods(llvm::raw_ostream &os, 1161 StringRef cppOpName, 1162 ComprehensionParsingState &state); 1163 1164 /// Print the C++ StructuredOpsInterface impl of `indexing_maps`. 1165 void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName, 1166 ComprehensionParsingState &state); 1167 1168 /// Print the C++ StructuredOpsInterface impl of `regionBuilder`. 1169 void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, 1170 ComprehensionParsingState &state); 1171 1172 /// Print the C++ impl for named ops canonicalizers and folders. 1173 void printCanonicalizersAndFolders(llvm::raw_ostream &os, 1174 StringRef cppOpName); 1175 1176 private: 1177 //===--------------------------------------------------------------------===// 1178 // Internal bookkeeping of tensors. 1179 //===--------------------------------------------------------------------===// 1180 struct RegisteredTensor { 1181 StringRef type; 1182 AffineMap shape; 1183 bool isOutput; 1184 AffineMap indexingMap; 1185 unsigned index; 1186 }; 1187 1188 //===--------------------------------------------------------------------===// 1189 // Internal bookkeeping of attributes. 1190 //===--------------------------------------------------------------------===// 1191 struct RegisteredAttr { 1192 StringRef elementType; 1193 SmallVector<uint64_t, 4> vectorDims; 1194 bool isArray; 1195 bool isOptional; 1196 1197 // Returns the function to get values at the given indices from this 1198 // attribute. 1199 llvm::Optional<std::string> getValueFn(ArrayRef<uint64_t> indices) const; 1200 }; 1201 1202 //===--------------------------------------------------------------------===// 1203 // Per-TC def state. 1204 //===--------------------------------------------------------------------===// 1205 /// Symbols are per TC def. 1206 AffineSymbolList symbols; 1207 1208 /// Attribute usages in all affine expressions. 1209 SmallVector<AttrUse, 8> attrUses; 1210 1211 /// Tensors are per TC def. 1212 llvm::StringMap<RegisteredTensor> registeredTensors; 1213 unsigned nextRegisteredTensorIndex; 1214 1215 /// Attributes are per TC def. 1216 std::map<std::string, RegisteredAttr> registeredAttrs; 1217 1218 /// A map from AttrUse to AffineExpr symbol. 1219 llvm::StringMap<AffineExpr> registeredAttrUseToSymbol; 1220 1221 StringRef docString; 1222 1223 Parser &parser; 1224 }; 1225 } // namespace 1226 1227 namespace llvm { 1228 1229 template <> struct DenseMapInfo<TensorUse> { 1230 static TensorUse getEmptyKey() { return TensorUse("", AffineMap()); } 1231 static TensorUse getTombstoneKey() { 1232 return TensorUse(DenseMapInfo<StringRef>::getTombstoneKey(), 1233 DenseMapInfo<AffineMap>::getTombstoneKey()); 1234 } 1235 static unsigned getHashValue(const TensorUse &val) { 1236 return ::llvm::hash_value(val.tensorId); // don't care about collisions. 1237 } 1238 static bool isEqual(const TensorUse &LHS, const TensorUse &RHS) { 1239 return LHS == RHS; 1240 } 1241 }; 1242 1243 } // namespace llvm 1244 1245 //===----------------------------------------------------------------------===// 1246 // Visitation functions. 1247 //===----------------------------------------------------------------------===// 1248 1249 template <typename Lambda, bool PreOrder> 1250 void visit(const Expression &expr, Lambda callback) { 1251 switch (expr.kind) { 1252 default: 1253 llvm_unreachable("Unexpected kind"); 1254 case Expression::Kind::TensorExpr: 1255 static_cast<const TensorExpr &>(expr).visit<Lambda, PreOrder>(callback); 1256 break; 1257 case Expression::Kind::TensorUse: 1258 static_cast<const TensorUse &>(expr).visit<Lambda, PreOrder>(callback); 1259 break; 1260 } 1261 } 1262 1263 template <typename Lambda> 1264 void visitPreorder(const Expression &expr, Lambda callback) { 1265 visit<Lambda, false>(expr, callback); 1266 } 1267 1268 template <typename Lambda> 1269 void visitPostorder(Expression &expr, Lambda callback) { 1270 visit<Lambda, true>(expr, callback); 1271 } 1272 1273 template <typename Lambda, bool PreOrder> 1274 void TensorExpr::visit(Lambda callback) const { 1275 if (!PreOrder) 1276 callback(*this); 1277 for (auto &e : expressions) 1278 ::visit<Lambda, PreOrder>(*e, callback); 1279 if (PreOrder) 1280 callback(*this); 1281 } 1282 1283 template <typename Lambda, bool PreOrder> 1284 void TensorUse::visit(Lambda callback) const { 1285 callback(*this); 1286 } 1287 1288 //===----------------------------------------------------------------------===// 1289 // TC parsing functions. 1290 //===----------------------------------------------------------------------===// 1291 TCParser::TCParser(Parser &p) 1292 : symbols(), registeredTensors(), nextRegisteredTensorIndex(0), parser(p) {} 1293 1294 /// Uses the AffineParser to parse the affine exprs used in a tensor 1295 /// definition. All identifiers are interpreted as symbols, new symbols are 1296 /// added eagerly. 1297 SmallVector<AffineExpr, 4> 1298 TCParser::parseAffineExprs(EagerDiscoveryMode discoveryMode, 1299 AffineDimList &dims, Token::Kind lDelim, 1300 Token::Kind rDelim) { 1301 auto createAffineBareId = [&](StringRef sRef) { 1302 AffineExpr expr; 1303 if (discoveryMode == EagerDiscoveryMode::Symbols) { 1304 expr = getAffineSymbolExpr(symbols.size(), parser.context); 1305 symbols.emplace_back(sRef, expr); 1306 } else if (discoveryMode == EagerDiscoveryMode::Dimensions) { 1307 expr = getAffineDimExpr(dims.size(), parser.context); 1308 dims.emplace_back(sRef, expr); 1309 } 1310 return expr; 1311 }; 1312 1313 auto tryToParseAttrUse = [&]() -> llvm::Optional<AffineExpr> { 1314 if (!parser.curToken.is(Token::Kind::id)) 1315 return llvm::None; 1316 1317 StringRef attrName = parser.curToken.getSpelling(); 1318 auto it = registeredAttrs.find(attrName.str()); 1319 if (it == registeredAttrs.end()) 1320 return llvm::None; 1321 1322 AttrUse result; 1323 if (failed(parseAttrUse(result))) 1324 return llvm::None; 1325 1326 auto symbolIt = registeredAttrUseToSymbol.find(result.getKey()); 1327 if (symbolIt == registeredAttrUseToSymbol.end()) { 1328 result.symbol = getAffineSymbolExpr(symbols.size(), parser.context); 1329 symbols.emplace_back("<attr-use>", result.symbol); 1330 registeredAttrUseToSymbol[result.getKey()] = result.symbol; 1331 attrUses.push_back(result); 1332 } else { 1333 result.symbol = symbolIt->second; 1334 } 1335 1336 return result.symbol; 1337 }; 1338 1339 AffineParser affineParser(parser, createAffineBareId, tryToParseAttrUse, dims, 1340 symbols); 1341 return affineParser.parseAffineExprs(lDelim, rDelim); 1342 } 1343 1344 /// Parse the information for a tensor def of the form: 1345 /// 1346 /// affine-expr-list ::= affine-expr (`,` affine-expr )* 1347 /// tensor-typedef ::= type `(` `)` 1348 /// | type `(` affine-expr-list `)` 1349 /// tensor-def ::= bare-id `:` tensor-typedef 1350 LogicalResult TCParser::parseTensorDef(bool isOutput) { 1351 StringRef tensorId = parser.curToken.getSpelling(); 1352 if (failed(parser.parseToken(Token::Kind::id, "expected an id")) || 1353 failed(parser.parseToken(Token::Kind::colon, "expected colon"))) 1354 return failure(); 1355 1356 StringRef tensorType = parser.curToken.getSpelling(); 1357 if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) 1358 return failure(); 1359 1360 AffineDimList emptyDims; 1361 auto exprs = parseAffineExprs(EagerDiscoveryMode::Symbols, emptyDims); 1362 assert(emptyDims.empty() && "Unexpected dimension in tensor def"); 1363 AffineMap map = 1364 AffineMap::get(/*dimCount=*/0, symbols.size(), exprs, parser.context); 1365 1366 auto iterBoolPair = registeredTensors.try_emplace( 1367 tensorId, RegisteredTensor{tensorType, map, isOutput, AffineMap(), 1368 nextRegisteredTensorIndex++}); 1369 (void)iterBoolPair; 1370 assert(iterBoolPair.second && "Could not emplace tensor registration"); 1371 LLVM_DEBUG(llvm::dbgs() << "Recorded: " << tensorId << " " 1372 << "with typeString: " << tensorType << " " 1373 << "and shape: " << map << "\n"); 1374 1375 return success(); 1376 } 1377 1378 /// Parses a tensor use of the form: 1379 /// 1380 /// affine-expr-list ::= affine-expr (`,` affine-expr)* 1381 /// tensor-use ::= bare-id `(` `)` 1382 /// | bare-id `(` affine-expr-list `)` 1383 LogicalResult TCParser::parseTensorUse(TensorUse &result, 1384 ComprehensionParsingState &state) { 1385 StringRef tensorId = parser.curToken.getSpelling(); 1386 if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) 1387 return failure(); 1388 1389 auto exprs = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims); 1390 AffineMap map = 1391 AffineMap::get(state.dims.size(), symbols.size(), exprs, parser.context); 1392 LLVM_DEBUG(llvm::dbgs() << "Use of tensor: " << tensorId << " map: " << map 1393 << "\n"); 1394 1395 result = TensorUse(tensorId, map); 1396 return success(); 1397 } 1398 1399 /// Parse the information for an attribute def of the form: 1400 /// 1401 /// affine-expr-list ::= affine-expr (`,` affine-expr )* 1402 /// attr-id ::= bare-id (`?`)? 1403 /// dim-list ::= (integer-literal 'x')+ 1404 /// attr-typedef ::= dim-list? type (`[` `]`)? 1405 /// attr-def ::= attr-id `:` attr-typedef 1406 LogicalResult TCParser::parseAttrDef() { 1407 auto attrLoc = parser.curToken.getLoc(); 1408 StringRef attrName = parser.curToken.getSpelling(); 1409 if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) 1410 return failure(); 1411 bool isOptional = succeeded(parser.parseOptionalToken(Token::Kind::question)); 1412 if (failed(parser.parseToken(Token::Kind::colon, "expected colon"))) 1413 return failure(); 1414 1415 // Parse the attribute's type. We don't expect the type to be arbitrary 1416 // complex, so just use this ad-hoc handling here. 1417 1418 // Parse potential dimension list 1419 SmallVector<uint64_t, 4> vectorDims; 1420 while (parser.curToken.is(Token::Kind::integer)) { 1421 uint64_t value; 1422 if (failed(parser.parseInteger(value))) 1423 return failure(); 1424 vectorDims.push_back(value); 1425 1426 StringRef spelling = parser.curToken.getSpelling(); 1427 if (spelling[0] != 'x') 1428 return parser.emitError(parser.curToken.getLoc(), 1429 "expected 'x' in dimension list"); 1430 1431 // If we had a prefix of 'x', lex the next token immediately after the 'x'. 1432 if (spelling.size() != 1) 1433 parser.lexer.resetPointer(spelling.data() + 1); 1434 1435 parser.consumeToken(); 1436 } 1437 1438 StringRef elementType = parser.curToken.getSpelling(); 1439 if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) 1440 return failure(); 1441 1442 bool isArray = false; 1443 auto arrayLoc = parser.curToken.getLoc(); 1444 if (succeeded(parser.parseOptionalToken(Token::Kind::l_square))) { 1445 isArray = true; 1446 if (failed(parser.parseToken(Token::Kind::r_square, "expected ']'"))) 1447 return failure(); 1448 } 1449 1450 if (!vectorDims.empty() && isArray) 1451 return parser.emitError(arrayLoc, "unsupported vector array attribute"); 1452 1453 auto iterBoolPair = registeredAttrs.emplace( 1454 attrName.str(), 1455 RegisteredAttr{elementType, vectorDims, isArray, isOptional}); 1456 if (!iterBoolPair.second) 1457 return parser.emitError(attrLoc, 1458 "Failed to register attribute '" + attrName + "'"); 1459 1460 LLVM_DEBUG(llvm::dbgs() << "Recorded: " << (isOptional ? "[optional]" : "") 1461 << " " << attrName << " " 1462 << "with type: " << elementType 1463 << (isArray ? "[]" : "") << "\n"); 1464 1465 return success(); 1466 } 1467 1468 LogicalResult TCParser::parseAttrUse(AttrUse &result) { 1469 result.attrName = parser.curToken.getSpelling(); 1470 if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) 1471 return failure(); 1472 1473 auto it = registeredAttrs.find(result.attrName.str()); 1474 assert(it != registeredAttrs.end()); 1475 const RegisteredAttr &attr = it->second; 1476 1477 if (!attr.vectorDims.empty() || attr.isArray) { 1478 // This is a vector/array attribute. Parse indices for it. 1479 auto indexLoc = parser.curToken.getLoc(); 1480 1481 if (failed(parser.parseToken(Token::Kind::l_square, "expected '['"))) 1482 return failure(); 1483 1484 auto parseIndex = [&]() { 1485 uint64_t value; 1486 if (failed(parser.parseInteger(value))) 1487 return failure(); 1488 result.indices.push_back(value); 1489 return success(); 1490 }; 1491 if (failed(parser.parseCommaSeparatedListUntil( 1492 Token::Kind::r_square, parseIndex, /*allowEmptyList=*/false))) 1493 return failure(); 1494 1495 size_t rank = attr.isArray ? 1 : attr.vectorDims.size(); 1496 if (result.indices.size() != rank) 1497 return parser.emitError(indexLoc, 1498 "number of indices mismatch: expected " + 1499 std::to_string(rank) + ", but found " + 1500 std::to_string(result.indices.size())); 1501 } 1502 1503 return success(); 1504 } 1505 1506 /// Parses a tensor expression of the form: 1507 /// 1508 /// op-spec ::= bare-id `<` reduction-dims-list `>` 1509 /// | bare-id 1510 /// op-arg ::= tensor-expr 1511 /// | tensor-use 1512 /// op-arg-list ::= op-arg (`,` op-arg)* 1513 /// tensor-expr ::= op-spec `(` op-arg-list `)` 1514 LogicalResult TCParser::parseExpression(TensorUse currentDefinition, 1515 std::unique_ptr<Expression> &result, 1516 ComprehensionParsingState &state) { 1517 StringRef opOrTensor = parser.curToken.getSpelling(); 1518 if (registeredTensors.count(opOrTensor) > 0) { 1519 TensorUse use; 1520 auto res = parseTensorUse(use, state); 1521 if (failed(res)) 1522 return res; 1523 result = std::make_unique<TensorUse>(use); 1524 return success(); 1525 } 1526 1527 if (failed(parser.parseToken(Token::Kind::id, "expected an operation"))) 1528 return failure(); 1529 1530 // This is an op. 1531 SmallVector<unsigned, 4> reductionDims; 1532 SmallVector<std::unique_ptr<Expression>, 4> expressions; 1533 1534 // Check if it has a reduction set, discover dimensions eagerly. 1535 if (parser.curToken.is(Token::Kind::lt)) { 1536 auto iters = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims, 1537 Token::Kind::lt, Token::Kind::gt); 1538 for (auto iter : iters) 1539 reductionDims.push_back(iter.cast<AffineDimExpr>().getPosition()); 1540 } 1541 1542 auto parseExpr = [&]() -> LogicalResult { 1543 std::unique_ptr<Expression> e; 1544 if (failed(parseExpression(currentDefinition, e, state))) 1545 return failure(); 1546 expressions.push_back(std::move(e)); 1547 return success(); 1548 }; 1549 if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")) || 1550 failed(parser.parseCommaSeparatedListUntil( 1551 Token::Kind::r_paren, parseExpr, /*allowEmptyList=*/true))) 1552 return failure(); 1553 1554 result = std::make_unique<TensorExpr>(opOrTensor, std::move(expressions), 1555 reductionDims); 1556 1557 return success(); 1558 } 1559 1560 //===----------------------------------------------------------------------===// 1561 // Parse and Emit functions. 1562 //===----------------------------------------------------------------------===// 1563 1564 /// Parse the information for a single comprehension. 1565 /// 1566 /// tensor-def-list ::= tensor-def (`,` tensor-def)* 1567 /// tensor-expr-list ::= tensor-expr (`,` tensor-expr)* 1568 /// comprehension ::= tensor-def-list `=` tensor-expr-list `;` 1569 LogicalResult 1570 TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName, 1571 ComprehensionParsingState &state) { 1572 // 1. Parse LHS of `=`, these become the definitions that appear as the output 1573 // tensors or read/write buffers. 1574 SmallVector<TensorUse, 4> definitions; 1575 auto parseUse = [&]() -> LogicalResult { 1576 TensorUse use; 1577 if (failed(parseTensorUse(use, state))) 1578 return failure(); 1579 definitions.push_back(use); 1580 return success(); 1581 }; 1582 if (failed(parser.parseCommaSeparatedListUntil(Token::Kind::equal, parseUse, 1583 /*allowEmptyList=*/true))) 1584 return failure(); 1585 1586 // 2. Parse RHS of `=`, this becomes the expressions from which we emit 1587 // computations. 1588 unsigned idx = 0; 1589 auto parseExpr = [&]() -> LogicalResult { 1590 std::unique_ptr<Expression> expr; 1591 if (idx >= definitions.size()) 1592 return parser.emitError("Fewer LHS definitions than RHS expressions"); 1593 if (failed(parseExpression(definitions[idx++], expr, state))) 1594 return failure(); 1595 state.expressions.push_back(std::move(expr)); 1596 return success(); 1597 }; 1598 if (failed(parser.parseCommaSeparatedListUntil( 1599 Token::Kind::semicolon, parseExpr, /*allowEmptyList=*/true))) 1600 return failure(); 1601 if (idx != definitions.size()) 1602 return parser.emitError("Fewer RHS expressions than LHS definitions"); 1603 1604 // 3. Postprocess. 1605 // 3.a. Normalize all maps to the proper state.dims and symbols counts. 1606 SmallVector<TensorUse, 4> allUses; 1607 allUses.reserve(registeredTensors.size()); 1608 for (auto &def : definitions) 1609 allUses.push_back(def); 1610 for (auto &pExpr : state.expressions) 1611 visitPostorder(*pExpr, [&](const Expression &e) { 1612 if (auto *use = dyn_cast<TensorUse>(&e)) 1613 allUses.push_back(*use); 1614 }); 1615 for (auto &use : allUses) 1616 use.indexingMap = 1617 AffineMap::get(state.dims.size(), symbols.size(), 1618 use.indexingMap.getResults(), parser.context); 1619 1620 // 3.b. Traverse definitions 1621 llvm::DenseSet<StringRef> seenDefs; 1622 for (auto &def : definitions) { 1623 if (seenDefs.count(def.tensorId) > 0) 1624 return parser.emitError("Unexpected multi-write to a single tensor"); 1625 seenDefs.insert(def.tensorId); 1626 auto tensorIter = registeredTensors.find(def.tensorId); 1627 assert(tensorIter != registeredTensors.end() && "unregistered tensor"); 1628 auto &tensor = tensorIter->getValue(); 1629 tensor.indexingMap = def.indexingMap; 1630 state.orderedTensorArgs[def] = tensor.index; 1631 } 1632 1633 bool failed = false; 1634 for (auto &pExpr : state.expressions) 1635 visitPostorder(*pExpr, [&](const Expression &e) { 1636 auto *pUse = dyn_cast<TensorUse>(&e); 1637 if (failed || !pUse) 1638 return; 1639 auto &use = *pUse; 1640 LLVM_DEBUG(llvm::dbgs() 1641 << "\nuse: " << use.tensorId << " map: " << use.indexingMap); 1642 auto tensorIter = registeredTensors.find(use.tensorId); 1643 assert(tensorIter != registeredTensors.end() && "unregistered tensor"); 1644 auto &tensor = tensorIter->getValue(); 1645 if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0 && 1646 tensor.indexingMap.getResults() != use.indexingMap.getResults()) { 1647 LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap); 1648 (void)parser.emitError( 1649 "Unexpected multi-read of a tensor with different accesses"); 1650 failed = true; 1651 return; 1652 } 1653 seenDefs.insert(use.tensorId); 1654 tensor.indexingMap = use.indexingMap; 1655 state.orderedTensorArgs[use] = tensor.index; 1656 }); 1657 // If more than one definitions are less. They are shaped-only operand, which 1658 // are used to define reduction loops. For now, only accept exactly one 1659 // shaped-only operand. 1660 if (state.numArgs > seenDefs.size() + 1) { 1661 failed = true; 1662 } else if (state.numArgs == seenDefs.size() + 1) { 1663 for (auto &tensorIter : registeredTensors) { 1664 auto &tensor = tensorIter.getValue(); 1665 if (tensor.indexingMap) 1666 continue; 1667 if (auto *pTensorExpr = 1668 dyn_cast<TensorExpr>(state.expressions[0].get())) { 1669 SmallVector<AffineExpr, 4> exprs; 1670 for (auto dim : pTensorExpr->reductionDimensions) 1671 exprs.push_back(getAffineDimExpr(dim, parser.context)); 1672 tensor.indexingMap = AffineMap::get(state.dims.size(), symbols.size(), 1673 exprs, parser.context); 1674 } 1675 } 1676 } 1677 if (failed) 1678 return failure(); 1679 1680 return success(); 1681 } 1682 1683 /// Parse and print the information for a ODS def. 1684 /// 1685 /// tensor-def-list ::= tensor-def (`,` tensor-def )* 1686 /// attr-def-list ::= attr-def (`,` attr-def )* 1687 /// 1688 /// comprehension-list ::= comprehension comprehension* 1689 /// 1690 /// tc-attr-def ::= `attr` `(` attr-def-list `)` 1691 /// tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)` 1692 /// (tc-attr-def)? 1693 /// `{` comprehension-list `}` 1694 /// 1695 /// implements-interface ::= 1696 /// `implements_interface` `<` bare-id (`,` bare-id)* `>` `:` tc-def 1697 /// 1698 /// ods-def ::= `ods_def` `<` bare-id `>` 1699 /// (implements-interface)? `:` 1700 /// tc-def 1701 /// 1702 /// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e. 1703 /// contain only expressions involving symbols and constants), but can 1704 /// otherwise contain arbitrary affine expressions. 1705 LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) { 1706 // Parse ods-def header (including C++ op name) 1707 if (failed(parser.parseToken(Token::Kind::kw_ods_def, 1708 "expected 'ods_def' to define a TC ODS")) || 1709 failed(parser.parseToken(Token::Kind::lt, "expected '<'"))) 1710 return failure(); 1711 StringRef cppOpName = parser.curToken.getSpelling(); 1712 LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing ODS: " << cppOpName << "\n"); 1713 if (failed(parser.parseToken(Token::Kind::id, "expected id")) || 1714 failed(parser.parseToken(Token::Kind::gt, "expected '>'"))) 1715 return failure(); 1716 1717 // Parse optional implements-interface header (including C++ op names) 1718 SmallVector<StringRef> interfaces; 1719 bool implementsInterface = succeeded( 1720 parser.parseOptionalToken(Token::Kind::kw_implements_interface)); 1721 if (implementsInterface) { 1722 auto parseInterfaceString = [&]() -> LogicalResult { 1723 StringRef interfaceName = parser.curToken.getSpelling(); 1724 if (failed(parser.parseToken(Token::Kind::id, "expected id"))) 1725 return failure(); 1726 interfaces.push_back(interfaceName); 1727 return success(); 1728 }; 1729 if (failed(parser.parseToken(Token::Kind::lt, "expected '<'")) || 1730 failed(parser.parseCommaSeparatedListUntil( 1731 Token::Kind::gt, parseInterfaceString, /*allowEmptyList=*/false))) 1732 return failure(); 1733 } 1734 1735 // Parse column. 1736 if (failed(parser.parseToken(Token::Kind::colon, "expected ':'"))) 1737 return failure(); 1738 1739 // Parse TC op name. 1740 if (failed(parser.parseToken(Token::Kind::kw_def, 1741 "expected 'def' to define a TC"))) 1742 return failure(); 1743 StringRef tcName = parser.curToken.getSpelling(); 1744 LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n"); 1745 1746 // Parse input/output tensor definitions 1747 if (failed(parser.parseToken(Token::Kind::id, "expected id")) || 1748 failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) 1749 return failure(); 1750 1751 auto parseInputDef = [&]() -> LogicalResult { 1752 return parseTensorDef(/*isOutput=*/false); 1753 }; 1754 if (failed(parser.parseCommaSeparatedListUntil( 1755 Token::Kind::r_paren, parseInputDef, /*allowEmptyList=*/false))) 1756 return failure(); 1757 1758 if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")) || 1759 failed(parser.parseToken(Token::Kind::gt, "expected '>'")) || 1760 failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) 1761 return failure(); 1762 auto parseOutputDef = [&]() -> LogicalResult { 1763 return parseTensorDef(/*isOutput=*/true); 1764 }; 1765 if (failed(parser.parseCommaSeparatedListUntil( 1766 Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false))) 1767 return failure(); 1768 1769 // Parse optional attribute definitions 1770 if (succeeded(parser.parseOptionalToken(Token::Kind::kw_attr_def))) { 1771 if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) 1772 return failure(); 1773 if (failed(parser.parseCommaSeparatedListUntil( 1774 Token::Kind::r_paren, std::bind(&TCParser::parseAttrDef, this), 1775 /*allowEmptyList=*/false))) 1776 return failure(); 1777 } 1778 1779 // Parse optional doc string 1780 if (parser.curToken.is(Token::Kind::doc_str)) { 1781 docString = parser.curToken.getSpelling(); 1782 parser.consumeToken(); 1783 LLVM_DEBUG(llvm::dbgs() 1784 << "parsed doc string: '''" << docString << "'''\n"); 1785 } 1786 1787 // Since we don't declare symbols separately, we discover them eagerly: each 1788 // newly encountered id in a tensor shape expression is treated as a new 1789 // symbolic. At this point, all tensors have been parsed and all the symbols 1790 // that could be discovered eagerly are now known. Resize all AffineMaps to 1791 // normalize the number of eagerly discovered symbols. 1792 for (auto &tensor : registeredTensors) { 1793 auto &map = tensor.getValue().shape; 1794 map = AffineMap::get(/*dimCount=*/0, symbols.size(), map.getResults(), 1795 parser.context); 1796 } 1797 1798 if (failed(parser.parseToken(Token::Kind::l_brace, "expected '{'"))) 1799 return failure(); 1800 1801 SmallVector<ComprehensionParsingState, 4> perComprehensionStates; 1802 while (parser.curToken.isNot(Token::Kind::r_brace)) { 1803 perComprehensionStates.push_back(ComprehensionParsingState()); 1804 perComprehensionStates.back().numArgs = registeredTensors.size(); 1805 if (failed(parseOneComprehension(cppOpName, tcName, 1806 perComprehensionStates.back()))) 1807 return failure(); 1808 }; 1809 if (failed(parser.parseToken(Token::Kind::r_brace, "expected '}'"))) 1810 return failure(); 1811 1812 // Print. 1813 auto nComprehensions = perComprehensionStates.size(); 1814 if (nComprehensions != 1) 1815 return parser.emitError("only 1 comprehension supported for now, got: " + 1816 llvm::Twine(nComprehensions)); 1817 if (genODSDecl) { 1818 auto &state = perComprehensionStates.back(); 1819 printODS(os, cppOpName, tcName, interfaces, state); 1820 os << "\n"; 1821 } 1822 if (genODSImpl) { 1823 auto &state = perComprehensionStates.back(); 1824 std::string extraMethods; 1825 llvm::raw_string_ostream ss(extraMethods); 1826 printReferenceIterators(ss, cppOpName, state); 1827 printIndexingMapRequiredAttrMethods(ss, cppOpName, state); 1828 printReferenceIndexingMaps(ss, cppOpName, state); 1829 printRegionBuilder(ss, cppOpName, state); 1830 printCanonicalizersAndFolders(ss, cppOpName); 1831 ss.flush(); 1832 os << extraMethods << "\n"; 1833 } 1834 1835 return success(); 1836 } 1837 1838 //===----------------------------------------------------------------------===// 1839 // Printing functions 1840 //===----------------------------------------------------------------------===// 1841 1842 /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. 1843 void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, 1844 StringRef linalgOpName, ArrayRef<StringRef> interfaces, 1845 ComprehensionParsingState &state) { 1846 SmallVector<std::string, 4> attributes; 1847 for (const auto &attr : registeredAttrs) { 1848 llvm::StringRef name = attr.first; 1849 1850 llvm::StringRef elementType = attr.second.elementType; 1851 std::string odsType = llvm::StringSwitch<std::string>(elementType) 1852 .Case("f32", "F32") 1853 .Case("i32", "I32") 1854 .Case("i64", "I64") 1855 .Default(""); 1856 if (odsType.empty()) { 1857 (void)parser.emitError( 1858 "unimplemented support for attribute element type: " + elementType); 1859 return; 1860 } 1861 1862 const auto &dims = attr.second.vectorDims; 1863 if (!dims.empty()) { 1864 // Vector case 1865 SmallVector<std::string, 4> dimStrs; 1866 for (uint64_t dim : dims) 1867 dimStrs.push_back(std::to_string(dim)); 1868 odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType, 1869 llvm::join(dimStrs, ", ")); 1870 } else if (attr.second.isArray) { 1871 // Array case 1872 odsType = llvm::formatv("{0}ArrayAttr", odsType); 1873 } else { 1874 // Scalar case 1875 odsType = llvm::formatv("{0}Attr", odsType); 1876 } 1877 1878 if (attr.second.isOptional) 1879 odsType = llvm::formatv("OptionalAttr<{0}>", odsType); 1880 1881 attributes.push_back(llvm::formatv("{0}:${1}", odsType, name)); 1882 } 1883 1884 std::string attrList = llvm::join(attributes, ",\n"); 1885 if (!attrList.empty()) 1886 attrList = ",\n" + attrList; 1887 1888 // Template for Linalg named ops' ODS definitions. Parameters: 1889 // {0}: ODS/C++ op name 1890 // {1}: assembly op mnemonic 1891 // {2}: op interface list 1892 // {3}: documentation (summary + description) 1893 // {4}: op attribute list 1894 // {5}: the number of arguments for the op region 1895 // {6}: builder methods taking standalone attribute parameters 1896 // {7}: additional methods for attributes used by indexing maps 1897 const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [ 1898 AttrSizedOperandSegments, 1899 DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, 1900 SingleBlockImplicitTerminator<"YieldOp"> 1901 /*extraInterfaces=*/{2}]> { 1902 {3} 1903 let arguments = (ins 1904 Variadic<AnyShaped>:$inputs, 1905 Variadic<AnyShaped>:$outputs{4} 1906 ); 1907 let results = (outs Variadic<AnyRankedTensor>:$result_tensors); 1908 let regions = (region AnyRegion:$region); 1909 1910 let skipDefaultBuilders = 1; 1911 let builders = [ 1912 OpBuilder< 1913 (ins "ValueRange":$inputs, "ValueRange":$outputs), 1914 [{{ 1915 $_state.addOperands(inputs); 1916 $_state.addOperands(outputs); 1917 $_state.addAttribute( 1918 "operand_segment_sizes", 1919 $_builder.getI32VectorAttr({{ 1920 static_cast<int32_t>(inputs.size()), 1921 static_cast<int32_t>(outputs.size())})); 1922 createAndFillStructuredOpRegion<{0}>( 1923 $_builder, 1924 $_state, 1925 TypeRange(inputs), 1926 TypeRange(outputs)); 1927 }]>, 1928 OpBuilder< 1929 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, 1930 "ValueRange":$outputs), 1931 [{{ 1932 $_state.addOperands(inputs); 1933 $_state.addOperands(outputs); 1934 $_state.addTypes(resultTensorTypes); 1935 $_state.addAttribute( 1936 "operand_segment_sizes", 1937 $_builder.getI32VectorAttr({{ 1938 static_cast<int32_t>(inputs.size()), 1939 static_cast<int32_t>(outputs.size())})); 1940 createAndFillStructuredOpRegion<{0}>( 1941 $_builder, 1942 $_state, 1943 TypeRange(inputs), 1944 TypeRange(outputs)); 1945 }]>, 1946 OpBuilder< 1947 (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, 1948 CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes), 1949 [{{ 1950 $_state.addOperands(operands); 1951 $_state.addAttributes(attributes); 1952 $_state.addTypes(resultTensorTypes); 1953 (void)$_state.addRegion(); 1954 }]> 1955 {6} 1956 ]; 1957 let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; 1958 let parser = [{{ 1959 return ::parseNamedStructuredOp<{0}>(parser, result); 1960 }]; 1961 let hasFolder = 1; 1962 1963 let extraClassDeclaration = structuredOpsBaseDecls # [{{ 1964 // Auto-generated. 1965 ArrayAttr iterator_types(); 1966 ArrayAttr indexing_maps(); 1967 static void regionBuilder(ImplicitLocOpBuilder &b, Block &block); 1968 static std::function<void(ImplicitLocOpBuilder &b, Block &)> 1969 getRegionBuilder() {{ 1970 return regionBuilder; 1971 } 1972 1973 // Generic methods. 1974 static unsigned getNumRegionArgs() {{ return {5}; } 1975 std::string getLibraryCallName() {{ 1976 return generateLibraryCallName(getOperation()); 1977 } 1978 1979 {7} 1980 }]; 1981 })FMT"; 1982 1983 // Generate the list of extra implemented interfaces. 1984 std::string interfaceNameList; 1985 if (!interfaces.empty()) { 1986 llvm::raw_string_ostream ss(interfaceNameList); 1987 ss << ", "; // Leading comma to concat to existing list of interfaces. 1988 llvm::interleaveComma(interfaces, ss); 1989 ss.flush(); 1990 } 1991 1992 // Generate documentation. 1993 std::string doc; 1994 if (!docString.empty()) { 1995 const char *docFmt = R"FMT( 1996 let summary = [{ {0} }]; 1997 let description = [{ 1998 {1} 1999 }]; 2000 )FMT"; 2001 2002 StringRef summary, description; 2003 std::tie(summary, description) = docString.trim().split('\n'); 2004 doc = llvm::formatv(docFmt, summary.trim(), description.trim()); 2005 } 2006 2007 // Generate an additional builder that has parameters for attributes. 2008 std::string attrBuilder; 2009 if (!registeredAttrs.empty()) { 2010 SmallVector<std::string, 4> attrParams, attrStmts; 2011 for (const auto &attr : registeredAttrs) { 2012 llvm::StringRef name = attr.first; 2013 attrParams.push_back(llvm::formatv("\"Attribute\":${0}", name)); 2014 attrStmts.push_back( 2015 llvm::formatv("$_state.addAttribute(\"{0}\", {0});", name)); 2016 } 2017 std::string attrParamsList = llvm::join(attrParams, ", "); 2018 std::string attrStmtsList = llvm::join(attrStmts, "\n"); 2019 2020 const char *builderFmt = R"FMT( 2021 , OpBuilder< 2022 (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, 2023 "ValueRange":$outputs, {1}), 2024 [{{ 2025 $_state.addOperands(inputs); 2026 $_state.addOperands(outputs); 2027 $_state.addTypes(resultTensorTypes); 2028 $_state.addAttribute( 2029 "operand_segment_sizes", 2030 $_builder.getI32VectorAttr({{ 2031 static_cast<int32_t>(inputs.size()), 2032 static_cast<int32_t>(outputs.size())})); 2033 createAndFillStructuredOpRegion<{0}>( 2034 $_builder, 2035 $_state, 2036 TypeRange(inputs), 2037 TypeRange(outputs)); 2038 {2} 2039 }]> 2040 )FMT"; 2041 attrBuilder = 2042 llvm::formatv(builderFmt, cppOpName, attrParamsList, attrStmtsList); 2043 } 2044 2045 std::string attrMethods; 2046 if (!registeredAttrs.empty()) { 2047 attrMethods = R"( 2048 bool hasDynamicIndexingMaps(); 2049 LogicalResult verifyIndexingMapRequiredAttributes(); 2050 )"; 2051 } 2052 2053 // Finally put everything together. 2054 os << llvm::formatv(header, cppOpName, linalgOpName, interfaceNameList, doc, 2055 attrList, state.numArgs, attrBuilder, attrMethods); 2056 } 2057 2058 /// Print the C++ StructuredOpsInterface impl of `iterator_types`. 2059 void TCParser::printReferenceIterators(llvm::raw_ostream &os, 2060 StringRef cppOpName, 2061 ComprehensionParsingState &state) { 2062 const char *referenceReferenceIteratorsFmt = 2063 R"FMT( 2064 ArrayAttr {0}::iterator_types() { 2065 return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef, 8>{{ {1} }); 2066 })FMT"; 2067 2068 std::string iteratorsStr; 2069 llvm::raw_string_ostream ss(iteratorsStr); 2070 unsigned pos = 0; 2071 llvm::interleaveComma( 2072 state.dims, ss, [&](std::pair<StringRef, AffineExpr> p) { 2073 bool reduction = false; 2074 for (auto &expr : state.expressions) { 2075 visitPostorder(*expr, [&](const Expression &e) { 2076 if (auto *pTensorExpr = dyn_cast<TensorExpr>(&e)) { 2077 if (pTensorExpr->reductionDimensions.count(pos) > 0) 2078 reduction = true; 2079 } 2080 }); 2081 if (reduction) 2082 break; 2083 } 2084 ss << (reduction ? "getReductionIteratorTypeName()" 2085 : "getParallelIteratorTypeName()"); 2086 pos++; 2087 }); 2088 ss.flush(); 2089 2090 os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr); 2091 } 2092 2093 void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os, 2094 StringRef cppOpName) { 2095 const char *foldersFmt = R"FMT( 2096 LogicalResult {0}::fold(ArrayRef<Attribute>, 2097 SmallVectorImpl<OpFoldResult> &) {{ 2098 return foldMemRefCast(*this); 2099 } 2100 void {0}::getEffects(SmallVectorImpl< 2101 SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{ 2102 SmallVector<Value> inputBuffers = getInputBufferOperands(); 2103 SmallVector<Value> outputBuffers = getOutputBufferOperands(); 2104 getGenericEffectsImpl(effects, 2105 getOperation()->getResults(), inputBuffers, outputBuffers); 2106 })FMT"; 2107 os << llvm::formatv(foldersFmt, cppOpName); 2108 } 2109 2110 // Prints methods for querying whether the current named op has attributes that 2111 // are used by its indexing maps and for verifying those attributes have the 2112 // expected type. 2113 void TCParser::printIndexingMapRequiredAttrMethods( 2114 llvm::raw_ostream &os, StringRef cppOpName, 2115 ComprehensionParsingState &state) { 2116 // If there are no attribute used by the whole definition, then we are done. 2117 if (registeredAttrs.empty()) 2118 return; 2119 2120 // Otherwise, go through each attribute and generate code to verify it's 2121 // valid per the spec. 2122 SmallVector<std::string, 4> attributes; 2123 for (const auto &attr : registeredAttrs) { 2124 if (attr.second.isOptional) 2125 continue; 2126 2127 llvm::StringRef name = attr.first; 2128 llvm::StringRef elementType = attr.second.elementType; 2129 const auto &dims = attr.second.vectorDims; 2130 2131 // Get the method call to check the element type is of the expected kind. 2132 std::string elemTypeCheck = llvm::StringSwitch<std::string>(elementType) 2133 .Case("f32", "isF32()") 2134 .Case("i32", "isInteger(32)") 2135 .Case("i64", "isInteger(64)") 2136 .Default(""); 2137 if (elemTypeCheck.empty()) { 2138 (void)parser.emitError( 2139 "unimplemented support for attribute element type: " + elementType); 2140 return; 2141 } 2142 2143 // Scalar case. 2144 if (dims.empty() && !attr.second.isArray) { 2145 const char *attrFmt = R"FMT( 2146 if (auto attr = op->getAttr("{0}")) {{ 2147 if (!attr.getType().{1}) return op->emitError( 2148 "incorrect type for indexing map required attribute '{0}'"); 2149 } else {{ 2150 return op->emitError( 2151 "missing indexing map required attribute '{0}'"); 2152 } 2153 )FMT"; 2154 2155 attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck)); 2156 continue; 2157 } 2158 2159 // Vector case. 2160 if (!dims.empty()) { 2161 SmallVector<std::string, 4> dimStrs; 2162 for (uint64_t dim : dims) 2163 dimStrs.push_back(std::to_string(dim)); 2164 2165 const char *attrFmt = R"FMT( 2166 if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{ 2167 if (!attr.getType().getElementType().{1}) return op->emitError( 2168 "incorrect element type for indexing map required attribute '{0}'"); 2169 if (attr.getType().getShape() != ArrayRef<int64_t>{{ {2} }) 2170 return op->emitError( 2171 "incorrect shape for indexing map required attribute '{0}'"); 2172 } else { 2173 return op->emitError( 2174 "missing indexing map required attribute '{0}'"); 2175 } 2176 )FMT"; 2177 2178 attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck, 2179 llvm::join(dimStrs, ", "))); 2180 continue; 2181 } 2182 2183 // Array case. 2184 { 2185 const char *attrFmt = R"FMT( 2186 if (auto attr = op->getAttrOfType<ArrayAttr>("{0}")) {{ 2187 for (Attribute element : attr) {{ 2188 if (!element.getType().{1}) return emitError( 2189 "incorrect element type for indexing map required attribute '{0}'"); 2190 } 2191 } else {{ 2192 return op->emitError( 2193 "missing indexing map required attribute '{0}'"); 2194 } 2195 )FMT"; 2196 2197 attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck)); 2198 } 2199 } 2200 2201 const char *methodFmt = R"FMT( 2202 bool {0}::hasDynamicIndexingMaps() {{ return true; } 2203 2204 LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{ 2205 Operation *op = getOperation(); 2206 {1} 2207 return success(); 2208 } 2209 )FMT"; 2210 2211 // Print everything out. 2212 os << llvm::formatv(methodFmt, cppOpName, llvm::join(attributes, "\n")); 2213 } 2214 2215 /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`. 2216 void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, 2217 StringRef cppOpName, 2218 ComprehensionParsingState &state) { 2219 // 1. Generic string template for specifying reference indexing maps. 2220 const char *referenceIndexingMapsFmt = 2221 R"FMT( 2222 // This is temporary until we transition out of manually specified ops that 2223 // should be auto-generated with linalg-ods-gen. 2224 ArrayAttr {0}::indexing_maps() { 2225 MLIRContext *context = getContext(); 2226 AffineExpr {1}; 2227 bindDims(context, {1}); 2228 {2} 2229 return Builder(context).getAffineMapArrayAttr({ {3} }); 2230 })FMT"; 2231 2232 // 2. Print a comma-separated list of identifiers for the AffineExpr in 2233 // `state.dims`. These will replace the `{1}` placeholder in both 2234 // `AffineExpr {1}` and `bindDims(context, {1})` ensuring the AffineExpr 2235 // identifiers are bound in the right order to the proper AffineDimExpr. 2236 std::string dimsStr; 2237 llvm::raw_string_ostream ss(dimsStr); 2238 llvm::interleaveComma( 2239 state.dims, ss, 2240 [&](std::pair<StringRef, AffineExpr> p) { ss << p.second; }); 2241 ss.flush(); 2242 2243 // 3. Get the list of affine maps for each input/output. The AffineExpr use 2244 // the common arithmetic operators on AffineExpr. These affine maps will 2245 // replace the `{2}` placeholder. 2246 std::string mapsStr; 2247 llvm::raw_string_ostream mapsStringStream(mapsStr); 2248 2249 // Create a list of all symbols. 2250 SmallVector<std::string, 4> symbolReplacements; 2251 symbolReplacements.reserve(symbols.size()); 2252 for (unsigned i = 0; i < symbols.size(); ++i) { 2253 const char *symFmt = 2254 "\n\tauto s{0} = getAffineSymbolExpr({0}, context); (void)s{0};"; 2255 mapsStringStream << llvm::formatv(symFmt, i); 2256 symbolReplacements.push_back(llvm::formatv("s{0}", i)); 2257 } 2258 2259 // Create the affine constant expressions to replace symbols for attributes. 2260 for (auto attrUse : llvm::enumerate(attrUses)) { 2261 StringRef attrName = attrUse.value().attrName; 2262 auto it = registeredAttrs.find(attrName.str()); 2263 assert(it != registeredAttrs.end() && "uses should point to valid attr!"); 2264 llvm::Optional<std::string> getValueFn = 2265 it->second.getValueFn(attrUse.value().indices); 2266 if (!getValueFn) { 2267 (void)parser.emitError("unimplemented getValueFn for attribute: " + 2268 attrName); 2269 return; 2270 } 2271 std::string cstVal = llvm::formatv("{0}(){1}", attrName, *getValueFn); 2272 const char *cstFmt = 2273 "\n\tauto cst{0} = getAffineConstantExpr({1}, context);"; 2274 mapsStringStream << llvm::formatv(cstFmt, attrUse.index(), cstVal); 2275 2276 unsigned position = 2277 attrUse.value().symbol.cast<AffineSymbolExpr>().getPosition(); 2278 symbolReplacements[position] = llvm::formatv("cst{0}", attrUse.index()); 2279 } 2280 2281 // For each registered tensor, construct the affine map, replace symbols by 2282 // the corresponding attribute values, and simplify the affine map. 2283 for (auto &tensorIter : registeredTensors) { 2284 auto &tensor = tensorIter.getValue(); 2285 auto indexingMap = tensor.indexingMap; 2286 const char *mapFmt = 2287 "\n\tauto map{0} = AffineMap::get({1}, {2}, {3}, context);"; 2288 2289 std::string exprsStr; 2290 llvm::raw_string_ostream exprsStringStream(exprsStr); 2291 exprsStringStream << "{"; 2292 llvm::interleaveComma(indexingMap.getResults(), exprsStringStream); 2293 exprsStringStream << "}"; 2294 exprsStringStream.flush(); 2295 mapsStringStream << llvm::formatv(mapFmt, tensor.index, state.dims.size(), 2296 indexingMap.getNumSymbols(), exprsStr); 2297 2298 std::string replaceSymbolList = 2299 llvm::formatv("{ {0} }", llvm::join(symbolReplacements, ", ")); 2300 2301 // Note that we use `0` as the result affine map's number of symbols. All 2302 // symbols representing attribute usages should be folded away. But there 2303 // may exist additional symbols for tensor dimension upper bounds. Linalg 2304 // does not handle such cases right now. This needs to be fixed once we 2305 // need that. 2306 const char *replaceFmt = 2307 "\n\tmap{0} = map{0}.replaceDimsAndSymbols({{}, {1}, {2}, 0);"; 2308 mapsStringStream << llvm::formatv(replaceFmt, tensor.index, 2309 replaceSymbolList, state.dims.size()); 2310 const char *simplifyFmt = "\n\tmap{0} = simplifyAffineMap(map{0});"; 2311 mapsStringStream << llvm::formatv(simplifyFmt, tensor.index); 2312 } 2313 2314 mapsStringStream.flush(); 2315 2316 SmallVector<std::string, 4> mapList; 2317 mapList.reserve(state.numArgs); 2318 for (auto i : llvm::seq<unsigned>(0, state.numArgs)) 2319 mapList.push_back(llvm::formatv("map{0}", i)); 2320 2321 // 4. Apply format to 1. using 2. and 3. 2322 os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr, 2323 llvm::join(mapList, ", ")); 2324 } 2325 2326 /// Print the C++ StructuredOpsInterface impl of `regionBuilder`. 2327 void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, 2328 ComprehensionParsingState &state) { 2329 unsigned count = state.numArgs; 2330 llvm::DenseMap<const TensorExpr *, unsigned> subExprsMap; 2331 std::function<void(llvm::raw_ostream & os, const Expression &)> printExpr; 2332 printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void { 2333 if (auto *pUse = dyn_cast<TensorUse>(&e)) { 2334 os << "_" << state.orderedTensorArgs.find(*pUse)->second; 2335 return; 2336 } 2337 auto *pTensorExpr = cast<TensorExpr>(&e); 2338 if (subExprsMap.count(pTensorExpr) > 0) { 2339 os << "_" << subExprsMap[pTensorExpr]; 2340 } else { 2341 std::string subExprs; 2342 llvm::raw_string_ostream subExprsStringStream(subExprs); 2343 llvm::interleaveComma(pTensorExpr->expressions, subExprsStringStream, 2344 [&](const std::unique_ptr<Expression> &e) { 2345 printExpr(subExprsStringStream, *e); 2346 }); 2347 subExprsStringStream.flush(); 2348 const char *tensorExprFmt = "\n Value _{0} = b.create<{1}>({2});"; 2349 os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName, 2350 subExprs); 2351 subExprsMap[pTensorExpr] = count; 2352 } 2353 }; 2354 2355 const char *regionBuilderFmt = R"FMT( 2356 void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) { 2357 auto args = block.getArguments(); 2358 Value {1}; 2359 {2} 2360 b.create<linalg::YieldOp>(ValueRange{ {3} }); 2361 })FMT"; 2362 2363 std::string valueHandleStr; 2364 llvm::raw_string_ostream valueHandleStringStream(valueHandleStr); 2365 std::set<unsigned> usedTensorId; 2366 for (const auto &iter : state.orderedTensorArgs) 2367 usedTensorId.insert(iter.second); 2368 llvm::interleaveComma(usedTensorId, valueHandleStringStream, [&](auto idx) { 2369 valueHandleStringStream << "_" << idx << "(args[" << idx << "])"; 2370 }); 2371 2372 std::string expressionsStr; 2373 llvm::raw_string_ostream expressionStringStream(expressionsStr); 2374 for (auto &expr : state.expressions) 2375 visitPostorder(*expr, [&](const Expression &e) { 2376 if (e.kind == Expression::Kind::TensorExpr) 2377 printExpr(expressionStringStream, e); 2378 }); 2379 expressionStringStream.flush(); 2380 substituteOpAliases(expressionsStr); 2381 2382 std::string yieldStr; 2383 llvm::raw_string_ostream yieldStringStream(yieldStr); 2384 llvm::interleaveComma(state.expressions, yieldStringStream, 2385 [&](const std::unique_ptr<Expression> &e) { 2386 printExpr(yieldStringStream, *e); 2387 }); 2388 2389 valueHandleStringStream.flush(); 2390 yieldStringStream.flush(); 2391 2392 os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr, 2393 expressionsStr, yieldStr); 2394 } 2395 2396 llvm::Optional<std::string> 2397 TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const { 2398 if (isArray) 2399 return llvm::None; 2400 2401 if (!vectorDims.empty()) { 2402 SmallVector<std::string, 4> indexStrs; 2403 for (uint64_t index : indices) 2404 indexStrs.push_back(std::to_string(index)); 2405 std::string indexList = llvm::join(indexStrs, ", "); 2406 if (elementType == "f32") 2407 return llvm::formatv(".getValue<float>({ {0} })", indexList).str(); 2408 if (elementType == "i32") 2409 return llvm::formatv(".getValue<int>({ {0} })", indexList).str(); 2410 if (elementType == "i64") 2411 return llvm::formatv(".getValue<int64_t>({ {0} })", indexList).str(); 2412 2413 return llvm::None; 2414 } 2415 2416 if (elementType == "f32") 2417 return std::string(".convertToFloat()"); 2418 if (elementType == "i32" || elementType == "i64") 2419 return std::string(""); 2420 return llvm::None; 2421 } 2422 2423 /// Iterate over each Tensor Comprehension def. 2424 LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os, 2425 Parser &parser) { 2426 while (parser.curToken.getKind() != Token::Kind::eof) { 2427 TCParser tcParser(parser); 2428 if (failed(tcParser.parseAndEmitODSDef(os))) 2429 return failure(); 2430 } 2431 return success(); 2432 } 2433 2434 int main(int argc, char **argv) { 2435 llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen"); 2436 2437 // Set up the input file. 2438 std::string errorMessage; 2439 std::unique_ptr<llvm::MemoryBuffer> file = 2440 mlir::openInputFile(inputFilename, &errorMessage); 2441 if (!file) { 2442 llvm::errs() << errorMessage << "\n"; 2443 return 1; 2444 } 2445 2446 std::unique_ptr<llvm::ToolOutputFile> output = 2447 openOutputFile(outputFilename, &errorMessage); 2448 if (!output) { 2449 llvm::errs() << errorMessage << "\n"; 2450 exit(1); 2451 } 2452 2453 // Include the proper Linalg header for end-to-end tblgen testing without 2454 // resorting to non-portable shell manipulations. 2455 if (testEmitIncludeTdHeader) 2456 output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\""; 2457 2458 MLIRContext context; 2459 llvm::SourceMgr mgr; 2460 mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); 2461 Parser parser(mgr, &context); 2462 (void)parseAndEmitAllTensorComprehensions(output->os(), parser); 2463 output->keep(); 2464 2465 return 0; 2466 } 2467