1 //===- mlir-linalg-ods-gen.cpp - Linalg ODS generation from math form -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation for the Tensor Comprehension-inspired
10 // parser and ODS pretty-printer for specifying Linalg "named ops" from a
11 // mathematical form.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/IR/OpImplementation.h"
19 #include "mlir/Support/FileUtilities.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/Optional.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SetVector.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/ADT/StringSwitch.h"
30 #include "llvm/ADT/Twine.h"
31 #include "llvm/Support/Casting.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/FormatVariadic.h"
35 #include "llvm/Support/MemoryBuffer.h"
36 #include "llvm/Support/SourceMgr.h"
37 #include "llvm/Support/ToolOutputFile.h"
38 
39 #include <map>
40 #include <set>
41 
42 #define DEBUG_TYPE "linalg-ods-gen"
43 
44 static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen");
45 
46 // Commandline options
47 static llvm::cl::opt<std::string>
48     inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
49                   llvm::cl::init("-"), llvm::cl::value_desc("filename"));
50 
51 static llvm::cl::opt<std::string>
52     outputFilename("o", llvm::cl::desc("Output filename"),
53                    llvm::cl::value_desc("filename"), llvm::cl::init("-"));
54 
55 static llvm::cl::opt<bool>
56     genODSDecl("gen-ods-decl", llvm::cl::desc("Emit the ODS ops declarations."),
57                llvm::cl::cat(ODSGenCat));
58 
59 static llvm::cl::opt<bool>
60     genODSImpl("gen-impl", llvm::cl::desc("Emit the ops implementations"),
61                llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
62 
63 static llvm::cl::opt<bool> testEmitIncludeTdHeader(
64     "test-emit-include-td-header",
65     llvm::cl::desc("Include LinalgStructuredOps.td for end-to-end "
66                    "tblgen testing."),
67     llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
68 
69 using llvm::SMLoc;
70 using llvm::StringRef;
71 using llvm::Twine;
72 
73 using namespace mlir;
74 
75 //===----------------------------------------------------------------------===//
76 // Special "op aliases" substitutions.
77 //===----------------------------------------------------------------------===//
78 
79 /// Perform substitutions of known special ops.
80 /// This is a poor man's way of achieving "op aliases": i.e. giving an op a
81 /// name.
82 /// This is hacky and temporary until migration to the python opdsl is complete.
83 static void substituteOpAliases(std::string &expressionsStr) {
84   for (auto kvp : SmallVector<std::pair<std::string, std::string>>{
85            {"b.create<CmpIOpSGT>(", "b.create<CmpIOp>(CmpIPredicate::sgt, "},
86            {"b.create<CmpFOpOGT>(", "b.create<CmpFOp>(CmpFPredicate::OGT, "},
87            {"b.create<CmpFOpOLT>(", "b.create<CmpFOp>(CmpFPredicate::OLT, "},
88            {"b.create<SignExtendIOp32>(",
89             "b.create<SignExtendIOp>(b.getI32Type(), "},
90        }) {
91     size_t pos = 0;
92     while ((pos = expressionsStr.find(kvp.first, pos)) != std::string::npos) {
93       expressionsStr.replace(pos, kvp.first.size(), kvp.second);
94       pos += kvp.second.size();
95     }
96   }
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // Lexer
101 //===----------------------------------------------------------------------===//
102 
103 namespace {
104 /// This class represents a specific token in the input format.
105 class Token {
106 public:
107   enum class Kind {
108     // Markers.
109     eof,
110     error,
111 
112     // Tokens with no info.
113     colon,
114     comma,
115     doc_str,
116     equal,
117     gt,
118     l_brace,
119     l_paren,
120     l_square,
121     lt,
122     minus,
123     plus,
124     question,
125     r_brace,
126     r_paren,
127     r_square,
128     semicolon,
129     star,
130 
131     // Keywords.
132     kw_def,
133     FIRST_KEYWORD = kw_def,
134     kw_ods_def,
135     kw_implements_interface,
136     kw_attr_def,
137     kw_floordiv,
138     kw_ceildiv,
139     kw_mod,
140     LAST_KEYWORD = kw_mod,
141 
142     // String valued tokens.
143     id,
144     integer,
145   };
146 
147   Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
148 
149   /// Return the bytes that make up this token.
150   StringRef getSpelling() const { return spelling; }
151 
152   /// Return the kind of this token.
153   Kind getKind() const { return kind; }
154 
155   /// Return a location for this token.
156   llvm::SMLoc getLoc() const {
157     return llvm::SMLoc::getFromPointer(spelling.data());
158   }
159 
160   /// Return if this token is a keyword.
161   bool isKeyword() const {
162     return kind >= Kind::FIRST_KEYWORD && kind <= Kind::LAST_KEYWORD;
163   }
164   bool is(Kind k) const { return kind == k; }
165   bool isNot(Kind k) const { return kind != k; }
166 
167   Optional<uint64_t> getUInt64IntegerValue() const {
168     bool isHex = spelling.size() > 1 && spelling[1] == 'x';
169 
170     uint64_t result = 0;
171     if (spelling.getAsInteger(isHex ? 0 : 10, result))
172       return None;
173     return result;
174   }
175 
176 private:
177   /// Discriminator that indicates the kind of token this is.
178   Kind kind;
179 
180   /// A reference to the entire token contents; this is always a pointer into
181   /// a memory buffer owned by the source manager.
182   StringRef spelling;
183 };
184 
185 /// This class implements a simple lexer.
186 class Lexer {
187 public:
188   Lexer(llvm::SourceMgr &mgr);
189 
190   /// Lex the next token and return it.
191   Token lexToken();
192 
193   /// Emit an error to the lexer with the given location and message.
194   Token emitError(llvm::SMLoc loc, const Twine &msg);
195   Token emitError(const char *loc, const Twine &msg);
196 
197   /// Change the position of the lexer cursor. The next token we lex will start
198   /// at the designated point in the input.
199   void resetPointer(const char *newPtr) { curPtr = newPtr; }
200 
201 private:
202   Token formToken(Token::Kind kind, const char *tokStart) {
203     return Token(kind, StringRef(tokStart, curPtr - tokStart));
204   }
205 
206   /// Return the next character in the stream.
207   int getNextChar();
208 
209   /// Lex an identifier.
210   Token lexIdentifier(const char *tokStart);
211 
212   // Lex an integer.
213   Token lexInteger(const char *tokStart);
214 
215   // Lex a string.
216   Token lexString(const char *tokStart);
217 
218   // Skip a comment line, starting with a '//'.
219   void skipComment();
220 
221   llvm::SourceMgr &srcMgr;
222   StringRef curBuffer;
223   const char *curPtr;
224 };
225 } // end anonymous namespace
226 
227 Lexer::Lexer(llvm::SourceMgr &mgr) : srcMgr(mgr) {
228   curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
229   curPtr = curBuffer.begin();
230 }
231 
232 Token Lexer::emitError(llvm::SMLoc loc, const Twine &msg) {
233   srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
234   return formToken(Token::Kind::error, loc.getPointer());
235 }
236 Token Lexer::emitError(const char *loc, const Twine &msg) {
237   return emitError(llvm::SMLoc::getFromPointer(loc), msg);
238 }
239 
240 int Lexer::getNextChar() {
241   char curChar = *curPtr++;
242   switch (curChar) {
243   default:
244     return (unsigned char)curChar;
245   case 0: {
246     // A nul character in the stream is either the end of the current buffer
247     // or a random nul in the file. Disambiguate that here.
248     if (curPtr - 1 != curBuffer.end())
249       return 0;
250 
251     // Otherwise, return end of file.
252     --curPtr;
253     return EOF;
254   }
255   case '\n':
256   case '\r':
257     // Handle the newline character by ignoring it and incrementing the line
258     // count. However, be careful about 'dos style' files with \n\r in them.
259     // Only treat a \n\r or \r\n as a single line.
260     if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
261       ++curPtr;
262     return '\n';
263   }
264 }
265 
266 Token Lexer::lexToken() {
267   while (true) {
268     const char *tokStart = curPtr;
269 
270     // This always consumes at least one character.
271     int curChar = getNextChar();
272     switch (curChar) {
273     default:
274       // Handle identifiers: [a-zA-Z_]
275       if (isalpha(curChar) || curChar == '_')
276         return lexIdentifier(tokStart);
277 
278       // Handle integers: [0-9]
279       if (isdigit(curChar))
280         return lexInteger(tokStart);
281 
282       // Unknown character, emit an error.
283       return emitError(tokStart, "unexpected character");
284 
285     case EOF:
286       // Return EOF denoting the end of lexing.
287       return formToken(Token::Kind::eof, tokStart);
288 
289     // Lex punctuation.
290     case ':':
291       return formToken(Token::Kind::colon, tokStart);
292     case ',':
293       return formToken(Token::Kind::comma, tokStart);
294     case '=':
295       return formToken(Token::Kind::equal, tokStart);
296     case '{':
297       return formToken(Token::Kind::l_brace, tokStart);
298     case '(':
299       return formToken(Token::Kind::l_paren, tokStart);
300     case '[':
301       return formToken(Token::Kind::l_square, tokStart);
302     case '}':
303       return formToken(Token::Kind::r_brace, tokStart);
304     case ')':
305       return formToken(Token::Kind::r_paren, tokStart);
306     case ']':
307       return formToken(Token::Kind::r_square, tokStart);
308     case '<':
309       return formToken(Token::Kind::lt, tokStart);
310     case '>':
311       return formToken(Token::Kind::gt, tokStart);
312     case '+':
313       return formToken(Token::Kind::plus, tokStart);
314     case '-':
315       return formToken(Token::Kind::minus, tokStart);
316     case ';':
317       return formToken(Token::Kind::semicolon, tokStart);
318     case '*':
319       return formToken(Token::Kind::star, tokStart);
320     case '?':
321       return formToken(Token::Kind::question, tokStart);
322     case '"':
323       return lexString(tokStart);
324     case '/':
325       if (*curPtr == '/') {
326         skipComment();
327         continue;
328       }
329       // Unknown character, emit an error.
330       return emitError(tokStart, "unexpected character: not a comment");
331 
332     // Ignore whitespace characters.
333     case 0:
334     case ' ':
335     case '\t':
336     case '\n':
337       return lexToken();
338     }
339   }
340 }
341 
342 Token Lexer::lexIdentifier(const char *tokStart) {
343   // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
344   while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
345     ++curPtr;
346 
347   // Check to see if this identifier is a keyword.
348   StringRef str(tokStart, curPtr - tokStart);
349   Token::Kind kind =
350       StringSwitch<Token::Kind>(str)
351           .Case("attr", Token::Kind::kw_attr_def)
352           .Case("def", Token::Kind::kw_def)
353           .Case("ods_def", Token::Kind::kw_ods_def)
354           .Case("implements_interface", Token::Kind::kw_implements_interface)
355           .Case("floordiv", Token::Kind::kw_floordiv)
356           .Case("ceildiv", Token::Kind::kw_ceildiv)
357           .Case("mod", Token::Kind::kw_mod)
358           .Default(Token::Kind::id);
359 
360   return Token(kind, str);
361 }
362 
363 Token Lexer::lexInteger(const char *tokStart) {
364   // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
365   while (isdigit(*curPtr))
366     ++curPtr;
367 
368   StringRef str(tokStart, curPtr - tokStart);
369   return Token(Token::Kind::integer, str);
370 }
371 
372 Token Lexer::lexString(const char *tokStart) {
373   assert(curPtr[-1] == '"');
374 
375   if (*curPtr == '"' && *(curPtr + 1) == '"') {
376     curPtr += 2;
377     while (true) {
378       switch (*curPtr++) {
379       case '"':
380         if (*curPtr == '"' && *(curPtr + 1) == '"') {
381           Token token(Token::Kind::doc_str,
382                       StringRef(tokStart + 3, curPtr - tokStart - 4));
383           curPtr += 2;
384           return token;
385         }
386         continue;
387       case 0:
388         // If this is a random nul character in the middle of the doc string,
389         // just include it.  If it is the end of file, then it is an error.
390         if (curPtr - 1 != curBuffer.end())
391           continue;
392         return emitError(curPtr - 1, "expected '\"\"\"' to end doc string");
393       default:
394         continue;
395       }
396     }
397   }
398 
399   return emitError(curPtr - 1, "expected '\"\"\"' to start doc string");
400 }
401 
402 /// Skip a comment line, starting with a '//'.
403 void Lexer::skipComment() {
404   // Advance over the second '/' in a '//' comment.
405   assert(*curPtr == '/');
406   ++curPtr;
407 
408   while (true) {
409     switch (*curPtr++) {
410     case '\n':
411     case '\r':
412       // Newline is end of comment.
413       return;
414     case 0:
415       // If this is the end of the buffer, end the comment.
416       if (curPtr - 1 == curBuffer.end()) {
417         --curPtr;
418         return;
419       }
420       LLVM_FALLTHROUGH;
421     default:
422       // Skip over other characters.
423       break;
424     }
425   }
426 }
427 
428 namespace {
429 
430 class Parser {
431 public:
432   Parser(llvm::SourceMgr &mgr, MLIRContext *ctx)
433       : lexer(mgr), curToken(lexer.lexToken()), context(ctx) {}
434 
435   //===--------------------------------------------------------------------===//
436   // Lexer Utilities
437   //===--------------------------------------------------------------------===//
438 
439   LogicalResult parseInteger(uint64_t &value) {
440     if (!curToken.is(Token::Kind::integer))
441       return emitError(curToken.getLoc(), "expected integer");
442     value = curToken.getUInt64IntegerValue().getValue();
443     consumeToken();
444     return success();
445   }
446 
447   /// Advance the current lexer onto the next token.
448   void consumeToken() {
449     assert(curToken.getKind() != Token::Kind::eof &&
450            curToken.getKind() != Token::Kind::error &&
451            "shouldn't advance past EOF or errors");
452     curToken = lexer.lexToken();
453   }
454 
455   void consumeToken(Token::Kind kind) {
456     assert(curToken.getKind() == kind && "unexpected token");
457     curToken = lexer.lexToken();
458   }
459 
460   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
461     if (curToken.getKind() != kind)
462       return emitError(curToken.getLoc(), msg);
463     consumeToken();
464     return success();
465   }
466 
467   /// Parses an optional token and returns failure if failed to parse.
468   LogicalResult parseOptionalToken(Token::Kind kind) {
469     return success(consumeIf(kind));
470   }
471 
472   LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
473     lexer.emitError(loc, msg);
474     return failure();
475   }
476 
477   LogicalResult emitError(const Twine &msg) {
478     return emitError(curToken.getLoc(), msg);
479   }
480 
481   bool consumeIf(Token::Kind kind) {
482     if (curToken.isNot(kind))
483       return false;
484     consumeToken(kind);
485     return true;
486   }
487 
488   LogicalResult
489   parseCommaSeparatedList(llvm::function_ref<ParseResult()> parseElement) {
490     // Non-empty case starts with an element.
491     if (parseElement())
492       return failure();
493 
494     // Otherwise we have a list of comma separated elements.
495     while (consumeIf(Token::Kind::comma)) {
496       if (parseElement())
497         return failure();
498     }
499     return success();
500   }
501 
502   LogicalResult
503   parseCommaSeparatedListUntil(Token::Kind rightToken,
504                                llvm::function_ref<ParseResult()> parseElement,
505                                bool allowEmptyList) {
506     // Handle the empty case.
507     if (curToken.is(rightToken)) {
508       if (!allowEmptyList)
509         return emitError("expected list element");
510       consumeToken(rightToken);
511       return success();
512     }
513 
514     if (failed(parseCommaSeparatedList(parseElement)) ||
515         failed(
516             parseToken(rightToken, "expected ',' or right-terminating token")))
517       return failure();
518 
519     return success();
520   }
521 
522   Lexer lexer;
523   Token curToken;
524   MLIRContext *context;
525 };
526 } // namespace
527 
528 /// Encodes an attribute use of the form:
529 ///
530 ///   index-list ::= integer-literal (`,` integer-literal)*
531 ///   attr-use ::= bare-id `[` index-list `]`
532 struct AttrUse {
533   // Referenced attribute
534   StringRef attrName;
535   // Indices into the attribute
536   SmallVector<uint64_t, 4> indices;
537   /// Affine symbol for this usage.
538   /// This is represented as an affine symbol because at the time of parsing the
539   /// spec and generating the op's ODS/C++, we don't know the concrete constant
540   /// value. But they should be replaced with constants read from the attribute
541   /// and thus folded away for concrete op instances.
542   AffineExpr symbol;
543 
544   std::string getKey() {
545     SmallVector<std::string, 4> indexStrs;
546     for (uint64_t index : indices)
547       indexStrs.push_back(std::to_string(index));
548     return llvm::formatv("{0}[{1}]", attrName, llvm::join(indexStrs, ","));
549   }
550 };
551 
552 //===----------------------------------------------------------------------===//
553 // Affine parsing.
554 //===----------------------------------------------------------------------===//
555 
556 namespace {
557 
558 /// Lower precedence ops (all at the same precedence level). LNoOp is false in
559 /// the boolean sense.
560 enum AffineLowPrecOp {
561   /// Null value.
562   LNoOp,
563   Add,
564   Sub
565 };
566 
567 /// Higher precedence ops - all at the same precedence level. HNoOp is false
568 /// in the boolean sense.
569 enum AffineHighPrecOp {
570   /// Null value.
571   HNoOp,
572   Mul,
573   FloorDiv,
574   CeilDiv,
575   Mod
576 };
577 
578 using AffineDimList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
579 using AffineSymbolList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
580 
581 /// This is a specialized parser for affine expressions.
582 class AffineParser {
583 public:
584   /// Creates an affine parser that parses tokens from `p`.
585   ///
586   /// The affine parser introduces new dimensions and symbols eagerly as new
587   /// `id` are discovered. To additionally support attribute use `id`s, for a
588   /// parsed `id`, the resolution mechanism proceeds as follows:
589   /// 1. Try to parse `id` as an attribute use (using the `attrUseParsingHook`).
590   /// 2. If unsuccessful, try to match `id` to a known dim or symbol.
591   /// 3. If still unsuccessful, eagerly create a new dim or symbol and add it to
592   ///    the known dims or symbols (using the `bareIdParsingHook`).
593   explicit AffineParser(
594       Parser &p, std::function<AffineExpr(StringRef)> bareIdParsingHook,
595       std::function<llvm::Optional<AffineExpr>()> attrUseParsingHook,
596       AffineDimList &dimList, AffineSymbolList &symbolList)
597       : parser(p), bareIdFallback(bareIdParsingHook),
598         attrUseCallback(attrUseParsingHook), dims(dimList),
599         symbols(symbolList) {}
600 
601   /// Parse a comma-separated list of affine exprs.
602   SmallVector<AffineExpr, 4>
603   parseAffineExprs(Token::Kind lDelim = Token::Kind::l_paren,
604                    Token::Kind rDelim = Token::Kind::r_paren);
605 
606   /// Parse a single affine expr.`.
607   AffineExpr parseAffineExpr();
608 
609 private:
610   // Binary affine op parsing.
611   AffineLowPrecOp consumeIfLowPrecOp();
612   AffineHighPrecOp consumeIfHighPrecOp();
613 
614   // AffineExpr parsing.
615   AffineExpr parseParentheticalExpr();
616   AffineExpr parseNegateExpression(AffineExpr lhs);
617   AffineExpr parseIntegerExpr();
618   AffineExpr parseAttrUseOrBareIdExpr();
619   AffineExpr parseBareIdExpr();
620 
621   AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
622                                    AffineExpr rhs, SMLoc opLoc);
623   AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
624                                    AffineExpr rhs);
625   AffineExpr parseAffineOperandExpr(AffineExpr lhs);
626   AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
627   AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
628                                        SMLoc llhsOpLoc);
629 
630   Parser &parser;
631   std::function<AffineExpr(StringRef)> bareIdFallback;
632   std::function<llvm::Optional<AffineExpr>()> attrUseCallback;
633   AffineDimList &dims;
634   AffineSymbolList &symbols;
635 };
636 } // end anonymous namespace
637 
638 /// Create an affine binary high precedence op expression (mul's, div's, mod).
639 /// opLoc is the location of the op token to be used to report errors
640 /// for non-conforming expressions.
641 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
642                                                AffineExpr lhs, AffineExpr rhs,
643                                                SMLoc opLoc) {
644   switch (op) {
645   case Mul:
646     if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
647       (void)parser.emitError(
648           opLoc, "non-affine expression: at least one of the multiply "
649                  "operands has to be either a constant or symbolic");
650       return nullptr;
651     }
652     return lhs * rhs;
653   case FloorDiv:
654     if (!rhs.isSymbolicOrConstant()) {
655       (void)parser.emitError(opLoc,
656                              "non-affine expression: right operand of floordiv "
657                              "has to be either a constant or symbolic");
658       return nullptr;
659     }
660     return lhs.floorDiv(rhs);
661   case CeilDiv:
662     if (!rhs.isSymbolicOrConstant()) {
663       (void)parser.emitError(opLoc,
664                              "non-affine expression: right operand of ceildiv "
665                              "has to be either a constant or symbolic");
666       return nullptr;
667     }
668     return lhs.ceilDiv(rhs);
669   case Mod:
670     if (!rhs.isSymbolicOrConstant()) {
671       (void)parser.emitError(opLoc,
672                              "non-affine expression: right operand of mod "
673                              "has to be either a constant or symbolic");
674       return nullptr;
675     }
676     return lhs % rhs;
677   case HNoOp:
678     llvm_unreachable("can't create affine expression for null high prec op");
679     return nullptr;
680   }
681   llvm_unreachable("Unknown AffineHighPrecOp");
682 }
683 
684 /// Create an affine binary low precedence op expression (add, sub).
685 AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
686                                                AffineExpr lhs, AffineExpr rhs) {
687   switch (op) {
688   case AffineLowPrecOp::Add:
689     return lhs + rhs;
690   case AffineLowPrecOp::Sub:
691     return lhs - rhs;
692   case AffineLowPrecOp::LNoOp:
693     llvm_unreachable("can't create affine expression for null low prec op");
694     return nullptr;
695   }
696   llvm_unreachable("Unknown AffineLowPrecOp");
697 }
698 
699 /// Consume this token if it is a lower precedence affine op (there are only
700 /// two precedence levels).
701 AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
702   switch (parser.curToken.getKind()) {
703   case Token::Kind::plus:
704     parser.consumeToken();
705     return AffineLowPrecOp::Add;
706   case Token::Kind::minus:
707     parser.consumeToken();
708     return AffineLowPrecOp::Sub;
709   default:
710     return AffineLowPrecOp::LNoOp;
711   }
712 }
713 
714 /// Consume this token if it is a higher precedence affine op (there are only
715 /// two precedence levels)
716 AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
717   switch (parser.curToken.getKind()) {
718   case Token::Kind::star:
719     parser.consumeToken(Token::Kind::star);
720     return Mul;
721   case Token::Kind::kw_floordiv:
722     parser.consumeToken(Token::Kind::kw_floordiv);
723     return FloorDiv;
724   case Token::Kind::kw_ceildiv:
725     parser.consumeToken(Token::Kind::kw_ceildiv);
726     return CeilDiv;
727   case Token::Kind::kw_mod:
728     parser.consumeToken(Token::Kind::kw_mod);
729     return Mod;
730   default:
731     return HNoOp;
732   }
733 }
734 
735 /// Parse a high precedence op expression list: mul, div, and mod are high
736 /// precedence binary ops, i.e., parse a
737 ///   expr_1 op_1 expr_2 op_2 ... expr_n
738 /// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
739 /// All affine binary ops are left associative.
740 /// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
741 /// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
742 /// null. llhsOpLoc is the location of the llhsOp token that will be used to
743 /// report an error for non-conforming expressions.
744 AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
745                                                    AffineHighPrecOp llhsOp,
746                                                    SMLoc llhsOpLoc) {
747   AffineExpr lhs = parseAffineOperandExpr(llhs);
748   if (!lhs)
749     return nullptr;
750 
751   // Found an LHS. Parse the remaining expression.
752   auto opLoc = parser.curToken.getLoc();
753   if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
754     if (llhs) {
755       AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
756       if (!expr)
757         return nullptr;
758       return parseAffineHighPrecOpExpr(expr, op, opLoc);
759     }
760     // No LLHS, get RHS
761     return parseAffineHighPrecOpExpr(lhs, op, opLoc);
762   }
763 
764   // This is the last operand in this expression.
765   if (llhs)
766     return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
767 
768   // No llhs, 'lhs' itself is the expression.
769   return lhs;
770 }
771 
772 /// Parse an affine expression inside parentheses.
773 ///
774 ///   affine-expr ::= `(` affine-expr `)`
775 AffineExpr AffineParser::parseParentheticalExpr() {
776   if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
777     return nullptr;
778   if (parser.curToken.is(Token::Kind::r_paren))
779     return ((void)parser.emitError("no expression inside parentheses"),
780             nullptr);
781 
782   auto expr = parseAffineExpr();
783   if (!expr)
784     return nullptr;
785   if (failed(parser.parseToken(Token::Kind::r_paren, "expected ')'")))
786     return nullptr;
787 
788   return expr;
789 }
790 
791 /// Parse the negation expression.
792 ///
793 ///   affine-expr ::= `-` affine-expr
794 AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
795   if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")))
796     return nullptr;
797 
798   AffineExpr operand = parseAffineOperandExpr(lhs);
799   // Since negation has the highest precedence of all ops (including high
800   // precedence ops) but lower than parentheses, we are only going to use
801   // parseAffineOperandExpr instead of parseAffineExpr here.
802   if (!operand)
803     // Extra error message although parseAffineOperandExpr would have
804     // complained. Leads to a better diagnostic.
805     return ((void)parser.emitError("missing operand of negation"), nullptr);
806   return (-1) * operand;
807 }
808 
809 AffineExpr AffineParser::parseAttrUseOrBareIdExpr() {
810   if (llvm::Optional<AffineExpr> attrUse = attrUseCallback())
811     return attrUse.getValue();
812   return parseBareIdExpr();
813 }
814 
815 /// Parse a bare id that may appear in an affine expression.
816 ///
817 ///   affine-expr ::= bare-id
818 AffineExpr AffineParser::parseBareIdExpr() {
819   if (parser.curToken.isNot(Token::Kind::id))
820     return ((void)parser.emitError("expected id"), nullptr);
821 
822   StringRef sRef = parser.curToken.getSpelling();
823   for (auto &list : {dims, symbols}) {
824     for (auto entry : list) {
825       if (entry.first == sRef) {
826         parser.consumeToken(Token::Kind::id);
827         return entry.second;
828       }
829     }
830   }
831 
832   // Not found, check fallback path.
833   AffineExpr expr = bareIdFallback(sRef);
834   if (expr) {
835     parser.consumeToken(Token::Kind::id);
836     return expr;
837   }
838 
839   return ((void)parser.emitError("use of undeclared id"), nullptr);
840 }
841 
842 /// Parse a positive integral constant appearing in an affine expression.
843 ///
844 ///   affine-expr ::= integer-literal
845 AffineExpr AffineParser::parseIntegerExpr() {
846   auto val = parser.curToken.getUInt64IntegerValue();
847   if (!val.hasValue() || (int64_t)val.getValue() < 0)
848     return ((void)parser.emitError("constant too large for index"), nullptr);
849 
850   parser.consumeToken(Token::Kind::integer);
851   return getAffineConstantExpr((int64_t)val.getValue(), parser.context);
852 }
853 
854 /// Parses an expression that can be a valid operand of an affine expression.
855 /// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
856 /// operator, the rhs of which is being parsed. This is used to determine
857 /// whether an error should be emitted for a missing right operand.
858 //  Eg: for an expression without parentheses (like i + j + k + l), each
859 //  of the four identifiers is an operand. For i + j*k + l, j*k is not an
860 //  operand expression, it's an op expression and will be parsed via
861 //  parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
862 //  -l are valid operands that will be parsed by this function.
863 AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
864   switch (parser.curToken.getKind()) {
865   case Token::Kind::id:
866     return parseAttrUseOrBareIdExpr();
867   case Token::Kind::integer:
868     return parseIntegerExpr();
869   case Token::Kind::l_paren:
870     return parseParentheticalExpr();
871   case Token::Kind::minus:
872     return parseNegateExpression(lhs);
873   case Token::Kind::kw_ceildiv:
874   case Token::Kind::kw_floordiv:
875   case Token::Kind::kw_mod:
876   case Token::Kind::plus:
877   case Token::Kind::star:
878     if (lhs)
879       (void)parser.emitError("missing right operand of binary operator");
880     else
881       (void)parser.emitError("missing left operand of binary operator");
882     return nullptr;
883   default:
884     if (lhs)
885       (void)parser.emitError("missing right operand of binary operator");
886     else
887       (void)parser.emitError("expected affine expression");
888     return nullptr;
889   }
890 }
891 
892 /// Parse affine expressions that are bare-id's, integer constants,
893 /// parenthetical affine expressions, and affine op expressions that are a
894 /// composition of those.
895 ///
896 /// All binary op's associate from left to right.
897 ///
898 /// {add, sub} have lower precedence than {mul, div, and mod}.
899 ///
900 /// Add, sub'are themselves at the same precedence level. Mul, floordiv,
901 /// ceildiv, and mod are at the same higher precedence level. Negation has
902 /// higher precedence than any binary op.
903 ///
904 /// llhs: the affine expression appearing on the left of the one being parsed.
905 /// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
906 /// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
907 /// if llhs is non-null; otherwise lhs is returned. This is to deal with left
908 /// associativity.
909 ///
910 /// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
911 /// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
912 /// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
913 AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
914                                                   AffineLowPrecOp llhsOp) {
915   AffineExpr lhs;
916   if (!(lhs = parseAffineOperandExpr(llhs)))
917     return nullptr;
918 
919   // Found an LHS. Deal with the ops.
920   if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
921     if (llhs) {
922       AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
923       return parseAffineLowPrecOpExpr(sum, lOp);
924     }
925     // No LLHS, get RHS and form the expression.
926     return parseAffineLowPrecOpExpr(lhs, lOp);
927   }
928   auto opLoc = parser.curToken.getLoc();
929   if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
930     // We have a higher precedence op here. Get the rhs operand for the llhs
931     // through parseAffineHighPrecOpExpr.
932     AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
933     if (!highRes)
934       return nullptr;
935 
936     // If llhs is null, the product forms the first operand of the yet to be
937     // found expression. If non-null, the op to associate with llhs is llhsOp.
938     AffineExpr expr =
939         llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
940 
941     // Recurse for subsequent low prec op's after the affine high prec op
942     // expression.
943     if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
944       return parseAffineLowPrecOpExpr(expr, nextOp);
945     return expr;
946   }
947   // Last operand in the expression list.
948   if (llhs)
949     return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
950   // No llhs, 'lhs' itself is the expression.
951   return lhs;
952 }
953 
954 /// Parse an affine expression.
955 ///  affine-expr ::= `(` affine-expr `)`
956 ///                | `-` affine-expr
957 ///                | affine-expr `+` affine-expr
958 ///                | affine-expr `-` affine-expr
959 ///                | affine-expr `*` affine-expr
960 ///                | affine-expr `floordiv` affine-expr
961 ///                | affine-expr `ceildiv` affine-expr
962 ///                | affine-expr `mod` affine-expr
963 ///                | bare-id
964 ///                | integer-literal
965 ///
966 /// Additional conditions are checked depending on the production. For eg.,
967 /// one of the operands for `*` has to be either constant/symbolic; the second
968 /// operand for floordiv, ceildiv, and mod has to be a positive integer.
969 AffineExpr AffineParser::parseAffineExpr() {
970   return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
971 }
972 
973 SmallVector<AffineExpr, 4> AffineParser::parseAffineExprs(Token::Kind lDelim,
974                                                           Token::Kind rDelim) {
975   if (failed(parser.parseToken(lDelim,
976                                "expected lDelim at start of affine expr list")))
977     return {};
978 
979   SmallVector<AffineExpr, 4> exprs;
980   auto parseElt = [&]() -> LogicalResult {
981     auto elt = parseAffineExpr();
982     exprs.push_back(elt);
983     return elt ? success() : failure();
984   };
985 
986   if (failed(parser.parseCommaSeparatedListUntil(rDelim, parseElt,
987                                                  /*allowEmptyList=*/true)))
988     llvm_unreachable("Failed AffineExpr parsing");
989 
990   return exprs;
991 }
992 
993 //===----------------------------------------------------------------------===//
994 // TC parsing.
995 //===----------------------------------------------------------------------===//
996 
997 namespace {
998 
999 /// Base class for expressions involved in TC parsing.
1000 struct Expression {
1001   enum class Kind {
1002     Uninitialized = 0,
1003     TensorExpr = 1,
1004     TensorUse = 2,
1005   };
1006 
1007   explicit Expression(Kind k = Kind::Uninitialized) : kind(k) {}
1008   virtual ~Expression() = default;
1009 
1010   operator bool() const { return kind != Kind::Uninitialized; }
1011 
1012   Kind kind;
1013 };
1014 
1015 /// Encodes a tensor use of the form:
1016 ///
1017 ///   affine-expr-list ::= affine-expr (`,` affine-expr)*
1018 ///   tensor-use ::= bare-id `(` `)`
1019 ///                | bare-id `(` affine-expr-list `)`
1020 ///
1021 /// The affine-expr-list is stored as an AffineMap.
1022 struct TensorUse : public Expression {
1023   TensorUse() : TensorUse("", AffineMap()) {}
1024   TensorUse(StringRef name, AffineMap map)
1025       : Expression(Kind::TensorUse), tensorId(name), indexingMap(map) {}
1026 
1027   static bool classof(const Expression *e) {
1028     return e->kind == Kind::TensorUse;
1029   }
1030 
1031   bool operator==(const TensorUse &other) const {
1032     return tensorId == other.tensorId && indexingMap == other.indexingMap;
1033   }
1034 
1035   /// Visitation function. Performs preorder or postorder traversal depending on
1036   /// `PreOrder` and applies `callback` on each node.
1037   template <typename Lambda, bool PreOrder> void visit(Lambda callback) const;
1038 
1039   StringRef tensorId;
1040   AffineMap indexingMap;
1041 };
1042 
1043 /// Encodes a tensor expression of the form:
1044 ///
1045 ///   op-spec ::= bare-id `<` reduction-dims-list `>`
1046 ///             | bare-id
1047 ///   op-arg ::= tensor-expr
1048 ///            | tensor-use
1049 ///   op-arg-list ::= op-arg (`,` op-arg)*
1050 ///   tensor-expr ::= op-spec `(` op-arg-list `)`
1051 ///
1052 /// Underlying op-arg are stored by unique_ptr to base class.
1053 struct TensorExpr : public Expression {
1054   TensorExpr(StringRef name,
1055              SmallVectorImpl<std::unique_ptr<Expression>> &&exprs,
1056              ArrayRef<unsigned> reductionDims)
1057       : Expression(Kind::TensorExpr), operationName(name),
1058         expressions(std::move(exprs)),
1059         reductionDimensions(reductionDims.begin(), reductionDims.end()) {}
1060 
1061   static bool classof(const Expression *e) {
1062     return e->kind == Kind::TensorExpr;
1063   }
1064 
1065   bool operator==(const TensorExpr &other) const {
1066     if (operationName != other.operationName)
1067       return false;
1068     if (expressions.size() != other.expressions.size())
1069       return false;
1070     for (unsigned i = 0, e = expressions.size(); i < e; ++i)
1071       if (*expressions[i] != *other.expressions[i])
1072         return false;
1073     for (unsigned i = 0, e = reductionDimensions.size(); i < e; ++i)
1074       if (reductionDimensions[i] != other.reductionDimensions[i])
1075         return false;
1076     return true;
1077   }
1078 
1079   /// Visitation function. Performs preorder or postorder traversal depending on
1080   /// `PreOrder` and applies `callback` on each node.
1081   template <typename Lambda, bool PreOrder> void visit(Lambda callback) const;
1082 
1083   StringRef operationName;
1084   SmallVector<std::unique_ptr<Expression>, 4> expressions;
1085   SetVector<unsigned> reductionDimensions;
1086 };
1087 
1088 /// This is a specialized parser for a TCDef.
1089 /// This maintains the dims it finds in an eager fashion.
1090 class TCParser {
1091   enum class EagerDiscoveryMode { None = 0, Symbols, Dimensions };
1092 
1093 public:
1094   explicit TCParser(Parser &p);
1095 
1096   /// Uses the AffineParser to parse the affine exprs used in a tensor
1097   /// definition. If `discoveryMode` is set to Symbols (resp. Dimensions), new
1098   /// symbols (resp. dimensions) are added eagerly. Otherwise, an error is
1099   /// emitted on new identifiers.
1100   SmallVector<AffineExpr, 4>
1101   parseAffineExprs(EagerDiscoveryMode discoveryMode, AffineDimList &dims,
1102                    Token::Kind lDelim = Token::Kind::l_paren,
1103                    Token::Kind rDelim = Token::Kind::r_paren);
1104 
1105   /// Parse the information for a tensor def.
1106   /// All the affine-expr must be dimensionless (i.e. contain only expressions
1107   /// involving symbols and constants), but can otherwise contain arbitrary
1108   /// affine expressions.
1109   LogicalResult parseTensorDef(bool isOutput);
1110 
1111   /// Parses a tensor use.
1112   struct ComprehensionParsingState {
1113     /// The number of operands (which includes inputs and outputs) in a
1114     /// comprehension.
1115     size_t numArgs;
1116     AffineDimList dims;
1117     SmallVector<std::unique_ptr<Expression>, 4> expressions;
1118     llvm::DenseMap<TensorUse, unsigned> orderedTensorArgs;
1119   };
1120   LogicalResult parseTensorUse(TensorUse &result,
1121                                ComprehensionParsingState &state);
1122 
1123   /// Parses an attribute definition.
1124   LogicalResult parseAttrDef();
1125 
1126   /// Parses an optional attribute use.
1127   LogicalResult parseAttrUse(AttrUse &result);
1128 
1129   /// Parses a tensor expression.
1130   LogicalResult parseExpression(TensorUse currentDefinition,
1131                                 std::unique_ptr<Expression> &result,
1132                                 ComprehensionParsingState &state);
1133 
1134   /// Parse a single comprehension.
1135   LogicalResult parseOneComprehension(StringRef cppOpName,
1136                                       StringRef linalgOpName,
1137                                       ComprehensionParsingState &state);
1138 
1139   /// Parse and print the information for a TC def.
1140   /// When `gen-ods-decl` is used, this prints the ODS declaration for the TC.
1141   /// When `gen-impl` is used, this prints the C++ implementation for the extra
1142   /// methods defined in ODS (`iterator_types`, `indexing_maps` and
1143   /// `regionBuilder`).
1144   LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os);
1145 
1146   /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
1147   void printODS(llvm::raw_ostream &os, StringRef cppOpName,
1148                 StringRef linalgOpName, ArrayRef<StringRef> interfaces,
1149                 ComprehensionParsingState &state);
1150 
1151   /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
1152   void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
1153                                ComprehensionParsingState &state);
1154 
1155   /// Print methods related to indexing map required attributes.
1156   ///
1157   /// Specifically, this prints the definitions for the following methods:
1158   ///   bool hasDynamicIndexingMaps();
1159   ///   LogicalResult verifyIndexingMapRequiredAttributes();
1160   void printIndexingMapRequiredAttrMethods(llvm::raw_ostream &os,
1161                                            StringRef cppOpName,
1162                                            ComprehensionParsingState &state);
1163 
1164   /// Print the C++ StructuredOpsInterface impl of `indexing_maps`.
1165   void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName,
1166                                   ComprehensionParsingState &state);
1167 
1168   /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
1169   void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
1170                           ComprehensionParsingState &state);
1171 
1172   /// Print the C++ impl for named ops canonicalizers and folders.
1173   void printCanonicalizersAndFolders(llvm::raw_ostream &os,
1174                                      StringRef cppOpName);
1175 
1176 private:
1177   //===--------------------------------------------------------------------===//
1178   // Internal bookkeeping of tensors.
1179   //===--------------------------------------------------------------------===//
1180   struct RegisteredTensor {
1181     StringRef type;
1182     AffineMap shape;
1183     bool isOutput;
1184     AffineMap indexingMap;
1185     unsigned index;
1186   };
1187 
1188   //===--------------------------------------------------------------------===//
1189   // Internal bookkeeping of attributes.
1190   //===--------------------------------------------------------------------===//
1191   struct RegisteredAttr {
1192     StringRef elementType;
1193     SmallVector<uint64_t, 4> vectorDims;
1194     bool isArray;
1195     bool isOptional;
1196 
1197     // Returns the function to get values at the given indices from this
1198     // attribute.
1199     llvm::Optional<std::string> getValueFn(ArrayRef<uint64_t> indices) const;
1200   };
1201 
1202   //===--------------------------------------------------------------------===//
1203   // Per-TC def state.
1204   //===--------------------------------------------------------------------===//
1205   /// Symbols are per TC def.
1206   AffineSymbolList symbols;
1207 
1208   /// Attribute usages in all affine expressions.
1209   SmallVector<AttrUse, 8> attrUses;
1210 
1211   /// Tensors are per TC def.
1212   llvm::StringMap<RegisteredTensor> registeredTensors;
1213   unsigned nextRegisteredTensorIndex;
1214 
1215   /// Attributes are per TC def.
1216   std::map<std::string, RegisteredAttr> registeredAttrs;
1217 
1218   /// A map from AttrUse to AffineExpr symbol.
1219   llvm::StringMap<AffineExpr> registeredAttrUseToSymbol;
1220 
1221   StringRef docString;
1222 
1223   Parser &parser;
1224 };
1225 } // namespace
1226 
1227 namespace llvm {
1228 
1229 template <> struct DenseMapInfo<TensorUse> {
1230   static TensorUse getEmptyKey() { return TensorUse("", AffineMap()); }
1231   static TensorUse getTombstoneKey() {
1232     return TensorUse(DenseMapInfo<StringRef>::getTombstoneKey(),
1233                      DenseMapInfo<AffineMap>::getTombstoneKey());
1234   }
1235   static unsigned getHashValue(const TensorUse &val) {
1236     return ::llvm::hash_value(val.tensorId); // don't care about collisions.
1237   }
1238   static bool isEqual(const TensorUse &LHS, const TensorUse &RHS) {
1239     return LHS == RHS;
1240   }
1241 };
1242 
1243 } // namespace llvm
1244 
1245 //===----------------------------------------------------------------------===//
1246 // Visitation functions.
1247 //===----------------------------------------------------------------------===//
1248 
1249 template <typename Lambda, bool PreOrder>
1250 void visit(const Expression &expr, Lambda callback) {
1251   switch (expr.kind) {
1252   default:
1253     llvm_unreachable("Unexpected kind");
1254   case Expression::Kind::TensorExpr:
1255     static_cast<const TensorExpr &>(expr).visit<Lambda, PreOrder>(callback);
1256     break;
1257   case Expression::Kind::TensorUse:
1258     static_cast<const TensorUse &>(expr).visit<Lambda, PreOrder>(callback);
1259     break;
1260   }
1261 }
1262 
1263 template <typename Lambda>
1264 void visitPreorder(const Expression &expr, Lambda callback) {
1265   visit<Lambda, false>(expr, callback);
1266 }
1267 
1268 template <typename Lambda>
1269 void visitPostorder(Expression &expr, Lambda callback) {
1270   visit<Lambda, true>(expr, callback);
1271 }
1272 
1273 template <typename Lambda, bool PreOrder>
1274 void TensorExpr::visit(Lambda callback) const {
1275   if (!PreOrder)
1276     callback(*this);
1277   for (auto &e : expressions)
1278     ::visit<Lambda, PreOrder>(*e, callback);
1279   if (PreOrder)
1280     callback(*this);
1281 }
1282 
1283 template <typename Lambda, bool PreOrder>
1284 void TensorUse::visit(Lambda callback) const {
1285   callback(*this);
1286 }
1287 
1288 //===----------------------------------------------------------------------===//
1289 // TC parsing functions.
1290 //===----------------------------------------------------------------------===//
1291 TCParser::TCParser(Parser &p)
1292     : symbols(), registeredTensors(), nextRegisteredTensorIndex(0), parser(p) {}
1293 
1294 /// Uses the AffineParser to parse the affine exprs used in a tensor
1295 /// definition. All identifiers are interpreted as symbols, new symbols are
1296 /// added eagerly.
1297 SmallVector<AffineExpr, 4>
1298 TCParser::parseAffineExprs(EagerDiscoveryMode discoveryMode,
1299                            AffineDimList &dims, Token::Kind lDelim,
1300                            Token::Kind rDelim) {
1301   auto createAffineBareId = [&](StringRef sRef) {
1302     AffineExpr expr;
1303     if (discoveryMode == EagerDiscoveryMode::Symbols) {
1304       expr = getAffineSymbolExpr(symbols.size(), parser.context);
1305       symbols.emplace_back(sRef, expr);
1306     } else if (discoveryMode == EagerDiscoveryMode::Dimensions) {
1307       expr = getAffineDimExpr(dims.size(), parser.context);
1308       dims.emplace_back(sRef, expr);
1309     }
1310     return expr;
1311   };
1312 
1313   auto tryToParseAttrUse = [&]() -> llvm::Optional<AffineExpr> {
1314     if (!parser.curToken.is(Token::Kind::id))
1315       return llvm::None;
1316 
1317     StringRef attrName = parser.curToken.getSpelling();
1318     auto it = registeredAttrs.find(attrName.str());
1319     if (it == registeredAttrs.end())
1320       return llvm::None;
1321 
1322     AttrUse result;
1323     if (failed(parseAttrUse(result)))
1324       return llvm::None;
1325 
1326     auto symbolIt = registeredAttrUseToSymbol.find(result.getKey());
1327     if (symbolIt == registeredAttrUseToSymbol.end()) {
1328       result.symbol = getAffineSymbolExpr(symbols.size(), parser.context);
1329       symbols.emplace_back("<attr-use>", result.symbol);
1330       registeredAttrUseToSymbol[result.getKey()] = result.symbol;
1331       attrUses.push_back(result);
1332     } else {
1333       result.symbol = symbolIt->second;
1334     }
1335 
1336     return result.symbol;
1337   };
1338 
1339   AffineParser affineParser(parser, createAffineBareId, tryToParseAttrUse, dims,
1340                             symbols);
1341   return affineParser.parseAffineExprs(lDelim, rDelim);
1342 }
1343 
1344 /// Parse the information for a tensor def of the form:
1345 ///
1346 ///   affine-expr-list ::= affine-expr (`,` affine-expr )*
1347 ///   tensor-typedef ::= type `(` `)`
1348 ///                    | type `(` affine-expr-list `)`
1349 ///   tensor-def ::= bare-id `:` tensor-typedef
1350 LogicalResult TCParser::parseTensorDef(bool isOutput) {
1351   StringRef tensorId = parser.curToken.getSpelling();
1352   if (failed(parser.parseToken(Token::Kind::id, "expected an id")) ||
1353       failed(parser.parseToken(Token::Kind::colon, "expected colon")))
1354     return failure();
1355 
1356   StringRef tensorType = parser.curToken.getSpelling();
1357   if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1358     return failure();
1359 
1360   AffineDimList emptyDims;
1361   auto exprs = parseAffineExprs(EagerDiscoveryMode::Symbols, emptyDims);
1362   assert(emptyDims.empty() && "Unexpected dimension in tensor def");
1363   AffineMap map =
1364       AffineMap::get(/*dimCount=*/0, symbols.size(), exprs, parser.context);
1365 
1366   auto iterBoolPair = registeredTensors.try_emplace(
1367       tensorId, RegisteredTensor{tensorType, map, isOutput, AffineMap(),
1368                                  nextRegisteredTensorIndex++});
1369   (void)iterBoolPair;
1370   assert(iterBoolPair.second && "Could not emplace tensor registration");
1371   LLVM_DEBUG(llvm::dbgs() << "Recorded: " << tensorId << " "
1372                           << "with typeString: " << tensorType << " "
1373                           << "and shape: " << map << "\n");
1374 
1375   return success();
1376 }
1377 
1378 /// Parses a tensor use of the form:
1379 ///
1380 ///   affine-expr-list ::= affine-expr (`,` affine-expr)*
1381 ///   tensor-use ::= bare-id `(` `)`
1382 ///                | bare-id `(` affine-expr-list `)`
1383 LogicalResult TCParser::parseTensorUse(TensorUse &result,
1384                                        ComprehensionParsingState &state) {
1385   StringRef tensorId = parser.curToken.getSpelling();
1386   if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1387     return failure();
1388 
1389   auto exprs = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims);
1390   AffineMap map =
1391       AffineMap::get(state.dims.size(), symbols.size(), exprs, parser.context);
1392   LLVM_DEBUG(llvm::dbgs() << "Use of tensor: " << tensorId << " map: " << map
1393                           << "\n");
1394 
1395   result = TensorUse(tensorId, map);
1396   return success();
1397 }
1398 
1399 /// Parse the information for an attribute def of the form:
1400 ///
1401 ///   affine-expr-list ::= affine-expr (`,` affine-expr )*
1402 ///   attr-id ::= bare-id (`?`)?
1403 ///   dim-list ::= (integer-literal 'x')+
1404 ///   attr-typedef ::= dim-list? type (`[` `]`)?
1405 ///   attr-def ::= attr-id `:` attr-typedef
1406 LogicalResult TCParser::parseAttrDef() {
1407   auto attrLoc = parser.curToken.getLoc();
1408   StringRef attrName = parser.curToken.getSpelling();
1409   if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1410     return failure();
1411   bool isOptional = succeeded(parser.parseOptionalToken(Token::Kind::question));
1412   if (failed(parser.parseToken(Token::Kind::colon, "expected colon")))
1413     return failure();
1414 
1415   // Parse the attribute's type. We don't expect the type to be arbitrary
1416   // complex, so just use this ad-hoc handling here.
1417 
1418   // Parse potential dimension list
1419   SmallVector<uint64_t, 4> vectorDims;
1420   while (parser.curToken.is(Token::Kind::integer)) {
1421     uint64_t value;
1422     if (failed(parser.parseInteger(value)))
1423       return failure();
1424     vectorDims.push_back(value);
1425 
1426     StringRef spelling = parser.curToken.getSpelling();
1427     if (spelling[0] != 'x')
1428       return parser.emitError(parser.curToken.getLoc(),
1429                               "expected 'x' in dimension list");
1430 
1431     // If we had a prefix of 'x', lex the next token immediately after the 'x'.
1432     if (spelling.size() != 1)
1433       parser.lexer.resetPointer(spelling.data() + 1);
1434 
1435     parser.consumeToken();
1436   }
1437 
1438   StringRef elementType = parser.curToken.getSpelling();
1439   if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1440     return failure();
1441 
1442   bool isArray = false;
1443   auto arrayLoc = parser.curToken.getLoc();
1444   if (succeeded(parser.parseOptionalToken(Token::Kind::l_square))) {
1445     isArray = true;
1446     if (failed(parser.parseToken(Token::Kind::r_square, "expected ']'")))
1447       return failure();
1448   }
1449 
1450   if (!vectorDims.empty() && isArray)
1451     return parser.emitError(arrayLoc, "unsupported vector array attribute");
1452 
1453   auto iterBoolPair = registeredAttrs.emplace(
1454       attrName.str(),
1455       RegisteredAttr{elementType, vectorDims, isArray, isOptional});
1456   if (!iterBoolPair.second)
1457     return parser.emitError(attrLoc,
1458                             "Failed to register attribute '" + attrName + "'");
1459 
1460   LLVM_DEBUG(llvm::dbgs() << "Recorded: " << (isOptional ? "[optional]" : "")
1461                           << " " << attrName << " "
1462                           << "with type: " << elementType
1463                           << (isArray ? "[]" : "") << "\n");
1464 
1465   return success();
1466 }
1467 
1468 LogicalResult TCParser::parseAttrUse(AttrUse &result) {
1469   result.attrName = parser.curToken.getSpelling();
1470   if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1471     return failure();
1472 
1473   auto it = registeredAttrs.find(result.attrName.str());
1474   assert(it != registeredAttrs.end());
1475   const RegisteredAttr &attr = it->second;
1476 
1477   if (!attr.vectorDims.empty() || attr.isArray) {
1478     // This is a vector/array attribute. Parse indices for it.
1479     auto indexLoc = parser.curToken.getLoc();
1480 
1481     if (failed(parser.parseToken(Token::Kind::l_square, "expected '['")))
1482       return failure();
1483 
1484     auto parseIndex = [&]() {
1485       uint64_t value;
1486       if (failed(parser.parseInteger(value)))
1487         return failure();
1488       result.indices.push_back(value);
1489       return success();
1490     };
1491     if (failed(parser.parseCommaSeparatedListUntil(
1492             Token::Kind::r_square, parseIndex, /*allowEmptyList=*/false)))
1493       return failure();
1494 
1495     size_t rank = attr.isArray ? 1 : attr.vectorDims.size();
1496     if (result.indices.size() != rank)
1497       return parser.emitError(indexLoc,
1498                               "number of indices mismatch: expected " +
1499                                   std::to_string(rank) + ", but found " +
1500                                   std::to_string(result.indices.size()));
1501   }
1502 
1503   return success();
1504 }
1505 
1506 /// Parses a tensor expression of the form:
1507 ///
1508 ///   op-spec ::= bare-id `<` reduction-dims-list `>`
1509 ///             | bare-id
1510 ///   op-arg ::= tensor-expr
1511 ///            | tensor-use
1512 ///   op-arg-list ::= op-arg (`,` op-arg)*
1513 ///   tensor-expr ::= op-spec `(` op-arg-list `)`
1514 LogicalResult TCParser::parseExpression(TensorUse currentDefinition,
1515                                         std::unique_ptr<Expression> &result,
1516                                         ComprehensionParsingState &state) {
1517   StringRef opOrTensor = parser.curToken.getSpelling();
1518   if (registeredTensors.count(opOrTensor) > 0) {
1519     TensorUse use;
1520     auto res = parseTensorUse(use, state);
1521     if (failed(res))
1522       return res;
1523     result = std::make_unique<TensorUse>(use);
1524     return success();
1525   }
1526 
1527   if (failed(parser.parseToken(Token::Kind::id, "expected an operation")))
1528     return failure();
1529 
1530   // This is an op.
1531   SmallVector<unsigned, 4> reductionDims;
1532   SmallVector<std::unique_ptr<Expression>, 4> expressions;
1533 
1534   // Check if it has a reduction set, discover dimensions eagerly.
1535   if (parser.curToken.is(Token::Kind::lt)) {
1536     auto iters = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims,
1537                                   Token::Kind::lt, Token::Kind::gt);
1538     for (auto iter : iters)
1539       reductionDims.push_back(iter.cast<AffineDimExpr>().getPosition());
1540   }
1541 
1542   auto parseExpr = [&]() -> LogicalResult {
1543     std::unique_ptr<Expression> e;
1544     if (failed(parseExpression(currentDefinition, e, state)))
1545       return failure();
1546     expressions.push_back(std::move(e));
1547     return success();
1548   };
1549   if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")) ||
1550       failed(parser.parseCommaSeparatedListUntil(
1551           Token::Kind::r_paren, parseExpr, /*allowEmptyList=*/true)))
1552     return failure();
1553 
1554   result = std::make_unique<TensorExpr>(opOrTensor, std::move(expressions),
1555                                         reductionDims);
1556 
1557   return success();
1558 }
1559 
1560 //===----------------------------------------------------------------------===//
1561 // Parse and Emit functions.
1562 //===----------------------------------------------------------------------===//
1563 
1564 /// Parse the information for a single comprehension.
1565 ///
1566 ///   tensor-def-list ::= tensor-def (`,` tensor-def)*
1567 ///   tensor-expr-list ::= tensor-expr (`,` tensor-expr)*
1568 ///   comprehension ::= tensor-def-list `=` tensor-expr-list `;`
1569 LogicalResult
1570 TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
1571                                 ComprehensionParsingState &state) {
1572   // 1. Parse LHS of `=`, these become the definitions that appear as the output
1573   // tensors or read/write buffers.
1574   SmallVector<TensorUse, 4> definitions;
1575   auto parseUse = [&]() -> LogicalResult {
1576     TensorUse use;
1577     if (failed(parseTensorUse(use, state)))
1578       return failure();
1579     definitions.push_back(use);
1580     return success();
1581   };
1582   if (failed(parser.parseCommaSeparatedListUntil(Token::Kind::equal, parseUse,
1583                                                  /*allowEmptyList=*/true)))
1584     return failure();
1585 
1586   // 2. Parse RHS of `=`, this becomes the expressions from which we emit
1587   // computations.
1588   unsigned idx = 0;
1589   auto parseExpr = [&]() -> LogicalResult {
1590     std::unique_ptr<Expression> expr;
1591     if (idx >= definitions.size())
1592       return parser.emitError("Fewer LHS definitions than RHS expressions");
1593     if (failed(parseExpression(definitions[idx++], expr, state)))
1594       return failure();
1595     state.expressions.push_back(std::move(expr));
1596     return success();
1597   };
1598   if (failed(parser.parseCommaSeparatedListUntil(
1599           Token::Kind::semicolon, parseExpr, /*allowEmptyList=*/true)))
1600     return failure();
1601   if (idx != definitions.size())
1602     return parser.emitError("Fewer RHS expressions than LHS definitions");
1603 
1604   // 3. Postprocess.
1605   // 3.a. Normalize all maps to the proper state.dims and symbols counts.
1606   SmallVector<TensorUse, 4> allUses;
1607   allUses.reserve(registeredTensors.size());
1608   for (auto &def : definitions)
1609     allUses.push_back(def);
1610   for (auto &pExpr : state.expressions)
1611     visitPostorder(*pExpr, [&](const Expression &e) {
1612       if (auto *use = dyn_cast<TensorUse>(&e))
1613         allUses.push_back(*use);
1614     });
1615   for (auto &use : allUses)
1616     use.indexingMap =
1617         AffineMap::get(state.dims.size(), symbols.size(),
1618                        use.indexingMap.getResults(), parser.context);
1619 
1620   // 3.b. Traverse definitions
1621   llvm::DenseSet<StringRef> seenDefs;
1622   for (auto &def : definitions) {
1623     if (seenDefs.count(def.tensorId) > 0)
1624       return parser.emitError("Unexpected multi-write to a single tensor");
1625     seenDefs.insert(def.tensorId);
1626     auto tensorIter = registeredTensors.find(def.tensorId);
1627     assert(tensorIter != registeredTensors.end() && "unregistered tensor");
1628     auto &tensor = tensorIter->getValue();
1629     tensor.indexingMap = def.indexingMap;
1630     state.orderedTensorArgs[def] = tensor.index;
1631   }
1632 
1633   bool failed = false;
1634   for (auto &pExpr : state.expressions)
1635     visitPostorder(*pExpr, [&](const Expression &e) {
1636       auto *pUse = dyn_cast<TensorUse>(&e);
1637       if (failed || !pUse)
1638         return;
1639       auto &use = *pUse;
1640       LLVM_DEBUG(llvm::dbgs()
1641                  << "\nuse: " << use.tensorId << " map: " << use.indexingMap);
1642       auto tensorIter = registeredTensors.find(use.tensorId);
1643       assert(tensorIter != registeredTensors.end() && "unregistered tensor");
1644       auto &tensor = tensorIter->getValue();
1645       if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0 &&
1646           tensor.indexingMap.getResults() != use.indexingMap.getResults()) {
1647         LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap);
1648         (void)parser.emitError(
1649             "Unexpected multi-read of a tensor with different accesses");
1650         failed = true;
1651         return;
1652       }
1653       seenDefs.insert(use.tensorId);
1654       tensor.indexingMap = use.indexingMap;
1655       state.orderedTensorArgs[use] = tensor.index;
1656     });
1657   // If more than one definitions are less. They are shaped-only operand, which
1658   // are used to define reduction loops. For now, only accept exactly one
1659   // shaped-only operand.
1660   if (state.numArgs > seenDefs.size() + 1) {
1661     failed = true;
1662   } else if (state.numArgs == seenDefs.size() + 1) {
1663     for (auto &tensorIter : registeredTensors) {
1664       auto &tensor = tensorIter.getValue();
1665       if (tensor.indexingMap)
1666         continue;
1667       if (auto *pTensorExpr =
1668               dyn_cast<TensorExpr>(state.expressions[0].get())) {
1669         SmallVector<AffineExpr, 4> exprs;
1670         for (auto dim : pTensorExpr->reductionDimensions)
1671           exprs.push_back(getAffineDimExpr(dim, parser.context));
1672         tensor.indexingMap = AffineMap::get(state.dims.size(), symbols.size(),
1673                                             exprs, parser.context);
1674       }
1675     }
1676   }
1677   if (failed)
1678     return failure();
1679 
1680   return success();
1681 }
1682 
1683 /// Parse and print the information for a ODS def.
1684 ///
1685 ///   tensor-def-list ::= tensor-def (`,` tensor-def )*
1686 ///   attr-def-list ::= attr-def (`,` attr-def )*
1687 ///
1688 ///   comprehension-list ::= comprehension comprehension*
1689 ///
1690 ///   tc-attr-def ::= `attr` `(` attr-def-list `)`
1691 ///   tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
1692 ///     (tc-attr-def)?
1693 ///     `{` comprehension-list `}`
1694 ///
1695 ///   implements-interface ::=
1696 ///     `implements_interface` `<` bare-id (`,` bare-id)* `>` `:` tc-def
1697 ///
1698 ///   ods-def ::= `ods_def` `<` bare-id `>`
1699 ///               (implements-interface)? `:`
1700 ///               tc-def
1701 ///
1702 /// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e.
1703 /// contain only expressions involving symbols and constants), but can
1704 /// otherwise contain arbitrary affine expressions.
1705 LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
1706   // Parse ods-def header (including C++ op name)
1707   if (failed(parser.parseToken(Token::Kind::kw_ods_def,
1708                                "expected 'ods_def' to define a TC ODS")) ||
1709       failed(parser.parseToken(Token::Kind::lt, "expected '<'")))
1710     return failure();
1711   StringRef cppOpName = parser.curToken.getSpelling();
1712   LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing ODS: " << cppOpName << "\n");
1713   if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
1714       failed(parser.parseToken(Token::Kind::gt, "expected '>'")))
1715     return failure();
1716 
1717   // Parse optional implements-interface header (including C++ op names)
1718   SmallVector<StringRef> interfaces;
1719   bool implementsInterface = succeeded(
1720       parser.parseOptionalToken(Token::Kind::kw_implements_interface));
1721   if (implementsInterface) {
1722     auto parseInterfaceString = [&]() -> LogicalResult {
1723       StringRef interfaceName = parser.curToken.getSpelling();
1724       if (failed(parser.parseToken(Token::Kind::id, "expected id")))
1725         return failure();
1726       interfaces.push_back(interfaceName);
1727       return success();
1728     };
1729     if (failed(parser.parseToken(Token::Kind::lt, "expected '<'")) ||
1730         failed(parser.parseCommaSeparatedListUntil(
1731             Token::Kind::gt, parseInterfaceString, /*allowEmptyList=*/false)))
1732       return failure();
1733   }
1734 
1735   // Parse column.
1736   if (failed(parser.parseToken(Token::Kind::colon, "expected ':'")))
1737     return failure();
1738 
1739   // Parse TC op name.
1740   if (failed(parser.parseToken(Token::Kind::kw_def,
1741                                "expected 'def' to define a TC")))
1742     return failure();
1743   StringRef tcName = parser.curToken.getSpelling();
1744   LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n");
1745 
1746   // Parse input/output tensor definitions
1747   if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
1748       failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
1749     return failure();
1750 
1751   auto parseInputDef = [&]() -> LogicalResult {
1752     return parseTensorDef(/*isOutput=*/false);
1753   };
1754   if (failed(parser.parseCommaSeparatedListUntil(
1755           Token::Kind::r_paren, parseInputDef, /*allowEmptyList=*/false)))
1756     return failure();
1757 
1758   if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")) ||
1759       failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
1760       failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
1761     return failure();
1762   auto parseOutputDef = [&]() -> LogicalResult {
1763     return parseTensorDef(/*isOutput=*/true);
1764   };
1765   if (failed(parser.parseCommaSeparatedListUntil(
1766           Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false)))
1767     return failure();
1768 
1769   // Parse optional attribute definitions
1770   if (succeeded(parser.parseOptionalToken(Token::Kind::kw_attr_def))) {
1771     if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
1772       return failure();
1773     if (failed(parser.parseCommaSeparatedListUntil(
1774             Token::Kind::r_paren, std::bind(&TCParser::parseAttrDef, this),
1775             /*allowEmptyList=*/false)))
1776       return failure();
1777   }
1778 
1779   // Parse optional doc string
1780   if (parser.curToken.is(Token::Kind::doc_str)) {
1781     docString = parser.curToken.getSpelling();
1782     parser.consumeToken();
1783     LLVM_DEBUG(llvm::dbgs()
1784                << "parsed doc string: '''" << docString << "'''\n");
1785   }
1786 
1787   // Since we don't declare symbols separately, we discover them eagerly: each
1788   // newly encountered id in a tensor shape expression is treated as a new
1789   // symbolic. At this point, all tensors have been parsed and all the symbols
1790   // that could be discovered eagerly are now known. Resize all AffineMaps to
1791   // normalize the number of eagerly discovered symbols.
1792   for (auto &tensor : registeredTensors) {
1793     auto &map = tensor.getValue().shape;
1794     map = AffineMap::get(/*dimCount=*/0, symbols.size(), map.getResults(),
1795                          parser.context);
1796   }
1797 
1798   if (failed(parser.parseToken(Token::Kind::l_brace, "expected '{'")))
1799     return failure();
1800 
1801   SmallVector<ComprehensionParsingState, 4> perComprehensionStates;
1802   while (parser.curToken.isNot(Token::Kind::r_brace)) {
1803     perComprehensionStates.push_back(ComprehensionParsingState());
1804     perComprehensionStates.back().numArgs = registeredTensors.size();
1805     if (failed(parseOneComprehension(cppOpName, tcName,
1806                                      perComprehensionStates.back())))
1807       return failure();
1808   };
1809   if (failed(parser.parseToken(Token::Kind::r_brace, "expected '}'")))
1810     return failure();
1811 
1812   // Print.
1813   auto nComprehensions = perComprehensionStates.size();
1814   if (nComprehensions != 1)
1815     return parser.emitError("only 1 comprehension supported for now, got: " +
1816                             llvm::Twine(nComprehensions));
1817   if (genODSDecl) {
1818     auto &state = perComprehensionStates.back();
1819     printODS(os, cppOpName, tcName, interfaces, state);
1820     os << "\n";
1821   }
1822   if (genODSImpl) {
1823     auto &state = perComprehensionStates.back();
1824     std::string extraMethods;
1825     llvm::raw_string_ostream ss(extraMethods);
1826     printReferenceIterators(ss, cppOpName, state);
1827     printIndexingMapRequiredAttrMethods(ss, cppOpName, state);
1828     printReferenceIndexingMaps(ss, cppOpName, state);
1829     printRegionBuilder(ss, cppOpName, state);
1830     printCanonicalizersAndFolders(ss, cppOpName);
1831     ss.flush();
1832     os << extraMethods << "\n";
1833   }
1834 
1835   return success();
1836 }
1837 
1838 //===----------------------------------------------------------------------===//
1839 // Printing functions
1840 //===----------------------------------------------------------------------===//
1841 
1842 /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
1843 void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1844                         StringRef linalgOpName, ArrayRef<StringRef> interfaces,
1845                         ComprehensionParsingState &state) {
1846   SmallVector<std::string, 4> attributes;
1847   for (const auto &attr : registeredAttrs) {
1848     llvm::StringRef name = attr.first;
1849 
1850     llvm::StringRef elementType = attr.second.elementType;
1851     std::string odsType = llvm::StringSwitch<std::string>(elementType)
1852                               .Case("f32", "F32")
1853                               .Case("i32", "I32")
1854                               .Case("i64", "I64")
1855                               .Default("");
1856     if (odsType.empty()) {
1857       (void)parser.emitError(
1858           "unimplemented support for attribute element type: " + elementType);
1859       return;
1860     }
1861 
1862     const auto &dims = attr.second.vectorDims;
1863     if (!dims.empty()) {
1864       // Vector case
1865       SmallVector<std::string, 4> dimStrs;
1866       for (uint64_t dim : dims)
1867         dimStrs.push_back(std::to_string(dim));
1868       odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType,
1869                               llvm::join(dimStrs, ", "));
1870     } else if (attr.second.isArray) {
1871       // Array case
1872       odsType = llvm::formatv("{0}ArrayAttr", odsType);
1873     } else {
1874       // Scalar case
1875       odsType = llvm::formatv("{0}Attr", odsType);
1876     }
1877 
1878     if (attr.second.isOptional)
1879       odsType = llvm::formatv("OptionalAttr<{0}>", odsType);
1880 
1881     attributes.push_back(llvm::formatv("{0}:${1}", odsType, name));
1882   }
1883 
1884   std::string attrList = llvm::join(attributes, ",\n");
1885   if (!attrList.empty())
1886     attrList = ",\n" + attrList;
1887 
1888   // Template for Linalg named ops' ODS definitions. Parameters:
1889   // {0}: ODS/C++ op name
1890   // {1}: assembly op mnemonic
1891   // {2}: op interface list
1892   // {3}: documentation (summary + description)
1893   // {4}: op attribute list
1894   // {5}: the number of arguments for the op region
1895   // {6}: builder methods taking standalone attribute parameters
1896   // {7}: additional methods for attributes used by indexing maps
1897   const char *header = R"FMT(  def {0} : LinalgStructuredBase_Op<"{1}", [
1898     AttrSizedOperandSegments,
1899     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1900     SingleBlockImplicitTerminator<"YieldOp">
1901     /*extraInterfaces=*/{2}]> {
1902       {3}
1903       let arguments = (ins
1904         Variadic<AnyShaped>:$inputs,
1905         Variadic<AnyShaped>:$outputs{4}
1906       );
1907       let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
1908       let regions = (region AnyRegion:$region);
1909 
1910       let skipDefaultBuilders = 1;
1911       let builders = [
1912         OpBuilder<
1913         (ins "ValueRange":$inputs, "ValueRange":$outputs),
1914         [{{
1915           $_state.addOperands(inputs);
1916           $_state.addOperands(outputs);
1917           $_state.addAttribute(
1918             "operand_segment_sizes",
1919             $_builder.getI32VectorAttr({{
1920               static_cast<int32_t>(inputs.size()),
1921               static_cast<int32_t>(outputs.size())}));
1922           createAndFillStructuredOpRegion<{0}>(
1923             $_builder,
1924             $_state,
1925             TypeRange(inputs),
1926             TypeRange(outputs));
1927         }]>,
1928         OpBuilder<
1929         (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1930              "ValueRange":$outputs),
1931         [{{
1932           $_state.addOperands(inputs);
1933           $_state.addOperands(outputs);
1934           $_state.addTypes(resultTensorTypes);
1935           $_state.addAttribute(
1936             "operand_segment_sizes",
1937             $_builder.getI32VectorAttr({{
1938               static_cast<int32_t>(inputs.size()),
1939               static_cast<int32_t>(outputs.size())}));
1940           createAndFillStructuredOpRegion<{0}>(
1941             $_builder,
1942             $_state,
1943             TypeRange(inputs),
1944             TypeRange(outputs));
1945         }]>,
1946         OpBuilder<
1947         (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
1948              CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
1949         [{{
1950           $_state.addOperands(operands);
1951           $_state.addAttributes(attributes);
1952           $_state.addTypes(resultTensorTypes);
1953           (void)$_state.addRegion();
1954         }]>
1955         {6}
1956       ];
1957       let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
1958       let parser = [{{
1959         return ::parseNamedStructuredOp<{0}>(parser, result);
1960       }];
1961       let hasFolder = 1;
1962 
1963       let extraClassDeclaration = structuredOpsBaseDecls # [{{
1964         // Auto-generated.
1965         ArrayAttr iterator_types();
1966         ArrayAttr indexing_maps();
1967         static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
1968         static std::function<void(ImplicitLocOpBuilder &b, Block &)>
1969         getRegionBuilder() {{
1970           return regionBuilder;
1971         }
1972 
1973         // Generic methods.
1974         static unsigned getNumRegionArgs() {{ return {5}; }
1975         std::string getLibraryCallName() {{
1976           return generateLibraryCallName(getOperation());
1977         }
1978 
1979         {7}
1980       }];
1981   })FMT";
1982 
1983   // Generate the list of extra implemented interfaces.
1984   std::string interfaceNameList;
1985   if (!interfaces.empty()) {
1986     llvm::raw_string_ostream ss(interfaceNameList);
1987     ss << ", "; // Leading comma to concat to existing list of interfaces.
1988     llvm::interleaveComma(interfaces, ss);
1989     ss.flush();
1990   }
1991 
1992   // Generate documentation.
1993   std::string doc;
1994   if (!docString.empty()) {
1995     const char *docFmt = R"FMT(
1996       let summary = [{ {0} }];
1997       let description = [{
1998         {1}
1999       }];
2000     )FMT";
2001 
2002     StringRef summary, description;
2003     std::tie(summary, description) = docString.trim().split('\n');
2004     doc = llvm::formatv(docFmt, summary.trim(), description.trim());
2005   }
2006 
2007   // Generate an additional builder that has parameters for attributes.
2008   std::string attrBuilder;
2009   if (!registeredAttrs.empty()) {
2010     SmallVector<std::string, 4> attrParams, attrStmts;
2011     for (const auto &attr : registeredAttrs) {
2012       llvm::StringRef name = attr.first;
2013       attrParams.push_back(llvm::formatv("\"Attribute\":${0}", name));
2014       attrStmts.push_back(
2015           llvm::formatv("$_state.addAttribute(\"{0}\", {0});", name));
2016     }
2017     std::string attrParamsList = llvm::join(attrParams, ", ");
2018     std::string attrStmtsList = llvm::join(attrStmts, "\n");
2019 
2020     const char *builderFmt = R"FMT(
2021       , OpBuilder<
2022       (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
2023            "ValueRange":$outputs, {1}),
2024       [{{
2025         $_state.addOperands(inputs);
2026         $_state.addOperands(outputs);
2027         $_state.addTypes(resultTensorTypes);
2028         $_state.addAttribute(
2029           "operand_segment_sizes",
2030           $_builder.getI32VectorAttr({{
2031             static_cast<int32_t>(inputs.size()),
2032             static_cast<int32_t>(outputs.size())}));
2033         createAndFillStructuredOpRegion<{0}>(
2034           $_builder,
2035           $_state,
2036           TypeRange(inputs),
2037           TypeRange(outputs));
2038         {2}
2039       }]>
2040     )FMT";
2041     attrBuilder =
2042         llvm::formatv(builderFmt, cppOpName, attrParamsList, attrStmtsList);
2043   }
2044 
2045   std::string attrMethods;
2046   if (!registeredAttrs.empty()) {
2047     attrMethods = R"(
2048       bool hasDynamicIndexingMaps();
2049       LogicalResult verifyIndexingMapRequiredAttributes();
2050     )";
2051   }
2052 
2053   // Finally put everything together.
2054   os << llvm::formatv(header, cppOpName, linalgOpName, interfaceNameList, doc,
2055                       attrList, state.numArgs, attrBuilder, attrMethods);
2056 }
2057 
2058 /// Print the C++ StructuredOpsInterface impl of `iterator_types`.
2059 void TCParser::printReferenceIterators(llvm::raw_ostream &os,
2060                                        StringRef cppOpName,
2061                                        ComprehensionParsingState &state) {
2062   const char *referenceReferenceIteratorsFmt =
2063       R"FMT(
2064     ArrayAttr {0}::iterator_types() {
2065       return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef, 8>{{ {1} });
2066     })FMT";
2067 
2068   std::string iteratorsStr;
2069   llvm::raw_string_ostream ss(iteratorsStr);
2070   unsigned pos = 0;
2071   llvm::interleaveComma(
2072       state.dims, ss, [&](std::pair<StringRef, AffineExpr> p) {
2073         bool reduction = false;
2074         for (auto &expr : state.expressions) {
2075           visitPostorder(*expr, [&](const Expression &e) {
2076             if (auto *pTensorExpr = dyn_cast<TensorExpr>(&e)) {
2077               if (pTensorExpr->reductionDimensions.count(pos) > 0)
2078                 reduction = true;
2079             }
2080           });
2081           if (reduction)
2082             break;
2083         }
2084         ss << (reduction ? "getReductionIteratorTypeName()"
2085                          : "getParallelIteratorTypeName()");
2086         pos++;
2087       });
2088   ss.flush();
2089 
2090   os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr);
2091 }
2092 
2093 void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
2094                                              StringRef cppOpName) {
2095   const char *foldersFmt = R"FMT(
2096     LogicalResult {0}::fold(ArrayRef<Attribute>,
2097                             SmallVectorImpl<OpFoldResult> &) {{
2098       return foldMemRefCast(*this);
2099     }
2100     void {0}::getEffects(SmallVectorImpl<
2101         SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
2102       SmallVector<Value> inputBuffers = getInputBufferOperands();
2103       SmallVector<Value> outputBuffers = getOutputBufferOperands();
2104       getGenericEffectsImpl(effects,
2105         getOperation()->getResults(), inputBuffers, outputBuffers);
2106     })FMT";
2107   os << llvm::formatv(foldersFmt, cppOpName);
2108 }
2109 
2110 // Prints methods for querying whether the current named op has attributes that
2111 // are used by its indexing maps and for verifying those attributes have the
2112 // expected type.
2113 void TCParser::printIndexingMapRequiredAttrMethods(
2114     llvm::raw_ostream &os, StringRef cppOpName,
2115     ComprehensionParsingState &state) {
2116   // If there are no attribute used by the whole definition, then we are done.
2117   if (registeredAttrs.empty())
2118     return;
2119 
2120   // Otherwise, go through each attribute and generate code to verify it's
2121   // valid per the spec.
2122   SmallVector<std::string, 4> attributes;
2123   for (const auto &attr : registeredAttrs) {
2124     if (attr.second.isOptional)
2125       continue;
2126 
2127     llvm::StringRef name = attr.first;
2128     llvm::StringRef elementType = attr.second.elementType;
2129     const auto &dims = attr.second.vectorDims;
2130 
2131     // Get the method call to check the element type is of the expected kind.
2132     std::string elemTypeCheck = llvm::StringSwitch<std::string>(elementType)
2133                                     .Case("f32", "isF32()")
2134                                     .Case("i32", "isInteger(32)")
2135                                     .Case("i64", "isInteger(64)")
2136                                     .Default("");
2137     if (elemTypeCheck.empty()) {
2138       (void)parser.emitError(
2139           "unimplemented support for attribute element type: " + elementType);
2140       return;
2141     }
2142 
2143     // Scalar case.
2144     if (dims.empty() && !attr.second.isArray) {
2145       const char *attrFmt = R"FMT(
2146         if (auto attr = op->getAttr("{0}")) {{
2147           if (!attr.getType().{1}) return op->emitError(
2148             "incorrect type for indexing map required attribute '{0}'");
2149         } else {{
2150           return op->emitError(
2151             "missing indexing map required attribute '{0}'");
2152         }
2153       )FMT";
2154 
2155       attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck));
2156       continue;
2157     }
2158 
2159     // Vector case.
2160     if (!dims.empty()) {
2161       SmallVector<std::string, 4> dimStrs;
2162       for (uint64_t dim : dims)
2163         dimStrs.push_back(std::to_string(dim));
2164 
2165       const char *attrFmt = R"FMT(
2166         if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
2167           if (!attr.getType().getElementType().{1}) return op->emitError(
2168             "incorrect element type for indexing map required attribute '{0}'");
2169           if (attr.getType().getShape() != ArrayRef<int64_t>{{ {2} })
2170             return op->emitError(
2171               "incorrect shape for indexing map required attribute '{0}'");
2172         } else {
2173           return op->emitError(
2174             "missing indexing map required attribute '{0}'");
2175         }
2176       )FMT";
2177 
2178       attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck,
2179                                          llvm::join(dimStrs, ", ")));
2180       continue;
2181     }
2182 
2183     // Array case.
2184     {
2185       const char *attrFmt = R"FMT(
2186         if (auto attr = op->getAttrOfType<ArrayAttr>("{0}")) {{
2187           for (Attribute element : attr) {{
2188             if (!element.getType().{1}) return emitError(
2189               "incorrect element type for indexing map required attribute '{0}'");
2190           }
2191         } else {{
2192           return op->emitError(
2193             "missing indexing map required attribute '{0}'");
2194         }
2195       )FMT";
2196 
2197       attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck));
2198     }
2199   }
2200 
2201   const char *methodFmt = R"FMT(
2202   bool {0}::hasDynamicIndexingMaps() {{ return true; }
2203 
2204   LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
2205     Operation *op = getOperation();
2206     {1}
2207     return success();
2208   }
2209   )FMT";
2210 
2211   // Print everything out.
2212   os << llvm::formatv(methodFmt, cppOpName, llvm::join(attributes, "\n"));
2213 }
2214 
2215 /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
2216 void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
2217                                           StringRef cppOpName,
2218                                           ComprehensionParsingState &state) {
2219   // 1. Generic string template for specifying reference indexing maps.
2220   const char *referenceIndexingMapsFmt =
2221       R"FMT(
2222   // This is temporary until we transition out of manually specified ops that
2223   // should be auto-generated with linalg-ods-gen.
2224   ArrayAttr {0}::indexing_maps() {
2225     MLIRContext *context = getContext();
2226     AffineExpr {1};
2227     bindDims(context, {1});
2228     {2}
2229     return Builder(context).getAffineMapArrayAttr({ {3} });
2230   })FMT";
2231 
2232   // 2. Print a comma-separated list of identifiers for the AffineExpr in
2233   // `state.dims`. These will replace the `{1}` placeholder in both
2234   // `AffineExpr {1}` and `bindDims(context, {1})` ensuring the AffineExpr
2235   // identifiers are bound in the right order to the proper AffineDimExpr.
2236   std::string dimsStr;
2237   llvm::raw_string_ostream ss(dimsStr);
2238   llvm::interleaveComma(
2239       state.dims, ss,
2240       [&](std::pair<StringRef, AffineExpr> p) { ss << p.second; });
2241   ss.flush();
2242 
2243   // 3. Get the list of affine maps for each input/output. The AffineExpr use
2244   // the common arithmetic operators on AffineExpr. These affine maps will
2245   // replace the `{2}` placeholder.
2246   std::string mapsStr;
2247   llvm::raw_string_ostream mapsStringStream(mapsStr);
2248 
2249   // Create a list of all symbols.
2250   SmallVector<std::string, 4> symbolReplacements;
2251   symbolReplacements.reserve(symbols.size());
2252   for (unsigned i = 0; i < symbols.size(); ++i) {
2253     const char *symFmt =
2254         "\n\tauto s{0} = getAffineSymbolExpr({0}, context); (void)s{0};";
2255     mapsStringStream << llvm::formatv(symFmt, i);
2256     symbolReplacements.push_back(llvm::formatv("s{0}", i));
2257   }
2258 
2259   // Create the affine constant expressions to replace symbols for attributes.
2260   for (auto attrUse : llvm::enumerate(attrUses)) {
2261     StringRef attrName = attrUse.value().attrName;
2262     auto it = registeredAttrs.find(attrName.str());
2263     assert(it != registeredAttrs.end() && "uses should point to valid attr!");
2264     llvm::Optional<std::string> getValueFn =
2265         it->second.getValueFn(attrUse.value().indices);
2266     if (!getValueFn) {
2267       (void)parser.emitError("unimplemented getValueFn for attribute: " +
2268                              attrName);
2269       return;
2270     }
2271     std::string cstVal = llvm::formatv("{0}(){1}", attrName, *getValueFn);
2272     const char *cstFmt =
2273         "\n\tauto cst{0} = getAffineConstantExpr({1}, context);";
2274     mapsStringStream << llvm::formatv(cstFmt, attrUse.index(), cstVal);
2275 
2276     unsigned position =
2277         attrUse.value().symbol.cast<AffineSymbolExpr>().getPosition();
2278     symbolReplacements[position] = llvm::formatv("cst{0}", attrUse.index());
2279   }
2280 
2281   // For each registered tensor, construct the affine map, replace symbols by
2282   // the corresponding attribute values, and simplify the affine map.
2283   for (auto &tensorIter : registeredTensors) {
2284     auto &tensor = tensorIter.getValue();
2285     auto indexingMap = tensor.indexingMap;
2286     const char *mapFmt =
2287         "\n\tauto map{0} = AffineMap::get({1}, {2}, {3}, context);";
2288 
2289     std::string exprsStr;
2290     llvm::raw_string_ostream exprsStringStream(exprsStr);
2291     exprsStringStream << "{";
2292     llvm::interleaveComma(indexingMap.getResults(), exprsStringStream);
2293     exprsStringStream << "}";
2294     exprsStringStream.flush();
2295     mapsStringStream << llvm::formatv(mapFmt, tensor.index, state.dims.size(),
2296                                       indexingMap.getNumSymbols(), exprsStr);
2297 
2298     std::string replaceSymbolList =
2299         llvm::formatv("{ {0} }", llvm::join(symbolReplacements, ", "));
2300 
2301     // Note that we use `0` as the result affine map's number of symbols. All
2302     // symbols representing attribute usages should be folded away. But there
2303     // may exist additional symbols for tensor dimension upper bounds. Linalg
2304     // does not handle such cases right now. This needs to be fixed once we
2305     // need that.
2306     const char *replaceFmt =
2307         "\n\tmap{0} = map{0}.replaceDimsAndSymbols({{}, {1}, {2}, 0);";
2308     mapsStringStream << llvm::formatv(replaceFmt, tensor.index,
2309                                       replaceSymbolList, state.dims.size());
2310     const char *simplifyFmt = "\n\tmap{0} = simplifyAffineMap(map{0});";
2311     mapsStringStream << llvm::formatv(simplifyFmt, tensor.index);
2312   }
2313 
2314   mapsStringStream.flush();
2315 
2316   SmallVector<std::string, 4> mapList;
2317   mapList.reserve(state.numArgs);
2318   for (auto i : llvm::seq<unsigned>(0, state.numArgs))
2319     mapList.push_back(llvm::formatv("map{0}", i));
2320 
2321   // 4. Apply format to 1. using 2. and 3.
2322   os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr,
2323                       llvm::join(mapList, ", "));
2324 }
2325 
2326 /// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
2327 void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
2328                                   ComprehensionParsingState &state) {
2329   unsigned count = state.numArgs;
2330   llvm::DenseMap<const TensorExpr *, unsigned> subExprsMap;
2331   std::function<void(llvm::raw_ostream & os, const Expression &)> printExpr;
2332   printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void {
2333     if (auto *pUse = dyn_cast<TensorUse>(&e)) {
2334       os << "_" << state.orderedTensorArgs.find(*pUse)->second;
2335       return;
2336     }
2337     auto *pTensorExpr = cast<TensorExpr>(&e);
2338     if (subExprsMap.count(pTensorExpr) > 0) {
2339       os << "_" << subExprsMap[pTensorExpr];
2340     } else {
2341       std::string subExprs;
2342       llvm::raw_string_ostream subExprsStringStream(subExprs);
2343       llvm::interleaveComma(pTensorExpr->expressions, subExprsStringStream,
2344                             [&](const std::unique_ptr<Expression> &e) {
2345                               printExpr(subExprsStringStream, *e);
2346                             });
2347       subExprsStringStream.flush();
2348       const char *tensorExprFmt = "\n    Value _{0} = b.create<{1}>({2});";
2349       os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName,
2350                           subExprs);
2351       subExprsMap[pTensorExpr] = count;
2352     }
2353   };
2354 
2355   const char *regionBuilderFmt = R"FMT(
2356   void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
2357     auto args = block.getArguments();
2358     Value {1};
2359     {2}
2360     b.create<linalg::YieldOp>(ValueRange{ {3} });
2361   })FMT";
2362 
2363   std::string valueHandleStr;
2364   llvm::raw_string_ostream valueHandleStringStream(valueHandleStr);
2365   std::set<unsigned> usedTensorId;
2366   for (const auto &iter : state.orderedTensorArgs)
2367     usedTensorId.insert(iter.second);
2368   llvm::interleaveComma(usedTensorId, valueHandleStringStream, [&](auto idx) {
2369     valueHandleStringStream << "_" << idx << "(args[" << idx << "])";
2370   });
2371 
2372   std::string expressionsStr;
2373   llvm::raw_string_ostream expressionStringStream(expressionsStr);
2374   for (auto &expr : state.expressions)
2375     visitPostorder(*expr, [&](const Expression &e) {
2376       if (e.kind == Expression::Kind::TensorExpr)
2377         printExpr(expressionStringStream, e);
2378     });
2379   expressionStringStream.flush();
2380   substituteOpAliases(expressionsStr);
2381 
2382   std::string yieldStr;
2383   llvm::raw_string_ostream yieldStringStream(yieldStr);
2384   llvm::interleaveComma(state.expressions, yieldStringStream,
2385                         [&](const std::unique_ptr<Expression> &e) {
2386                           printExpr(yieldStringStream, *e);
2387                         });
2388 
2389   valueHandleStringStream.flush();
2390   yieldStringStream.flush();
2391 
2392   os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr,
2393                       expressionsStr, yieldStr);
2394 }
2395 
2396 llvm::Optional<std::string>
2397 TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
2398   if (isArray)
2399     return llvm::None;
2400 
2401   if (!vectorDims.empty()) {
2402     SmallVector<std::string, 4> indexStrs;
2403     for (uint64_t index : indices)
2404       indexStrs.push_back(std::to_string(index));
2405     std::string indexList = llvm::join(indexStrs, ", ");
2406     if (elementType == "f32")
2407       return llvm::formatv(".getValue<float>({ {0} })", indexList).str();
2408     if (elementType == "i32")
2409       return llvm::formatv(".getValue<int>({ {0} })", indexList).str();
2410     if (elementType == "i64")
2411       return llvm::formatv(".getValue<int64_t>({ {0} })", indexList).str();
2412 
2413     return llvm::None;
2414   }
2415 
2416   if (elementType == "f32")
2417     return std::string(".convertToFloat()");
2418   if (elementType == "i32" || elementType == "i64")
2419     return std::string("");
2420   return llvm::None;
2421 }
2422 
2423 /// Iterate over each Tensor Comprehension def.
2424 LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os,
2425                                                   Parser &parser) {
2426   while (parser.curToken.getKind() != Token::Kind::eof) {
2427     TCParser tcParser(parser);
2428     if (failed(tcParser.parseAndEmitODSDef(os)))
2429       return failure();
2430   }
2431   return success();
2432 }
2433 
2434 int main(int argc, char **argv) {
2435   llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen");
2436 
2437   // Set up the input file.
2438   std::string errorMessage;
2439   std::unique_ptr<llvm::MemoryBuffer> file =
2440       mlir::openInputFile(inputFilename, &errorMessage);
2441   if (!file) {
2442     llvm::errs() << errorMessage << "\n";
2443     return 1;
2444   }
2445 
2446   std::unique_ptr<llvm::ToolOutputFile> output =
2447       openOutputFile(outputFilename, &errorMessage);
2448   if (!output) {
2449     llvm::errs() << errorMessage << "\n";
2450     exit(1);
2451   }
2452 
2453   // Include the proper Linalg header for end-to-end tblgen testing without
2454   // resorting to non-portable shell manipulations.
2455   if (testEmitIncludeTdHeader)
2456     output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
2457 
2458   MLIRContext context;
2459   llvm::SourceMgr mgr;
2460   mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
2461   Parser parser(mgr, &context);
2462   (void)parseAndEmitAllTensorComprehensions(output->os(), parser);
2463   output->keep();
2464 
2465   return 0;
2466 }
2467