1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "OpFormatGen.h"
10 #include "mlir/Support/LogicalResult.h"
11 #include "mlir/TableGen/Format.h"
12 #include "mlir/TableGen/GenInfo.h"
13 #include "mlir/TableGen/Interfaces.h"
14 #include "mlir/TableGen/OpClass.h"
15 #include "mlir/TableGen/OpTrait.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27 
28 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
29 
30 using namespace mlir;
31 using namespace mlir::tblgen;
32 
33 static llvm::cl::opt<bool> formatErrorIsFatal(
34     "asmformat-error-is-fatal",
35     llvm::cl::desc("Emit a fatal error if format parsing fails"),
36     llvm::cl::init(true));
37 
38 /// Returns true if the given string can be formatted as a keyword.
canFormatStringAsKeyword(StringRef value)39 static bool canFormatStringAsKeyword(StringRef value) {
40   if (!isalpha(value.front()) && value.front() != '_')
41     return false;
42   return llvm::all_of(value.drop_front(), [](char c) {
43     return isalnum(c) || c == '_' || c == '$' || c == '.';
44   });
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Element
49 //===----------------------------------------------------------------------===//
50 
51 namespace {
52 /// This class represents a single format element.
53 class Element {
54 public:
55   enum class Kind {
56     /// This element is a directive.
57     AttrDictDirective,
58     CustomDirective,
59     FunctionalTypeDirective,
60     OperandsDirective,
61     RegionsDirective,
62     ResultsDirective,
63     SuccessorsDirective,
64     TypeDirective,
65     TypeRefDirective,
66 
67     /// This element is a literal.
68     Literal,
69 
70     /// This element is a whitespace.
71     Newline,
72     Space,
73 
74     /// This element is an variable value.
75     AttributeVariable,
76     OperandVariable,
77     RegionVariable,
78     ResultVariable,
79     SuccessorVariable,
80 
81     /// This element is an optional element.
82     Optional,
83   };
Element(Kind kind)84   Element(Kind kind) : kind(kind) {}
85   virtual ~Element() = default;
86 
87   /// Return the kind of this element.
getKind() const88   Kind getKind() const { return kind; }
89 
90 private:
91   /// The kind of this element.
92   Kind kind;
93 };
94 } // namespace
95 
96 //===----------------------------------------------------------------------===//
97 // VariableElement
98 
99 namespace {
100 /// This class represents an instance of an variable element. A variable refers
101 /// to something registered on the operation itself, e.g. an argument, result,
102 /// etc.
103 template <typename VarT, Element::Kind kindVal>
104 class VariableElement : public Element {
105 public:
VariableElement(const VarT * var)106   VariableElement(const VarT *var) : Element(kindVal), var(var) {}
classof(const Element * element)107   static bool classof(const Element *element) {
108     return element->getKind() == kindVal;
109   }
getVar()110   const VarT *getVar() { return var; }
111 
112 protected:
113   const VarT *var;
114 };
115 
116 /// This class represents a variable that refers to an attribute argument.
117 struct AttributeVariable
118     : public VariableElement<NamedAttribute, Element::Kind::AttributeVariable> {
119   using VariableElement<NamedAttribute,
120                         Element::Kind::AttributeVariable>::VariableElement;
121 
122   /// Return the constant builder call for the type of this attribute, or None
123   /// if it doesn't have one.
getTypeBuilder__anon7e5610690311::AttributeVariable124   Optional<StringRef> getTypeBuilder() const {
125     Optional<Type> attrType = var->attr.getValueType();
126     return attrType ? attrType->getBuilderCall() : llvm::None;
127   }
128 
129   /// Return if this attribute refers to a UnitAttr.
isUnitAttr__anon7e5610690311::AttributeVariable130   bool isUnitAttr() const {
131     return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
132   }
133 };
134 
135 /// This class represents a variable that refers to an operand argument.
136 using OperandVariable =
137     VariableElement<NamedTypeConstraint, Element::Kind::OperandVariable>;
138 
139 /// This class represents a variable that refers to a region.
140 using RegionVariable =
141     VariableElement<NamedRegion, Element::Kind::RegionVariable>;
142 
143 /// This class represents a variable that refers to a result.
144 using ResultVariable =
145     VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
146 
147 /// This class represents a variable that refers to a successor.
148 using SuccessorVariable =
149     VariableElement<NamedSuccessor, Element::Kind::SuccessorVariable>;
150 } // end anonymous namespace
151 
152 //===----------------------------------------------------------------------===//
153 // DirectiveElement
154 
155 namespace {
156 /// This class implements single kind directives.
157 template <Element::Kind type>
158 class DirectiveElement : public Element {
159 public:
DirectiveElement()160   DirectiveElement() : Element(type){};
classof(const Element * ele)161   static bool classof(const Element *ele) { return ele->getKind() == type; }
162 };
163 /// This class represents the `operands` directive. This directive represents
164 /// all of the operands of an operation.
165 using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
166 
167 /// This class represents the `regions` directive. This directive represents
168 /// all of the regions of an operation.
169 using RegionsDirective = DirectiveElement<Element::Kind::RegionsDirective>;
170 
171 /// This class represents the `results` directive. This directive represents
172 /// all of the results of an operation.
173 using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
174 
175 /// This class represents the `successors` directive. This directive represents
176 /// all of the successors of an operation.
177 using SuccessorsDirective =
178     DirectiveElement<Element::Kind::SuccessorsDirective>;
179 
180 /// This class represents the `attr-dict` directive. This directive represents
181 /// the attribute dictionary of the operation.
182 class AttrDictDirective
183     : public DirectiveElement<Element::Kind::AttrDictDirective> {
184 public:
AttrDictDirective(bool withKeyword)185   explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
isWithKeyword() const186   bool isWithKeyword() const { return withKeyword; }
187 
188 private:
189   /// If the dictionary should be printed with the 'attributes' keyword.
190   bool withKeyword;
191 };
192 
193 /// This class represents a custom format directive that is implemented by the
194 /// user in C++.
195 class CustomDirective : public Element {
196 public:
CustomDirective(StringRef name,std::vector<std::unique_ptr<Element>> && arguments)197   CustomDirective(StringRef name,
198                   std::vector<std::unique_ptr<Element>> &&arguments)
199       : Element{Kind::CustomDirective}, name(name),
200         arguments(std::move(arguments)) {}
201 
classof(const Element * element)202   static bool classof(const Element *element) {
203     return element->getKind() == Kind::CustomDirective;
204   }
205 
206   /// Return the name of this optional element.
getName() const207   StringRef getName() const { return name; }
208 
209   /// Return the arguments to the custom directive.
getArguments() const210   auto getArguments() const { return llvm::make_pointee_range(arguments); }
211 
212 private:
213   /// The user provided name of the directive.
214   StringRef name;
215 
216   /// The arguments to the custom directive.
217   std::vector<std::unique_ptr<Element>> arguments;
218 };
219 
220 /// This class represents the `functional-type` directive. This directive takes
221 /// two arguments and formats them, respectively, as the inputs and results of a
222 /// FunctionType.
223 class FunctionalTypeDirective
224     : public DirectiveElement<Element::Kind::FunctionalTypeDirective> {
225 public:
FunctionalTypeDirective(std::unique_ptr<Element> inputs,std::unique_ptr<Element> results)226   FunctionalTypeDirective(std::unique_ptr<Element> inputs,
227                           std::unique_ptr<Element> results)
228       : inputs(std::move(inputs)), results(std::move(results)) {}
getInputs() const229   Element *getInputs() const { return inputs.get(); }
getResults() const230   Element *getResults() const { return results.get(); }
231 
232 private:
233   /// The input and result arguments.
234   std::unique_ptr<Element> inputs, results;
235 };
236 
237 /// This class represents the `type` directive.
238 class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> {
239 public:
TypeDirective(std::unique_ptr<Element> arg)240   TypeDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
getOperand() const241   Element *getOperand() const { return operand.get(); }
242 
243 private:
244   /// The operand that is used to format the directive.
245   std::unique_ptr<Element> operand;
246 };
247 
248 /// This class represents the `type_ref` directive.
249 class TypeRefDirective
250     : public DirectiveElement<Element::Kind::TypeRefDirective> {
251 public:
TypeRefDirective(std::unique_ptr<Element> arg)252   TypeRefDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
getOperand() const253   Element *getOperand() const { return operand.get(); }
254 
255 private:
256   /// The operand that is used to format the directive.
257   std::unique_ptr<Element> operand;
258 };
259 } // namespace
260 
261 //===----------------------------------------------------------------------===//
262 // LiteralElement
263 
264 namespace {
265 /// This class represents an instance of a literal element.
266 class LiteralElement : public Element {
267 public:
LiteralElement(StringRef literal)268   LiteralElement(StringRef literal)
269       : Element{Kind::Literal}, literal(literal) {}
classof(const Element * element)270   static bool classof(const Element *element) {
271     return element->getKind() == Kind::Literal;
272   }
273 
274   /// Return the literal for this element.
getLiteral() const275   StringRef getLiteral() const { return literal; }
276 
277   /// Returns true if the given string is a valid literal.
278   static bool isValidLiteral(StringRef value);
279 
280 private:
281   /// The spelling of the literal for this element.
282   StringRef literal;
283 };
284 } // end anonymous namespace
285 
isValidLiteral(StringRef value)286 bool LiteralElement::isValidLiteral(StringRef value) {
287   if (value.empty())
288     return false;
289   char front = value.front();
290 
291   // If there is only one character, this must either be punctuation or a
292   // single character bare identifier.
293   if (value.size() == 1)
294     return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front);
295 
296   // Check the punctuation that are larger than a single character.
297   if (value == "->")
298     return true;
299 
300   // Otherwise, this must be an identifier.
301   return canFormatStringAsKeyword(value);
302 }
303 
304 //===----------------------------------------------------------------------===//
305 // WhitespaceElement
306 
307 namespace {
308 /// This class represents a whitespace element, e.g. newline or space. It's a
309 /// literal that is printed but never parsed.
310 class WhitespaceElement : public Element {
311 public:
WhitespaceElement(Kind kind)312   WhitespaceElement(Kind kind) : Element{kind} {}
classof(const Element * element)313   static bool classof(const Element *element) {
314     Kind kind = element->getKind();
315     return kind == Kind::Newline || kind == Kind::Space;
316   }
317 };
318 
319 /// This class represents an instance of a newline element. It's a literal that
320 /// prints a newline. It is ignored by the parser.
321 class NewlineElement : public WhitespaceElement {
322 public:
NewlineElement()323   NewlineElement() : WhitespaceElement(Kind::Newline) {}
classof(const Element * element)324   static bool classof(const Element *element) {
325     return element->getKind() == Kind::Newline;
326   }
327 };
328 
329 /// This class represents an instance of a space element. It's a literal that
330 /// prints or omits printing a space. It is ignored by the parser.
331 class SpaceElement : public WhitespaceElement {
332 public:
SpaceElement(bool value)333   SpaceElement(bool value) : WhitespaceElement(Kind::Space), value(value) {}
classof(const Element * element)334   static bool classof(const Element *element) {
335     return element->getKind() == Kind::Space;
336   }
337 
338   /// Returns true if this element should print as a space. Otherwise, the
339   /// element should omit printing a space between the surrounding elements.
getValue() const340   bool getValue() const { return value; }
341 
342 private:
343   bool value;
344 };
345 } // end anonymous namespace
346 
347 //===----------------------------------------------------------------------===//
348 // OptionalElement
349 
350 namespace {
351 /// This class represents a group of elements that are optionally emitted based
352 /// upon an optional variable of the operation.
353 class OptionalElement : public Element {
354 public:
OptionalElement(std::vector<std::unique_ptr<Element>> && elements,unsigned anchor,unsigned parseStart)355   OptionalElement(std::vector<std::unique_ptr<Element>> &&elements,
356                   unsigned anchor, unsigned parseStart)
357       : Element{Kind::Optional}, elements(std::move(elements)), anchor(anchor),
358         parseStart(parseStart) {}
classof(const Element * element)359   static bool classof(const Element *element) {
360     return element->getKind() == Kind::Optional;
361   }
362 
363   /// Return the nested elements of this grouping.
getElements() const364   auto getElements() const { return llvm::make_pointee_range(elements); }
365 
366   /// Return the anchor of this optional group.
getAnchor() const367   Element *getAnchor() const { return elements[anchor].get(); }
368 
369   /// Return the index of the first element that needs to be parsed.
getParseStart() const370   unsigned getParseStart() const { return parseStart; }
371 
372 private:
373   /// The child elements of this optional.
374   std::vector<std::unique_ptr<Element>> elements;
375   /// The index of the element that acts as the anchor for the optional group.
376   unsigned anchor;
377   /// The index of the first element that is parsed (is not a
378   /// WhitespaceElement).
379   unsigned parseStart;
380 };
381 } // end anonymous namespace
382 
383 //===----------------------------------------------------------------------===//
384 // OperationFormat
385 //===----------------------------------------------------------------------===//
386 
387 namespace {
388 
389 using ConstArgument =
390     llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
391 
392 struct OperationFormat {
393   /// This class represents a specific resolver for an operand or result type.
394   class TypeResolution {
395   public:
396     TypeResolution() = default;
397 
398     /// Get the index into the buildable types for this type, or None.
getBuilderIdx() const399     Optional<int> getBuilderIdx() const { return builderIdx; }
setBuilderIdx(int idx)400     void setBuilderIdx(int idx) { builderIdx = idx; }
401 
402     /// Get the variable this type is resolved to, or nullptr.
getVariable() const403     const NamedTypeConstraint *getVariable() const {
404       return resolver.dyn_cast<const NamedTypeConstraint *>();
405     }
406     /// Get the attribute this type is resolved to, or nullptr.
getAttribute() const407     const NamedAttribute *getAttribute() const {
408       return resolver.dyn_cast<const NamedAttribute *>();
409     }
410     /// Get the transformer for the type of the variable, or None.
getVarTransformer() const411     Optional<StringRef> getVarTransformer() const {
412       return variableTransformer;
413     }
setResolver(ConstArgument arg,Optional<StringRef> transformer)414     void setResolver(ConstArgument arg, Optional<StringRef> transformer) {
415       resolver = arg;
416       variableTransformer = transformer;
417       assert(getVariable() || getAttribute());
418     }
419 
420   private:
421     /// If the type is resolved with a buildable type, this is the index into
422     /// 'buildableTypes' in the parent format.
423     Optional<int> builderIdx;
424     /// If the type is resolved based upon another operand or result, this is
425     /// the variable or the attribute that this type is resolved to.
426     ConstArgument resolver;
427     /// If the type is resolved based upon another operand or result, this is
428     /// a transformer to apply to the variable when resolving.
429     Optional<StringRef> variableTransformer;
430   };
431 
OperationFormat__anon7e5610690811::OperationFormat432   OperationFormat(const Operator &op)
433       : allOperands(false), allOperandTypes(false), allResultTypes(false) {
434     operandTypes.resize(op.getNumOperands(), TypeResolution());
435     resultTypes.resize(op.getNumResults(), TypeResolution());
436 
437     hasImplicitTermTrait =
438         llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
439           return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
440         });
441   }
442 
443   /// Generate the operation parser from this format.
444   void genParser(Operator &op, OpClass &opClass);
445   /// Generate the parser code for a specific format element.
446   void genElementParser(Element *element, OpMethodBody &body,
447                         FmtContext &attrTypeCtx);
448   /// Generate the c++ to resolve the types of operands and results during
449   /// parsing.
450   void genParserTypeResolution(Operator &op, OpMethodBody &body);
451   /// Generate the c++ to resolve regions during parsing.
452   void genParserRegionResolution(Operator &op, OpMethodBody &body);
453   /// Generate the c++ to resolve successors during parsing.
454   void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
455   /// Generate the c++ to handling variadic segment size traits.
456   void genParserVariadicSegmentResolution(Operator &op, OpMethodBody &body);
457 
458   /// Generate the operation printer from this format.
459   void genPrinter(Operator &op, OpClass &opClass);
460 
461   /// Generate the printer code for a specific format element.
462   void genElementPrinter(Element *element, OpMethodBody &body, Operator &op,
463                          bool &shouldEmitSpace, bool &lastWasPunctuation);
464 
465   /// The various elements in this format.
466   std::vector<std::unique_ptr<Element>> elements;
467 
468   /// A flag indicating if all operand/result types were seen. If the format
469   /// contains these, it can not contain individual type resolvers.
470   bool allOperands, allOperandTypes, allResultTypes;
471 
472   /// A flag indicating if this operation has the SingleBlockImplicitTerminator
473   /// trait.
474   bool hasImplicitTermTrait;
475 
476   /// A map of buildable types to indices.
477   llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
478 
479   /// The index of the buildable type, if valid, for every operand and result.
480   std::vector<TypeResolution> operandTypes, resultTypes;
481 
482   /// The set of attributes explicitly used within the format.
483   SmallVector<const NamedAttribute *, 8> usedAttributes;
484 };
485 } // end anonymous namespace
486 
487 //===----------------------------------------------------------------------===//
488 // Parser Gen
489 
490 /// Returns true if we can format the given attribute as an EnumAttr in the
491 /// parser format.
canFormatEnumAttr(const NamedAttribute * attr)492 static bool canFormatEnumAttr(const NamedAttribute *attr) {
493   Attribute baseAttr = attr->attr.getBaseAttr();
494   const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
495   if (!enumAttr)
496     return false;
497 
498   // The attribute must have a valid underlying type and a constant builder.
499   return !enumAttr->getUnderlyingType().empty() &&
500          !enumAttr->getConstBuilderTemplate().empty();
501 }
502 
503 /// Returns if we should format the given attribute as an SymbolNameAttr.
shouldFormatSymbolNameAttr(const NamedAttribute * attr)504 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
505   return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
506 }
507 
508 /// The code snippet used to generate a parser call for an attribute.
509 ///
510 /// {0}: The name of the attribute.
511 /// {1}: The type for the attribute.
512 const char *const attrParserCode = R"(
513   if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
514     return ::mlir::failure();
515 )";
516 const char *const optionalAttrParserCode = R"(
517   {
518     ::mlir::OptionalParseResult parseResult =
519       parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
520     if (parseResult.hasValue() && failed(*parseResult))
521       return ::mlir::failure();
522   }
523 )";
524 
525 /// The code snippet used to generate a parser call for a symbol name attribute.
526 ///
527 /// {0}: The name of the attribute.
528 const char *const symbolNameAttrParserCode = R"(
529   if (parser.parseSymbolName({0}Attr, "{0}", result.attributes))
530     return ::mlir::failure();
531 )";
532 const char *const optionalSymbolNameAttrParserCode = R"(
533   // Parsing an optional symbol name doesn't fail, so no need to check the
534   // result.
535   (void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes);
536 )";
537 
538 /// The code snippet used to generate a parser call for an enum attribute.
539 ///
540 /// {0}: The name of the attribute.
541 /// {1}: The c++ namespace for the enum symbolize functions.
542 /// {2}: The function to symbolize a string of the enum.
543 /// {3}: The constant builder call to create an attribute of the enum type.
544 /// {4}: The set of allowed enum keywords.
545 /// {5}: The error message on failure when the enum isn't present.
546 const char *const enumAttrParserCode = R"(
547   {
548     ::llvm::StringRef attrStr;
549     ::mlir::NamedAttrList attrStorage;
550     auto loc = parser.getCurrentLocation();
551     if (parser.parseOptionalKeyword(&attrStr, {4})) {
552       ::mlir::StringAttr attrVal;
553       ::mlir::OptionalParseResult parseResult =
554         parser.parseOptionalAttribute(attrVal,
555                                       parser.getBuilder().getNoneType(),
556                                       "{0}", attrStorage);
557       if (parseResult.hasValue()) {{
558         if (failed(*parseResult))
559           return ::mlir::failure();
560         attrStr = attrVal.getValue();
561       } else {
562         {5}
563       }
564     }
565     if (!attrStr.empty()) {
566       auto attrOptional = {1}::{2}(attrStr);
567       if (!attrOptional)
568         return parser.emitError(loc, "invalid ")
569                << "{0} attribute specification: \"" << attrStr << '"';;
570 
571       {0}Attr = {3};
572       result.addAttribute("{0}", {0}Attr);
573     }
574   }
575 )";
576 
577 /// The code snippet used to generate a parser call for an operand.
578 ///
579 /// {0}: The name of the operand.
580 const char *const variadicOperandParserCode = R"(
581   {0}OperandsLoc = parser.getCurrentLocation();
582   if (parser.parseOperandList({0}Operands))
583     return ::mlir::failure();
584 )";
585 const char *const optionalOperandParserCode = R"(
586   {
587     {0}OperandsLoc = parser.getCurrentLocation();
588     ::mlir::OpAsmParser::OperandType operand;
589     ::mlir::OptionalParseResult parseResult =
590                                     parser.parseOptionalOperand(operand);
591     if (parseResult.hasValue()) {
592       if (failed(*parseResult))
593         return ::mlir::failure();
594       {0}Operands.push_back(operand);
595     }
596   }
597 )";
598 const char *const operandParserCode = R"(
599   {0}OperandsLoc = parser.getCurrentLocation();
600   if (parser.parseOperand({0}RawOperands[0]))
601     return ::mlir::failure();
602 )";
603 
604 /// The code snippet used to generate a parser call for a type list.
605 ///
606 /// {0}: The name for the type list.
607 const char *const variadicTypeParserCode = R"(
608   if (parser.parseTypeList({0}Types))
609     return ::mlir::failure();
610 )";
611 const char *const optionalTypeParserCode = R"(
612   {
613     ::mlir::Type optionalType;
614     ::mlir::OptionalParseResult parseResult =
615                                     parser.parseOptionalType(optionalType);
616     if (parseResult.hasValue()) {
617       if (failed(*parseResult))
618         return ::mlir::failure();
619       {0}Types.push_back(optionalType);
620     }
621   }
622 )";
623 const char *const typeParserCode = R"(
624   if (parser.parseType({0}RawTypes[0]))
625     return ::mlir::failure();
626 )";
627 
628 /// The code snippet used to generate a parser call for a functional type.
629 ///
630 /// {0}: The name for the input type list.
631 /// {1}: The name for the result type list.
632 const char *const functionalTypeParserCode = R"(
633   ::mlir::FunctionType {0}__{1}_functionType;
634   if (parser.parseType({0}__{1}_functionType))
635     return ::mlir::failure();
636   {0}Types = {0}__{1}_functionType.getInputs();
637   {1}Types = {0}__{1}_functionType.getResults();
638 )";
639 
640 /// The code snippet used to generate a parser call for a region list.
641 ///
642 /// {0}: The name for the region list.
643 const char *regionListParserCode = R"(
644   {
645     std::unique_ptr<::mlir::Region> region;
646     auto firstRegionResult = parser.parseOptionalRegion(region);
647     if (firstRegionResult.hasValue()) {
648       if (failed(*firstRegionResult))
649         return ::mlir::failure();
650       {0}Regions.emplace_back(std::move(region));
651 
652       // Parse any trailing regions.
653       while (succeeded(parser.parseOptionalComma())) {
654         region = std::make_unique<::mlir::Region>();
655         if (parser.parseRegion(*region))
656           return ::mlir::failure();
657         {0}Regions.emplace_back(std::move(region));
658       }
659     }
660   }
661 )";
662 
663 /// The code snippet used to ensure a list of regions have terminators.
664 ///
665 /// {0}: The name of the region list.
666 const char *regionListEnsureTerminatorParserCode = R"(
667   for (auto &region : {0}Regions)
668     ensureTerminator(*region, parser.getBuilder(), result.location);
669 )";
670 
671 /// The code snippet used to generate a parser call for an optional region.
672 ///
673 /// {0}: The name of the region.
674 const char *optionalRegionParserCode = R"(
675   {
676      auto parseResult = parser.parseOptionalRegion(*{0}Region);
677      if (parseResult.hasValue() && failed(*parseResult))
678        return ::mlir::failure();
679   }
680 )";
681 
682 /// The code snippet used to generate a parser call for a region.
683 ///
684 /// {0}: The name of the region.
685 const char *regionParserCode = R"(
686   if (parser.parseRegion(*{0}Region))
687     return ::mlir::failure();
688 )";
689 
690 /// The code snippet used to ensure a region has a terminator.
691 ///
692 /// {0}: The name of the region.
693 const char *regionEnsureTerminatorParserCode = R"(
694   ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
695 )";
696 
697 /// The code snippet used to generate a parser call for a successor list.
698 ///
699 /// {0}: The name for the successor list.
700 const char *successorListParserCode = R"(
701   {
702     ::mlir::Block *succ;
703     auto firstSucc = parser.parseOptionalSuccessor(succ);
704     if (firstSucc.hasValue()) {
705       if (failed(*firstSucc))
706         return ::mlir::failure();
707       {0}Successors.emplace_back(succ);
708 
709       // Parse any trailing successors.
710       while (succeeded(parser.parseOptionalComma())) {
711         if (parser.parseSuccessor(succ))
712           return ::mlir::failure();
713         {0}Successors.emplace_back(succ);
714       }
715     }
716   }
717 )";
718 
719 /// The code snippet used to generate a parser call for a successor.
720 ///
721 /// {0}: The name of the successor.
722 const char *successorParserCode = R"(
723   if (parser.parseSuccessor({0}Successor))
724     return ::mlir::failure();
725 )";
726 
727 namespace {
728 /// The type of length for a given parse argument.
729 enum class ArgumentLengthKind {
730   /// The argument is variadic, and may contain 0->N elements.
731   Variadic,
732   /// The argument is optional, and may contain 0 or 1 elements.
733   Optional,
734   /// The argument is a single element, i.e. always represents 1 element.
735   Single
736 };
737 } // end anonymous namespace
738 
739 /// Get the length kind for the given constraint.
740 static ArgumentLengthKind
getArgumentLengthKind(const NamedTypeConstraint * var)741 getArgumentLengthKind(const NamedTypeConstraint *var) {
742   if (var->isOptional())
743     return ArgumentLengthKind::Optional;
744   if (var->isVariadic())
745     return ArgumentLengthKind::Variadic;
746   return ArgumentLengthKind::Single;
747 }
748 
749 /// Get the name used for the type list for the given type directive operand.
750 /// 'lengthKind' to the corresponding kind for the given argument.
751 static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) {
752   if (auto *operand = dyn_cast<OperandVariable>(arg)) {
753     lengthKind = getArgumentLengthKind(operand->getVar());
754     return operand->getVar()->name;
755   }
756   if (auto *result = dyn_cast<ResultVariable>(arg)) {
757     lengthKind = getArgumentLengthKind(result->getVar());
758     return result->getVar()->name;
759   }
760   lengthKind = ArgumentLengthKind::Variadic;
761   if (isa<OperandsDirective>(arg))
762     return "allOperand";
763   if (isa<ResultsDirective>(arg))
764     return "allResult";
765   llvm_unreachable("unknown 'type' directive argument");
766 }
767 
768 /// Generate the parser for a literal value.
769 static void genLiteralParser(StringRef value, OpMethodBody &body) {
770   // Handle the case of a keyword/identifier.
771   if (value.front() == '_' || isalpha(value.front())) {
772     body << "Keyword(\"" << value << "\")";
773     return;
774   }
775   body << (StringRef)StringSwitch<StringRef>(value)
776               .Case("->", "Arrow()")
777               .Case(":", "Colon()")
778               .Case(",", "Comma()")
779               .Case("=", "Equal()")
780               .Case("<", "Less()")
781               .Case(">", "Greater()")
782               .Case("{", "LBrace()")
783               .Case("}", "RBrace()")
784               .Case("(", "LParen()")
785               .Case(")", "RParen()")
786               .Case("[", "LSquare()")
787               .Case("]", "RSquare()")
788               .Case("?", "Question()")
789               .Case("+", "Plus()")
790               .Case("*", "Star()");
791 }
792 
793 /// Generate the storage code required for parsing the given element.
genElementParserStorage(Element * element,OpMethodBody & body)794 static void genElementParserStorage(Element *element, OpMethodBody &body) {
795   if (auto *optional = dyn_cast<OptionalElement>(element)) {
796     auto elements = optional->getElements();
797 
798     // If the anchor is a unit attribute, it won't be parsed directly so elide
799     // it.
800     auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
801     Element *elidedAnchorElement = nullptr;
802     if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr())
803       elidedAnchorElement = anchor;
804     for (auto &childElement : elements)
805       if (&childElement != elidedAnchorElement)
806         genElementParserStorage(&childElement, body);
807 
808   } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
809     for (auto &paramElement : custom->getArguments())
810       genElementParserStorage(&paramElement, body);
811 
812   } else if (isa<OperandsDirective>(element)) {
813     body << "  ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
814             "allOperands;\n";
815 
816   } else if (isa<RegionsDirective>(element)) {
817     body << "  ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
818             "fullRegions;\n";
819 
820   } else if (isa<SuccessorsDirective>(element)) {
821     body << "  ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
822 
823   } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
824     const NamedAttribute *var = attr->getVar();
825     body << llvm::formatv("  {0} {1}Attr;\n", var->attr.getStorageType(),
826                           var->name);
827 
828   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
829     StringRef name = operand->getVar()->name;
830     if (operand->getVar()->isVariableLength()) {
831       body << "  ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
832            << name << "Operands;\n";
833     } else {
834       body << "  ::mlir::OpAsmParser::OperandType " << name
835            << "RawOperands[1];\n"
836            << "  ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name
837            << "Operands(" << name << "RawOperands);";
838     }
839     body << llvm::formatv("  ::llvm::SMLoc {0}OperandsLoc;\n"
840                           "  (void){0}OperandsLoc;\n",
841                           name);
842 
843   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
844     StringRef name = region->getVar()->name;
845     if (region->getVar()->isVariadic()) {
846       body << llvm::formatv(
847           "  ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
848           "{0}Regions;\n",
849           name);
850     } else {
851       body << llvm::formatv("  std::unique_ptr<::mlir::Region> {0}Region = "
852                             "std::make_unique<::mlir::Region>();\n",
853                             name);
854     }
855 
856   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
857     StringRef name = successor->getVar()->name;
858     if (successor->getVar()->isVariadic()) {
859       body << llvm::formatv("  ::llvm::SmallVector<::mlir::Block *, 2> "
860                             "{0}Successors;\n",
861                             name);
862     } else {
863       body << llvm::formatv("  ::mlir::Block *{0}Successor = nullptr;\n", name);
864     }
865 
866   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
867     ArgumentLengthKind lengthKind;
868     StringRef name = getTypeListName(dir->getOperand(), lengthKind);
869     if (lengthKind != ArgumentLengthKind::Single)
870       body << "  ::mlir::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
871     else
872       body << llvm::formatv("  ::mlir::Type {0}RawTypes[1];\n", name)
873            << llvm::formatv(
874                   "  ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
875                   name);
876   } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
877     ArgumentLengthKind lengthKind;
878     StringRef name = getTypeListName(dir->getOperand(), lengthKind);
879     // Refer to the previously encountered TypeDirective for name.
880     // Take a `const ::mlir::SmallVector<::mlir::Type, 1> &` in the declaration
881     // to properly track the types that will be parsed and pushed later on.
882     if (lengthKind != ArgumentLengthKind::Single)
883       body << "  const ::mlir::SmallVector<::mlir::Type, 1> &" << name
884            << "TypesRef(" << name << "Types);\n";
885     else
886       body << llvm::formatv(
887           "  ::llvm::ArrayRef<::mlir::Type> {0}RawTypesRef({0}RawTypes);\n",
888           name);
889   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
890     ArgumentLengthKind ignored;
891     body << "  ::llvm::ArrayRef<::mlir::Type> "
892          << getTypeListName(dir->getInputs(), ignored) << "Types;\n";
893     body << "  ::llvm::ArrayRef<::mlir::Type> "
894          << getTypeListName(dir->getResults(), ignored) << "Types;\n";
895   }
896 }
897 
898 /// Generate the parser for a parameter to a custom directive.
genCustomParameterParser(Element & param,OpMethodBody & body)899 static void genCustomParameterParser(Element &param, OpMethodBody &body) {
900   body << ", ";
901   if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
902     body << attr->getVar()->name << "Attr";
903   } else if (isa<AttrDictDirective>(&param)) {
904     body << "result.attributes";
905   } else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
906     StringRef name = operand->getVar()->name;
907     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
908     if (lengthKind == ArgumentLengthKind::Variadic)
909       body << llvm::formatv("{0}Operands", name);
910     else if (lengthKind == ArgumentLengthKind::Optional)
911       body << llvm::formatv("{0}Operand", name);
912     else
913       body << formatv("{0}RawOperands[0]", name);
914 
915   } else if (auto *region = dyn_cast<RegionVariable>(&param)) {
916     StringRef name = region->getVar()->name;
917     if (region->getVar()->isVariadic())
918       body << llvm::formatv("{0}Regions", name);
919     else
920       body << llvm::formatv("*{0}Region", name);
921 
922   } else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
923     StringRef name = successor->getVar()->name;
924     if (successor->getVar()->isVariadic())
925       body << llvm::formatv("{0}Successors", name);
926     else
927       body << llvm::formatv("{0}Successor", name);
928 
929   } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
930     ArgumentLengthKind lengthKind;
931     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
932     if (lengthKind == ArgumentLengthKind::Variadic)
933       body << llvm::formatv("{0}TypesRef", listName);
934     else if (lengthKind == ArgumentLengthKind::Optional)
935       body << llvm::formatv("{0}TypeRef", listName);
936     else
937       body << formatv("{0}RawTypesRef[0]", listName);
938   } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
939     ArgumentLengthKind lengthKind;
940     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
941     if (lengthKind == ArgumentLengthKind::Variadic)
942       body << llvm::formatv("{0}Types", listName);
943     else if (lengthKind == ArgumentLengthKind::Optional)
944       body << llvm::formatv("{0}Type", listName);
945     else
946       body << formatv("{0}RawTypes[0]", listName);
947   } else {
948     llvm_unreachable("unknown custom directive parameter");
949   }
950 }
951 
952 /// Generate the parser for a custom directive.
genCustomDirectiveParser(CustomDirective * dir,OpMethodBody & body)953 static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
954   body << "  {\n";
955 
956   // Preprocess the directive variables.
957   // * Add a local variable for optional operands and types. This provides a
958   //   better API to the user defined parser methods.
959   // * Set the location of operand variables.
960   for (Element &param : dir->getArguments()) {
961     if (auto *operand = dyn_cast<OperandVariable>(&param)) {
962       body << "    " << operand->getVar()->name
963            << "OperandsLoc = parser.getCurrentLocation();\n";
964       if (operand->getVar()->isOptional()) {
965         body << llvm::formatv(
966             "    llvm::Optional<::mlir::OpAsmParser::OperandType> "
967             "{0}Operand;\n",
968             operand->getVar()->name);
969       }
970     } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
971       // Reference to an optional which may or may not have been set.
972       // Retrieve from vector if not empty.
973       ArgumentLengthKind lengthKind;
974       StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
975       if (lengthKind == ArgumentLengthKind::Optional)
976         body << llvm::formatv(
977             "    ::mlir::Type {0}TypeRef = {0}TypesRef.empty() "
978             "? Type() : {0}TypesRef[0];\n",
979             listName);
980     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
981       ArgumentLengthKind lengthKind;
982       StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
983       if (lengthKind == ArgumentLengthKind::Optional)
984         body << llvm::formatv("    ::mlir::Type {0}Type;\n", listName);
985     }
986   }
987 
988   body << "    if (parse" << dir->getName() << "(parser";
989   for (Element &param : dir->getArguments())
990     genCustomParameterParser(param, body);
991 
992   body << "))\n"
993        << "      return ::mlir::failure();\n";
994 
995   // After parsing, add handling for any of the optional constructs.
996   for (Element &param : dir->getArguments()) {
997     if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
998       const NamedAttribute *var = attr->getVar();
999       if (var->attr.isOptional())
1000         body << llvm::formatv("    if ({0}Attr)\n  ", var->name);
1001 
1002       body << llvm::formatv("    result.addAttribute(\"{0}\", {0}Attr);\n",
1003                             var->name);
1004     } else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
1005       const NamedTypeConstraint *var = operand->getVar();
1006       if (!var->isOptional())
1007         continue;
1008       body << llvm::formatv("    if ({0}Operand.hasValue())\n"
1009                             "      {0}Operands.push_back(*{0}Operand);\n",
1010                             var->name);
1011     } else if (isa<TypeRefDirective>(&param)) {
1012       // In the `type_ref` case, do not parse a new Type that needs to be added.
1013       // Just do nothing here.
1014     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
1015       ArgumentLengthKind lengthKind;
1016       StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1017       if (lengthKind == ArgumentLengthKind::Optional) {
1018         body << llvm::formatv("    if ({0}Type)\n"
1019                               "      {0}Types.push_back({0}Type);\n",
1020                               listName);
1021       }
1022     }
1023   }
1024 
1025   body << "  }\n";
1026 }
1027 
1028 /// Generate the parser for a enum attribute.
genEnumAttrParser(const NamedAttribute * var,OpMethodBody & body,FmtContext & attrTypeCtx)1029 static void genEnumAttrParser(const NamedAttribute *var, OpMethodBody &body,
1030                               FmtContext &attrTypeCtx) {
1031   Attribute baseAttr = var->attr.getBaseAttr();
1032   const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1033   std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1034 
1035   // Generate the code for building an attribute for this enum.
1036   std::string attrBuilderStr;
1037   {
1038     llvm::raw_string_ostream os(attrBuilderStr);
1039     os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
1040                 "attrOptional.getValue()");
1041   }
1042 
1043   // Build a string containing the cases that can be formatted as a keyword.
1044   std::string validCaseKeywordsStr = "{";
1045   llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr);
1046   for (const EnumAttrCase &attrCase : cases)
1047     if (canFormatStringAsKeyword(attrCase.getStr()))
1048       validCaseKeywordsOS << '"' << attrCase.getStr() << "\",";
1049   validCaseKeywordsOS.str().back() = '}';
1050 
1051   // If the attribute is not optional, build an error message for the missing
1052   // attribute.
1053   std::string errorMessage;
1054   if (!var->attr.isOptional()) {
1055     llvm::raw_string_ostream errorMessageOS(errorMessage);
1056     errorMessageOS
1057         << "return parser.emitError(loc, \"expected string or "
1058            "keyword containing one of the following enum values for attribute '"
1059         << var->name << "' [";
1060     llvm::interleaveComma(cases, errorMessageOS, [&](const auto &attrCase) {
1061       errorMessageOS << attrCase.getStr();
1062     });
1063     errorMessageOS << "]\");";
1064   }
1065 
1066   body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
1067                   enumAttr.getStringToSymbolFnName(), attrBuilderStr,
1068                   validCaseKeywordsStr, errorMessage);
1069 }
1070 
genParser(Operator & op,OpClass & opClass)1071 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
1072   llvm::SmallVector<OpMethodParameter, 4> paramList;
1073   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1074   paramList.emplace_back("::mlir::OperationState &", "result");
1075 
1076   auto *method =
1077       opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
1078                                 OpMethod::MP_Static, std::move(paramList));
1079   auto &body = method->body();
1080 
1081   // Generate variables to store the operands and type within the format. This
1082   // allows for referencing these variables in the presence of optional
1083   // groupings.
1084   for (auto &element : elements)
1085     genElementParserStorage(&*element, body);
1086 
1087   // A format context used when parsing attributes with buildable types.
1088   FmtContext attrTypeCtx;
1089   attrTypeCtx.withBuilder("parser.getBuilder()");
1090 
1091   // Generate parsers for each of the elements.
1092   for (auto &element : elements)
1093     genElementParser(element.get(), body, attrTypeCtx);
1094 
1095   // Generate the code to resolve the operand/result types and successors now
1096   // that they have been parsed.
1097   genParserTypeResolution(op, body);
1098   genParserRegionResolution(op, body);
1099   genParserSuccessorResolution(op, body);
1100   genParserVariadicSegmentResolution(op, body);
1101 
1102   body << "  return ::mlir::success();\n";
1103 }
1104 
genElementParser(Element * element,OpMethodBody & body,FmtContext & attrTypeCtx)1105 void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
1106                                        FmtContext &attrTypeCtx) {
1107   /// Optional Group.
1108   if (auto *optional = dyn_cast<OptionalElement>(element)) {
1109     auto elements =
1110         llvm::drop_begin(optional->getElements(), optional->getParseStart());
1111 
1112     // Generate a special optional parser for the first element to gate the
1113     // parsing of the rest of the elements.
1114     Element *firstElement = &*elements.begin();
1115     if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
1116       genElementParser(attrVar, body, attrTypeCtx);
1117       body << "  if (" << attrVar->getVar()->name << "Attr) {\n";
1118     } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
1119       body << "  if (succeeded(parser.parseOptional";
1120       genLiteralParser(literal->getLiteral(), body);
1121       body << ")) {\n";
1122     } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
1123       genElementParser(opVar, body, attrTypeCtx);
1124       body << "  if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
1125     } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
1126       const NamedRegion *region = regionVar->getVar();
1127       if (region->isVariadic()) {
1128         genElementParser(regionVar, body, attrTypeCtx);
1129         body << "  if (!" << region->name << "Regions.empty()) {\n";
1130       } else {
1131         body << llvm::formatv(optionalRegionParserCode, region->name);
1132         body << "  if (!" << region->name << "Region->empty()) {\n  ";
1133         if (hasImplicitTermTrait)
1134           body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
1135       }
1136     }
1137 
1138     // If the anchor is a unit attribute, we don't need to print it. When
1139     // parsing, we will add this attribute if this group is present.
1140     Element *elidedAnchorElement = nullptr;
1141     auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
1142     if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) {
1143       elidedAnchorElement = anchorAttr;
1144 
1145       // Add the anchor unit attribute to the operation state.
1146       body << "    result.addAttribute(\"" << anchorAttr->getVar()->name
1147            << "\", parser.getBuilder().getUnitAttr());\n";
1148     }
1149 
1150     // Generate the rest of the elements normally.
1151     for (Element &childElement : llvm::drop_begin(elements, 1)) {
1152       if (&childElement != elidedAnchorElement)
1153         genElementParser(&childElement, body, attrTypeCtx);
1154     }
1155     body << "  }\n";
1156 
1157     /// Literals.
1158   } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
1159     body << "  if (parser.parse";
1160     genLiteralParser(literal->getLiteral(), body);
1161     body << ")\n    return ::mlir::failure();\n";
1162 
1163     /// Whitespaces.
1164   } else if (isa<WhitespaceElement>(element)) {
1165     // Nothing to parse.
1166 
1167     /// Arguments.
1168   } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1169     const NamedAttribute *var = attr->getVar();
1170 
1171     // Check to see if we can parse this as an enum attribute.
1172     if (canFormatEnumAttr(var))
1173       return genEnumAttrParser(var, body, attrTypeCtx);
1174 
1175     // Check to see if we should parse this as a symbol name attribute.
1176     if (shouldFormatSymbolNameAttr(var)) {
1177       body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode
1178                                              : symbolNameAttrParserCode,
1179                       var->name);
1180       return;
1181     }
1182 
1183     // If this attribute has a buildable type, use that when parsing the
1184     // attribute.
1185     std::string attrTypeStr;
1186     if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
1187       llvm::raw_string_ostream os(attrTypeStr);
1188       os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
1189     }
1190 
1191     body << formatv(var->attr.isOptional() ? optionalAttrParserCode
1192                                            : attrParserCode,
1193                     var->name, attrTypeStr);
1194   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1195     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
1196     StringRef name = operand->getVar()->name;
1197     if (lengthKind == ArgumentLengthKind::Variadic)
1198       body << llvm::formatv(variadicOperandParserCode, name);
1199     else if (lengthKind == ArgumentLengthKind::Optional)
1200       body << llvm::formatv(optionalOperandParserCode, name);
1201     else
1202       body << formatv(operandParserCode, name);
1203 
1204   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1205     bool isVariadic = region->getVar()->isVariadic();
1206     body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
1207                           region->getVar()->name);
1208     if (hasImplicitTermTrait) {
1209       body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
1210                                        : regionEnsureTerminatorParserCode,
1211                             region->getVar()->name);
1212     }
1213 
1214   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1215     bool isVariadic = successor->getVar()->isVariadic();
1216     body << formatv(isVariadic ? successorListParserCode : successorParserCode,
1217                     successor->getVar()->name);
1218 
1219     /// Directives.
1220   } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1221     body << "  if (parser.parseOptionalAttrDict"
1222          << (attrDict->isWithKeyword() ? "WithKeyword" : "")
1223          << "(result.attributes))\n"
1224          << "    return ::mlir::failure();\n";
1225   } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
1226     genCustomDirectiveParser(customDir, body);
1227 
1228   } else if (isa<OperandsDirective>(element)) {
1229     body << "  ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
1230          << "  if (parser.parseOperandList(allOperands))\n"
1231          << "    return ::mlir::failure();\n";
1232 
1233   } else if (isa<RegionsDirective>(element)) {
1234     body << llvm::formatv(regionListParserCode, "full");
1235     if (hasImplicitTermTrait)
1236       body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
1237 
1238   } else if (isa<SuccessorsDirective>(element)) {
1239     body << llvm::formatv(successorListParserCode, "full");
1240 
1241   } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
1242     ArgumentLengthKind lengthKind;
1243     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1244     if (lengthKind == ArgumentLengthKind::Variadic)
1245       body << llvm::formatv(variadicTypeParserCode, listName);
1246     else if (lengthKind == ArgumentLengthKind::Optional)
1247       body << llvm::formatv(optionalTypeParserCode, listName);
1248     else
1249       body << formatv(typeParserCode, listName);
1250   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1251     ArgumentLengthKind lengthKind;
1252     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1253     if (lengthKind == ArgumentLengthKind::Variadic)
1254       body << llvm::formatv(variadicTypeParserCode, listName);
1255     else if (lengthKind == ArgumentLengthKind::Optional)
1256       body << llvm::formatv(optionalTypeParserCode, listName);
1257     else
1258       body << formatv(typeParserCode, listName);
1259   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1260     ArgumentLengthKind ignored;
1261     body << formatv(functionalTypeParserCode,
1262                     getTypeListName(dir->getInputs(), ignored),
1263                     getTypeListName(dir->getResults(), ignored));
1264   } else {
1265     llvm_unreachable("unknown format element");
1266   }
1267 }
1268 
genParserTypeResolution(Operator & op,OpMethodBody & body)1269 void OperationFormat::genParserTypeResolution(Operator &op,
1270                                               OpMethodBody &body) {
1271   // If any of type resolutions use transformed variables, make sure that the
1272   // types of those variables are resolved.
1273   SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
1274   FmtContext verifierFCtx;
1275   for (TypeResolution &resolver :
1276        llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
1277     Optional<StringRef> transformer = resolver.getVarTransformer();
1278     if (!transformer)
1279       continue;
1280     // Ensure that we don't verify the same variables twice.
1281     const NamedTypeConstraint *variable = resolver.getVariable();
1282     if (!variable || !verifiedVariables.insert(variable).second)
1283       continue;
1284 
1285     auto constraint = variable->constraint;
1286     body << "  for (::mlir::Type type : " << variable->name << "Types) {\n"
1287          << "    (void)type;\n"
1288          << "    if (!("
1289          << tgfmt(constraint.getConditionTemplate(),
1290                   &verifierFCtx.withSelf("type"))
1291          << ")) {\n"
1292          << formatv("      return parser.emitError(parser.getNameLoc()) << "
1293                     "\"'{0}' must be {1}, but got \" << type;\n",
1294                     variable->name, constraint.getSummary())
1295          << "    }\n"
1296          << "  }\n";
1297   }
1298 
1299   // Initialize the set of buildable types.
1300   if (!buildableTypes.empty()) {
1301     FmtContext typeBuilderCtx;
1302     typeBuilderCtx.withBuilder("parser.getBuilder()");
1303     for (auto &it : buildableTypes)
1304       body << "  ::mlir::Type odsBuildableType" << it.second << " = "
1305            << tgfmt(it.first, &typeBuilderCtx) << ";\n";
1306   }
1307 
1308   // Emit the code necessary for a type resolver.
1309   auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
1310     if (Optional<int> val = resolver.getBuilderIdx()) {
1311       body << "odsBuildableType" << *val;
1312     } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1313       if (Optional<StringRef> tform = resolver.getVarTransformer()) {
1314         FmtContext fmtContext;
1315         if (var->isVariadic())
1316           fmtContext.withSelf(var->name + "Types");
1317         else
1318           fmtContext.withSelf(var->name + "Types[0]");
1319         body << tgfmt(*tform, &fmtContext);
1320       } else {
1321         body << var->name << "Types";
1322       }
1323     } else if (const NamedAttribute *attr = resolver.getAttribute()) {
1324       if (Optional<StringRef> tform = resolver.getVarTransformer())
1325         body << tgfmt(*tform,
1326                       &FmtContext().withSelf(attr->name + "Attr.getType()"));
1327       else
1328         body << attr->name << "Attr.getType()";
1329     } else {
1330       body << curVar << "Types";
1331     }
1332   };
1333 
1334   // Resolve each of the result types.
1335   if (allResultTypes) {
1336     body << "  result.addTypes(allResultTypes);\n";
1337   } else {
1338     for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
1339       body << "  result.addTypes(";
1340       emitTypeResolver(resultTypes[i], op.getResultName(i));
1341       body << ");\n";
1342     }
1343   }
1344 
1345   // Early exit if there are no operands.
1346   if (op.getNumOperands() == 0)
1347     return;
1348 
1349   // Handle the case where all operand types are in one group.
1350   if (allOperandTypes) {
1351     // If we have all operands together, use the full operand list directly.
1352     if (allOperands) {
1353       body << "  if (parser.resolveOperands(allOperands, allOperandTypes, "
1354               "allOperandLoc, result.operands))\n"
1355               "    return ::mlir::failure();\n";
1356       return;
1357     }
1358 
1359     // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1360     // llvm::concat does not allow the case of a single range, so guard it here.
1361     body << "  if (parser.resolveOperands(";
1362     if (op.getNumOperands() > 1) {
1363       body << "::llvm::concat<const ::mlir::OpAsmParser::OperandType>(";
1364       llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
1365         body << operand.name << "Operands";
1366       });
1367       body << ")";
1368     } else {
1369       body << op.operand_begin()->name << "Operands";
1370     }
1371     body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1372          << "    return ::mlir::failure();\n";
1373     return;
1374   }
1375   // Handle the case where all of the operands were grouped together.
1376   if (allOperands) {
1377     body << "  if (parser.resolveOperands(allOperands, ";
1378 
1379     // Group all of the operand types together to perform the resolution all at
1380     // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1381     // the case of a single range, so guard it here.
1382     if (op.getNumOperands() > 1) {
1383       body << "::llvm::concat<const Type>(";
1384       llvm::interleaveComma(
1385           llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1386             body << "::llvm::ArrayRef<::mlir::Type>(";
1387             emitTypeResolver(operandTypes[i], op.getOperand(i).name);
1388             body << ")";
1389           });
1390       body << ")";
1391     } else {
1392       emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
1393     }
1394 
1395     body << ", allOperandLoc, result.operands))\n"
1396          << "    return ::mlir::failure();\n";
1397     return;
1398   }
1399 
1400   // The final case is the one where each of the operands types are resolved
1401   // separately.
1402   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
1403     NamedTypeConstraint &operand = op.getOperand(i);
1404     body << "  if (parser.resolveOperands(" << operand.name << "Operands, ";
1405 
1406     // Resolve the type of this operand.
1407     TypeResolution &operandType = operandTypes[i];
1408     emitTypeResolver(operandType, operand.name);
1409 
1410     // If the type is resolved by a non-variadic variable, index into the
1411     // resolved type list. This allows for resolving the types of a variadic
1412     // operand list from a non-variadic variable.
1413     bool verifyOperandAndTypeSize = true;
1414     if (auto *resolverVar = operandType.getVariable()) {
1415       if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) {
1416         body << "[0]";
1417         verifyOperandAndTypeSize = false;
1418       }
1419     } else {
1420       verifyOperandAndTypeSize = !operandType.getBuilderIdx();
1421     }
1422 
1423     // Check to see if the sizes between the types and operands must match. If
1424     // they do, provide the operand location to select the proper resolution
1425     // overload.
1426     if (verifyOperandAndTypeSize)
1427       body << ", " << operand.name << "OperandsLoc";
1428     body << ", result.operands))\n    return ::mlir::failure();\n";
1429   }
1430 }
1431 
genParserRegionResolution(Operator & op,OpMethodBody & body)1432 void OperationFormat::genParserRegionResolution(Operator &op,
1433                                                 OpMethodBody &body) {
1434   // Check for the case where all regions were parsed.
1435   bool hasAllRegions = llvm::any_of(
1436       elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
1437   if (hasAllRegions) {
1438     body << "  result.addRegions(fullRegions);\n";
1439     return;
1440   }
1441 
1442   // Otherwise, handle each region individually.
1443   for (const NamedRegion &region : op.getRegions()) {
1444     if (region.isVariadic())
1445       body << "  result.addRegions(" << region.name << "Regions);\n";
1446     else
1447       body << "  result.addRegion(std::move(" << region.name << "Region));\n";
1448   }
1449 }
1450 
genParserSuccessorResolution(Operator & op,OpMethodBody & body)1451 void OperationFormat::genParserSuccessorResolution(Operator &op,
1452                                                    OpMethodBody &body) {
1453   // Check for the case where all successors were parsed.
1454   bool hasAllSuccessors = llvm::any_of(
1455       elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); });
1456   if (hasAllSuccessors) {
1457     body << "  result.addSuccessors(fullSuccessors);\n";
1458     return;
1459   }
1460 
1461   // Otherwise, handle each successor individually.
1462   for (const NamedSuccessor &successor : op.getSuccessors()) {
1463     if (successor.isVariadic())
1464       body << "  result.addSuccessors(" << successor.name << "Successors);\n";
1465     else
1466       body << "  result.addSuccessors(" << successor.name << "Successor);\n";
1467   }
1468 }
1469 
genParserVariadicSegmentResolution(Operator & op,OpMethodBody & body)1470 void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
1471                                                          OpMethodBody &body) {
1472   if (!allOperands &&
1473       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1474     body << "  result.addAttribute(\"operand_segment_sizes\", "
1475          << "parser.getBuilder().getI32VectorAttr({";
1476     auto interleaveFn = [&](const NamedTypeConstraint &operand) {
1477       // If the operand is variadic emit the parsed size.
1478       if (operand.isVariableLength())
1479         body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
1480       else
1481         body << "1";
1482     };
1483     llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1484     body << "}));\n";
1485   }
1486 
1487   if (!allResultTypes &&
1488       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1489     body << "  result.addAttribute(\"result_segment_sizes\", "
1490          << "parser.getBuilder().getI32VectorAttr({";
1491     auto interleaveFn = [&](const NamedTypeConstraint &result) {
1492       // If the result is variadic emit the parsed size.
1493       if (result.isVariableLength())
1494         body << "static_cast<int32_t>(" << result.name << "Types.size())";
1495       else
1496         body << "1";
1497     };
1498     llvm::interleaveComma(op.getResults(), body, interleaveFn);
1499     body << "}));\n";
1500   }
1501 }
1502 
1503 //===----------------------------------------------------------------------===//
1504 // PrinterGen
1505 
1506 /// The code snippet used to generate a printer call for a region of an
1507 // operation that has the SingleBlockImplicitTerminator trait.
1508 ///
1509 /// {0}: The name of the region.
1510 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
1511   {
1512     bool printTerminator = true;
1513     if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1514       printTerminator = !term->getAttrDictionary().empty() ||
1515                         term->getNumOperands() != 0 ||
1516                         term->getNumResults() != 0;
1517     }
1518     p.printRegion({0}, /*printEntryBlockArgs=*/true,
1519                   /*printBlockTerminators=*/printTerminator);
1520   }
1521 )";
1522 
1523 /// The code snippet used to generate a printer call for an enum that has cases
1524 /// that can't be represented with a keyword.
1525 ///
1526 /// {0}: The name of the enum attribute.
1527 /// {1}: The name of the enum attributes symbolToString function.
1528 const char *enumAttrBeginPrinterCode = R"(
1529   {
1530     auto caseValue = {0}();
1531     auto caseValueStr = {1}(caseValue);
1532 )";
1533 
1534 /// Generate the printer for the 'attr-dict' directive.
genAttrDictPrinter(OperationFormat & fmt,Operator & op,OpMethodBody & body,bool withKeyword)1535 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
1536                                OpMethodBody &body, bool withKeyword) {
1537   body << "  p.printOptionalAttrDict" << (withKeyword ? "WithKeyword" : "")
1538        << "(getAttrs(), /*elidedAttrs=*/{";
1539   // Elide the variadic segment size attributes if necessary.
1540   if (!fmt.allOperands &&
1541       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1542     body << "\"operand_segment_sizes\", ";
1543   if (!fmt.allResultTypes &&
1544       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1545     body << "\"result_segment_sizes\", ";
1546   llvm::interleaveComma(
1547       fmt.usedAttributes, body,
1548       [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
1549   body << "});\n";
1550 }
1551 
1552 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1553 /// space should be emitted before this element. `lastWasPunctuation` is true if
1554 /// the previous element was a punctuation literal.
genLiteralPrinter(StringRef value,OpMethodBody & body,bool & shouldEmitSpace,bool & lastWasPunctuation)1555 static void genLiteralPrinter(StringRef value, OpMethodBody &body,
1556                               bool &shouldEmitSpace, bool &lastWasPunctuation) {
1557   body << "  p";
1558 
1559   // Don't insert a space for certain punctuation.
1560   auto shouldPrintSpaceBeforeLiteral = [&] {
1561     if (value.size() != 1 && value != "->")
1562       return true;
1563     if (lastWasPunctuation)
1564       return !StringRef(">)}],").contains(value.front());
1565     return !StringRef("<>(){}[],").contains(value.front());
1566   };
1567   if (shouldEmitSpace && shouldPrintSpaceBeforeLiteral())
1568     body << " << ' '";
1569   body << " << \"" << value << "\";\n";
1570 
1571   // Insert a space after certain literals.
1572   shouldEmitSpace =
1573       value.size() != 1 || !StringRef("<({[").contains(value.front());
1574   lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
1575 }
1576 
1577 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1578 /// are set to false.
genSpacePrinter(bool value,OpMethodBody & body,bool & shouldEmitSpace,bool & lastWasPunctuation)1579 static void genSpacePrinter(bool value, OpMethodBody &body,
1580                             bool &shouldEmitSpace, bool &lastWasPunctuation) {
1581   if (value) {
1582     body << "  p << ' ';\n";
1583     lastWasPunctuation = false;
1584   }
1585   shouldEmitSpace = false;
1586 }
1587 
1588 /// Generate the printer for a custom directive.
genCustomDirectivePrinter(CustomDirective * customDir,OpMethodBody & body)1589 static void genCustomDirectivePrinter(CustomDirective *customDir,
1590                                       OpMethodBody &body) {
1591   body << "  print" << customDir->getName() << "(p, *this";
1592   for (Element &param : customDir->getArguments()) {
1593     body << ", ";
1594     if (auto *attr = dyn_cast<AttributeVariable>(&param)) {
1595       body << attr->getVar()->name << "Attr()";
1596 
1597     } else if (isa<AttrDictDirective>(&param)) {
1598       body << "getOperation()->getAttrDictionary()";
1599 
1600     } else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
1601       body << operand->getVar()->name << "()";
1602 
1603     } else if (auto *region = dyn_cast<RegionVariable>(&param)) {
1604       body << region->getVar()->name << "()";
1605 
1606     } else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
1607       body << successor->getVar()->name << "()";
1608 
1609     } else if (auto *dir = dyn_cast<TypeRefDirective>(&param)) {
1610       auto *typeOperand = dir->getOperand();
1611       auto *operand = dyn_cast<OperandVariable>(typeOperand);
1612       auto *var = operand ? operand->getVar()
1613                           : cast<ResultVariable>(typeOperand)->getVar();
1614       if (var->isVariadic())
1615         body << var->name << "().getTypes()";
1616       else if (var->isOptional())
1617         body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
1618       else
1619         body << var->name << "().getType()";
1620     } else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
1621       auto *typeOperand = dir->getOperand();
1622       auto *operand = dyn_cast<OperandVariable>(typeOperand);
1623       auto *var = operand ? operand->getVar()
1624                           : cast<ResultVariable>(typeOperand)->getVar();
1625       if (var->isVariadic())
1626         body << var->name << "().getTypes()";
1627       else if (var->isOptional())
1628         body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
1629       else
1630         body << var->name << "().getType()";
1631     } else {
1632       llvm_unreachable("unknown custom directive parameter");
1633     }
1634   }
1635 
1636   body << ");\n";
1637 }
1638 
1639 /// Generate the printer for a region with the given variable name.
genRegionPrinter(const Twine & regionName,OpMethodBody & body,bool hasImplicitTermTrait)1640 static void genRegionPrinter(const Twine &regionName, OpMethodBody &body,
1641                              bool hasImplicitTermTrait) {
1642   if (hasImplicitTermTrait)
1643     body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
1644                           regionName);
1645   else
1646     body << "  p.printRegion(" << regionName << ");\n";
1647 }
genVariadicRegionPrinter(const Twine & regionListName,OpMethodBody & body,bool hasImplicitTermTrait)1648 static void genVariadicRegionPrinter(const Twine &regionListName,
1649                                      OpMethodBody &body,
1650                                      bool hasImplicitTermTrait) {
1651   body << "    llvm::interleaveComma(" << regionListName
1652        << ", p, [&](::mlir::Region &region) {\n      ";
1653   genRegionPrinter("region", body, hasImplicitTermTrait);
1654   body << "    });\n";
1655 }
1656 
1657 /// Generate the C++ for an operand to a (*-)type directive.
genTypeOperandPrinter(Element * arg,OpMethodBody & body)1658 static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
1659   if (isa<OperandsDirective>(arg))
1660     return body << "getOperation()->getOperandTypes()";
1661   if (isa<ResultsDirective>(arg))
1662     return body << "getOperation()->getResultTypes()";
1663   auto *operand = dyn_cast<OperandVariable>(arg);
1664   auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
1665   if (var->isVariadic())
1666     return body << var->name << "().getTypes()";
1667   if (var->isOptional())
1668     return body << llvm::formatv(
1669                "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
1670                "::llvm::ArrayRef<::mlir::Type>())",
1671                var->name);
1672   return body << "::llvm::ArrayRef<::mlir::Type>(" << var->name
1673               << "().getType())";
1674 }
1675 
1676 /// Generate the printer for an enum attribute.
genEnumAttrPrinter(const NamedAttribute * var,OpMethodBody & body)1677 static void genEnumAttrPrinter(const NamedAttribute *var, OpMethodBody &body) {
1678   Attribute baseAttr = var->attr.getBaseAttr();
1679   const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1680   std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1681 
1682   body << llvm::formatv(enumAttrBeginPrinterCode,
1683                         (var->attr.isOptional() ? "*" : "") + var->name,
1684                         enumAttr.getSymbolToStringFnName());
1685 
1686   // Get a string containing all of the cases that can't be represented with a
1687   // keyword.
1688   llvm::BitVector nonKeywordCases(cases.size());
1689   bool hasStrCase = false;
1690   for (auto it : llvm::enumerate(cases)) {
1691     hasStrCase = it.value().isStrCase();
1692     if (!canFormatStringAsKeyword(it.value().getStr()))
1693       nonKeywordCases.set(it.index());
1694   }
1695 
1696   // If this is a string enum, use the case string to determine which cases
1697   // need to use the string form.
1698   if (hasStrCase) {
1699     if (nonKeywordCases.any()) {
1700       body << "    if (llvm::is_contained(llvm::ArrayRef<llvm::StringRef>(";
1701       llvm::interleaveComma(nonKeywordCases.set_bits(), body, [&](unsigned it) {
1702         body << '"' << cases[it].getStr() << '"';
1703       });
1704       body << ")))\n"
1705               "      p << '\"' << caseValueStr << '\"';\n"
1706               "    else\n  ";
1707     }
1708     body << "    p << caseValueStr;\n"
1709             "  }\n";
1710     return;
1711   }
1712 
1713   // Otherwise if this is a bit enum attribute, don't allow cases that may
1714   // overlap with other cases. For simplicity sake, only allow cases with a
1715   // single bit value.
1716   if (enumAttr.isBitEnum()) {
1717     for (auto it : llvm::enumerate(cases)) {
1718       int64_t value = it.value().getValue();
1719       if (value < 0 || !llvm::isPowerOf2_64(value))
1720         nonKeywordCases.set(it.index());
1721     }
1722   }
1723 
1724   // If there are any cases that can't be used with a keyword, switch on the
1725   // case value to determine when to print in the string form.
1726   if (nonKeywordCases.any()) {
1727     body << "    switch (caseValue) {\n";
1728     StringRef cppNamespace = enumAttr.getCppNamespace();
1729     StringRef enumName = enumAttr.getEnumClassName();
1730     for (auto it : llvm::enumerate(cases)) {
1731       if (nonKeywordCases.test(it.index()))
1732         continue;
1733       StringRef symbol = it.value().getSymbol();
1734       body << llvm::formatv("    case {0}::{1}::{2}:\n", cppNamespace, enumName,
1735                             llvm::isDigit(symbol.front()) ? ("_" + symbol)
1736                                                           : symbol);
1737     }
1738     body << "      p << caseValueStr;\n"
1739             "      break;\n"
1740             "    default:\n"
1741             "      p << '\"' << caseValueStr << '\"';\n"
1742             "      break;\n"
1743             "    }\n"
1744             "  }\n";
1745     return;
1746   }
1747 
1748   body << "    p << caseValueStr;\n"
1749           "  }\n";
1750 }
1751 
1752 /// Generate the check for the anchor of an optional group.
genOptionalGroupPrinterAnchor(Element * anchor,OpMethodBody & body)1753 static void genOptionalGroupPrinterAnchor(Element *anchor, OpMethodBody &body) {
1754   TypeSwitch<Element *>(anchor)
1755       .Case<OperandVariable, ResultVariable>([&](auto *element) {
1756         const NamedTypeConstraint *var = element->getVar();
1757         if (var->isOptional())
1758           body << "  if (" << var->name << "()) {\n";
1759         else if (var->isVariadic())
1760           body << "  if (!" << var->name << "().empty()) {\n";
1761       })
1762       .Case<RegionVariable>([&](RegionVariable *element) {
1763         const NamedRegion *var = element->getVar();
1764         // TODO: Add a check for optional regions here when ODS supports it.
1765         body << "  if (!" << var->name << "().empty()) {\n";
1766       })
1767       .Case<TypeDirective>([&](TypeDirective *element) {
1768         genOptionalGroupPrinterAnchor(element->getOperand(), body);
1769       })
1770       .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
1771         genOptionalGroupPrinterAnchor(element->getInputs(), body);
1772       })
1773       .Case<AttributeVariable>([&](AttributeVariable *attr) {
1774         body << "  if ((*this)->getAttr(\"" << attr->getVar()->name
1775              << "\")) {\n";
1776       });
1777 }
1778 
genElementPrinter(Element * element,OpMethodBody & body,Operator & op,bool & shouldEmitSpace,bool & lastWasPunctuation)1779 void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
1780                                         Operator &op, bool &shouldEmitSpace,
1781                                         bool &lastWasPunctuation) {
1782   if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
1783     return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
1784                              lastWasPunctuation);
1785 
1786   // Emit a whitespace element.
1787   if (isa<NewlineElement>(element)) {
1788     body << "  p.printNewline();\n";
1789     return;
1790   }
1791   if (SpaceElement *space = dyn_cast<SpaceElement>(element))
1792     return genSpacePrinter(space->getValue(), body, shouldEmitSpace,
1793                            lastWasPunctuation);
1794 
1795   // Emit an optional group.
1796   if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
1797     // Emit the check for the presence of the anchor element.
1798     Element *anchor = optional->getAnchor();
1799     genOptionalGroupPrinterAnchor(anchor, body);
1800 
1801     // If the anchor is a unit attribute, we don't need to print it. When
1802     // parsing, we will add this attribute if this group is present.
1803     auto elements = optional->getElements();
1804     Element *elidedAnchorElement = nullptr;
1805     auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
1806     if (anchorAttr && anchorAttr != &*elements.begin() &&
1807         anchorAttr->isUnitAttr()) {
1808       elidedAnchorElement = anchorAttr;
1809     }
1810 
1811     // Emit each of the elements.
1812     for (Element &childElement : elements) {
1813       if (&childElement != elidedAnchorElement) {
1814         genElementPrinter(&childElement, body, op, shouldEmitSpace,
1815                           lastWasPunctuation);
1816       }
1817     }
1818     body << "  }\n";
1819     return;
1820   }
1821 
1822   // Emit the attribute dictionary.
1823   if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1824     genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
1825     lastWasPunctuation = false;
1826     return;
1827   }
1828 
1829   // Optionally insert a space before the next element. The AttrDict printer
1830   // already adds a space as necessary.
1831   if (shouldEmitSpace || !lastWasPunctuation)
1832     body << "  p << ' ';\n";
1833   lastWasPunctuation = false;
1834   shouldEmitSpace = true;
1835 
1836   if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1837     const NamedAttribute *var = attr->getVar();
1838 
1839     // If we are formatting as an enum, symbolize the attribute as a string.
1840     if (canFormatEnumAttr(var))
1841       return genEnumAttrPrinter(var, body);
1842 
1843     // If we are formatting as a symbol name, handle it as a symbol name.
1844     if (shouldFormatSymbolNameAttr(var)) {
1845       body << "  p.printSymbolName(" << var->name << "Attr().getValue());\n";
1846       return;
1847     }
1848 
1849     // Elide the attribute type if it is buildable.
1850     if (attr->getTypeBuilder())
1851       body << "  p.printAttributeWithoutType(" << var->name << "Attr());\n";
1852     else
1853       body << "  p.printAttribute(" << var->name << "Attr());\n";
1854   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1855     if (operand->getVar()->isOptional()) {
1856       body << "  if (::mlir::Value value = " << operand->getVar()->name
1857            << "())\n"
1858            << "    p << value;\n";
1859     } else {
1860       body << "  p << " << operand->getVar()->name << "();\n";
1861     }
1862   } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1863     const NamedRegion *var = region->getVar();
1864     if (var->isVariadic()) {
1865       genVariadicRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
1866     } else {
1867       genRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
1868     }
1869   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1870     const NamedSuccessor *var = successor->getVar();
1871     if (var->isVariadic())
1872       body << "  ::llvm::interleaveComma(" << var->name << "(), p);\n";
1873     else
1874       body << "  p << " << var->name << "();\n";
1875   } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
1876     genCustomDirectivePrinter(dir, body);
1877   } else if (isa<OperandsDirective>(element)) {
1878     body << "  p << getOperation()->getOperands();\n";
1879   } else if (isa<RegionsDirective>(element)) {
1880     genVariadicRegionPrinter("getOperation()->getRegions()", body,
1881                              hasImplicitTermTrait);
1882   } else if (isa<SuccessorsDirective>(element)) {
1883     body << "  ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n";
1884   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1885     body << "  p << ";
1886     genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
1887   } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
1888     body << "  p << ";
1889     genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
1890   } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1891     body << "  p.printFunctionalType(";
1892     genTypeOperandPrinter(dir->getInputs(), body) << ", ";
1893     genTypeOperandPrinter(dir->getResults(), body) << ");\n";
1894   } else {
1895     llvm_unreachable("unknown format element");
1896   }
1897 }
1898 
genPrinter(Operator & op,OpClass & opClass)1899 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
1900   auto *method =
1901       opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &p");
1902   auto &body = method->body();
1903 
1904   // Emit the operation name, trimming the prefix if this is the standard
1905   // dialect.
1906   body << "  p << \"";
1907   std::string opName = op.getOperationName();
1908   if (op.getDialectName() == "std")
1909     body << StringRef(opName).drop_front(4);
1910   else
1911     body << opName;
1912   body << "\";\n";
1913 
1914   // Flags for if we should emit a space, and if the last element was
1915   // punctuation.
1916   bool shouldEmitSpace = true, lastWasPunctuation = false;
1917   for (auto &element : elements)
1918     genElementPrinter(element.get(), body, op, shouldEmitSpace,
1919                       lastWasPunctuation);
1920 }
1921 
1922 //===----------------------------------------------------------------------===//
1923 // FormatLexer
1924 //===----------------------------------------------------------------------===//
1925 
1926 namespace {
1927 /// This class represents a specific token in the input format.
1928 class Token {
1929 public:
1930   enum Kind {
1931     // Markers.
1932     eof,
1933     error,
1934 
1935     // Tokens with no info.
1936     l_paren,
1937     r_paren,
1938     caret,
1939     comma,
1940     equal,
1941     less,
1942     greater,
1943     question,
1944 
1945     // Keywords.
1946     keyword_start,
1947     kw_attr_dict,
1948     kw_attr_dict_w_keyword,
1949     kw_custom,
1950     kw_functional_type,
1951     kw_operands,
1952     kw_regions,
1953     kw_results,
1954     kw_successors,
1955     kw_type,
1956     kw_type_ref,
1957     keyword_end,
1958 
1959     // String valued tokens.
1960     identifier,
1961     literal,
1962     variable,
1963   };
Token(Kind kind,StringRef spelling)1964   Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
1965 
1966   /// Return the bytes that make up this token.
getSpelling() const1967   StringRef getSpelling() const { return spelling; }
1968 
1969   /// Return the kind of this token.
getKind() const1970   Kind getKind() const { return kind; }
1971 
1972   /// Return a location for this token.
getLoc() const1973   llvm::SMLoc getLoc() const {
1974     return llvm::SMLoc::getFromPointer(spelling.data());
1975   }
1976 
1977   /// Return if this token is a keyword.
isKeyword() const1978   bool isKeyword() const { return kind > keyword_start && kind < keyword_end; }
1979 
1980 private:
1981   /// Discriminator that indicates the kind of token this is.
1982   Kind kind;
1983 
1984   /// A reference to the entire token contents; this is always a pointer into
1985   /// a memory buffer owned by the source manager.
1986   StringRef spelling;
1987 };
1988 
1989 /// This class implements a simple lexer for operation assembly format strings.
1990 class FormatLexer {
1991 public:
1992   FormatLexer(llvm::SourceMgr &mgr, Operator &op);
1993 
1994   /// Lex the next token and return it.
1995   Token lexToken();
1996 
1997   /// Emit an error to the lexer with the given location and message.
1998   Token emitError(llvm::SMLoc loc, const Twine &msg);
1999   Token emitError(const char *loc, const Twine &msg);
2000 
2001   Token emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, const Twine &note);
2002 
2003 private:
formToken(Token::Kind kind,const char * tokStart)2004   Token formToken(Token::Kind kind, const char *tokStart) {
2005     return Token(kind, StringRef(tokStart, curPtr - tokStart));
2006   }
2007 
2008   /// Return the next character in the stream.
2009   int getNextChar();
2010 
2011   /// Lex an identifier, literal, or variable.
2012   Token lexIdentifier(const char *tokStart);
2013   Token lexLiteral(const char *tokStart);
2014   Token lexVariable(const char *tokStart);
2015 
2016   llvm::SourceMgr &srcMgr;
2017   Operator &op;
2018   StringRef curBuffer;
2019   const char *curPtr;
2020 };
2021 } // end anonymous namespace
2022 
FormatLexer(llvm::SourceMgr & mgr,Operator & op)2023 FormatLexer::FormatLexer(llvm::SourceMgr &mgr, Operator &op)
2024     : srcMgr(mgr), op(op) {
2025   curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
2026   curPtr = curBuffer.begin();
2027 }
2028 
emitError(llvm::SMLoc loc,const Twine & msg)2029 Token FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) {
2030   srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
2031   llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
2032                             "in custom assembly format for this operation");
2033   return formToken(Token::error, loc.getPointer());
2034 }
emitErrorAndNote(llvm::SMLoc loc,const Twine & msg,const Twine & note)2035 Token FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
2036                                     const Twine &note) {
2037   srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
2038   llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
2039                             "in custom assembly format for this operation");
2040   srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
2041   return formToken(Token::error, loc.getPointer());
2042 }
emitError(const char * loc,const Twine & msg)2043 Token FormatLexer::emitError(const char *loc, const Twine &msg) {
2044   return emitError(llvm::SMLoc::getFromPointer(loc), msg);
2045 }
2046 
getNextChar()2047 int FormatLexer::getNextChar() {
2048   char curChar = *curPtr++;
2049   switch (curChar) {
2050   default:
2051     return (unsigned char)curChar;
2052   case 0: {
2053     // A nul character in the stream is either the end of the current buffer or
2054     // a random nul in the file. Disambiguate that here.
2055     if (curPtr - 1 != curBuffer.end())
2056       return 0;
2057 
2058     // Otherwise, return end of file.
2059     --curPtr;
2060     return EOF;
2061   }
2062   case '\n':
2063   case '\r':
2064     // Handle the newline character by ignoring it and incrementing the line
2065     // count. However, be careful about 'dos style' files with \n\r in them.
2066     // Only treat a \n\r or \r\n as a single line.
2067     if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
2068       ++curPtr;
2069     return '\n';
2070   }
2071 }
2072 
lexToken()2073 Token FormatLexer::lexToken() {
2074   const char *tokStart = curPtr;
2075 
2076   // This always consumes at least one character.
2077   int curChar = getNextChar();
2078   switch (curChar) {
2079   default:
2080     // Handle identifiers: [a-zA-Z_]
2081     if (isalpha(curChar) || curChar == '_')
2082       return lexIdentifier(tokStart);
2083 
2084     // Unknown character, emit an error.
2085     return emitError(tokStart, "unexpected character");
2086   case EOF:
2087     // Return EOF denoting the end of lexing.
2088     return formToken(Token::eof, tokStart);
2089 
2090   // Lex punctuation.
2091   case '^':
2092     return formToken(Token::caret, tokStart);
2093   case ',':
2094     return formToken(Token::comma, tokStart);
2095   case '=':
2096     return formToken(Token::equal, tokStart);
2097   case '<':
2098     return formToken(Token::less, tokStart);
2099   case '>':
2100     return formToken(Token::greater, tokStart);
2101   case '?':
2102     return formToken(Token::question, tokStart);
2103   case '(':
2104     return formToken(Token::l_paren, tokStart);
2105   case ')':
2106     return formToken(Token::r_paren, tokStart);
2107 
2108   // Ignore whitespace characters.
2109   case 0:
2110   case ' ':
2111   case '\t':
2112   case '\n':
2113     return lexToken();
2114 
2115   case '`':
2116     return lexLiteral(tokStart);
2117   case '$':
2118     return lexVariable(tokStart);
2119   }
2120 }
2121 
lexLiteral(const char * tokStart)2122 Token FormatLexer::lexLiteral(const char *tokStart) {
2123   assert(curPtr[-1] == '`');
2124 
2125   // Lex a literal surrounded by ``.
2126   while (const char curChar = *curPtr++) {
2127     if (curChar == '`')
2128       return formToken(Token::literal, tokStart);
2129   }
2130   return emitError(curPtr - 1, "unexpected end of file in literal");
2131 }
2132 
lexVariable(const char * tokStart)2133 Token FormatLexer::lexVariable(const char *tokStart) {
2134   if (!isalpha(curPtr[0]) && curPtr[0] != '_')
2135     return emitError(curPtr - 1, "expected variable name");
2136 
2137   // Otherwise, consume the rest of the characters.
2138   while (isalnum(*curPtr) || *curPtr == '_')
2139     ++curPtr;
2140   return formToken(Token::variable, tokStart);
2141 }
2142 
lexIdentifier(const char * tokStart)2143 Token FormatLexer::lexIdentifier(const char *tokStart) {
2144   // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
2145   while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
2146     ++curPtr;
2147 
2148   // Check to see if this identifier is a keyword.
2149   StringRef str(tokStart, curPtr - tokStart);
2150   Token::Kind kind =
2151       StringSwitch<Token::Kind>(str)
2152           .Case("attr-dict", Token::kw_attr_dict)
2153           .Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
2154           .Case("custom", Token::kw_custom)
2155           .Case("functional-type", Token::kw_functional_type)
2156           .Case("operands", Token::kw_operands)
2157           .Case("regions", Token::kw_regions)
2158           .Case("results", Token::kw_results)
2159           .Case("successors", Token::kw_successors)
2160           .Case("type", Token::kw_type)
2161           .Case("type_ref", Token::kw_type_ref)
2162           .Default(Token::identifier);
2163   return Token(kind, str);
2164 }
2165 
2166 //===----------------------------------------------------------------------===//
2167 // FormatParser
2168 //===----------------------------------------------------------------------===//
2169 
2170 /// Function to find an element within the given range that has the same name as
2171 /// 'name'.
2172 template <typename RangeT>
findArg(RangeT && range,StringRef name)2173 static auto findArg(RangeT &&range, StringRef name) {
2174   auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
2175   return it != range.end() ? &*it : nullptr;
2176 }
2177 
2178 namespace {
2179 /// This class implements a parser for an instance of an operation assembly
2180 /// format.
2181 class FormatParser {
2182 public:
FormatParser(llvm::SourceMgr & mgr,OperationFormat & format,Operator & op)2183   FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
2184       : lexer(mgr, op), curToken(lexer.lexToken()), fmt(format), op(op),
2185         seenOperandTypes(op.getNumOperands()),
2186         seenResultTypes(op.getNumResults()) {}
2187 
2188   /// Parse the operation assembly format.
2189   LogicalResult parse();
2190 
2191 private:
2192   /// This struct represents a type resolution instance. It includes a specific
2193   /// type as well as an optional transformer to apply to that type in order to
2194   /// properly resolve the type of a variable.
2195   struct TypeResolutionInstance {
2196     ConstArgument resolver;
2197     Optional<StringRef> transformer;
2198   };
2199 
2200   /// An iterator over the elements of a format group.
2201   using ElementsIterT = llvm::pointee_iterator<
2202       std::vector<std::unique_ptr<Element>>::const_iterator>;
2203 
2204   /// Verify the state of operation attributes within the format.
2205   LogicalResult verifyAttributes(llvm::SMLoc loc);
2206   /// Verify the attribute elements at the back of the given stack of iterators.
2207   LogicalResult verifyAttributes(
2208       llvm::SMLoc loc,
2209       SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack);
2210 
2211   /// Verify the state of operation operands within the format.
2212   LogicalResult
2213   verifyOperands(llvm::SMLoc loc,
2214                  llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2215 
2216   /// Verify the state of operation regions within the format.
2217   LogicalResult verifyRegions(llvm::SMLoc loc);
2218 
2219   /// Verify the state of operation results within the format.
2220   LogicalResult
2221   verifyResults(llvm::SMLoc loc,
2222                 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2223 
2224   /// Verify the state of operation successors within the format.
2225   LogicalResult verifySuccessors(llvm::SMLoc loc);
2226 
2227   /// Given the values of an `AllTypesMatch` trait, check for inferable type
2228   /// resolution.
2229   void handleAllTypesMatchConstraint(
2230       ArrayRef<StringRef> values,
2231       llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2232   /// Check for inferable type resolution given all operands, and or results,
2233   /// have the same type. If 'includeResults' is true, the results also have the
2234   /// same type as all of the operands.
2235   void handleSameTypesConstraint(
2236       llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2237       bool includeResults);
2238   /// Check for inferable type resolution based on another operand, result, or
2239   /// attribute.
2240   void handleTypesMatchConstraint(
2241       llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2242       llvm::Record def);
2243 
2244   /// Returns an argument or attribute with the given name that has been seen
2245   /// within the format.
2246   ConstArgument findSeenArg(StringRef name);
2247 
2248   /// Parse a specific element.
2249   LogicalResult parseElement(std::unique_ptr<Element> &element,
2250                              bool isTopLevel);
2251   LogicalResult parseVariable(std::unique_ptr<Element> &element,
2252                               bool isTopLevel);
2253   LogicalResult parseDirective(std::unique_ptr<Element> &element,
2254                                bool isTopLevel);
2255   LogicalResult parseLiteral(std::unique_ptr<Element> &element);
2256   LogicalResult parseOptional(std::unique_ptr<Element> &element,
2257                               bool isTopLevel);
2258   LogicalResult parseOptionalChildElement(
2259       std::vector<std::unique_ptr<Element>> &childElements,
2260       Optional<unsigned> &anchorIdx);
2261   LogicalResult verifyOptionalChildElement(Element *element,
2262                                            llvm::SMLoc childLoc, bool isAnchor);
2263 
2264   /// Parse the various different directives.
2265   LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
2266                                        llvm::SMLoc loc, bool isTopLevel,
2267                                        bool withKeyword);
2268   LogicalResult parseCustomDirective(std::unique_ptr<Element> &element,
2269                                      llvm::SMLoc loc, bool isTopLevel);
2270   LogicalResult parseCustomDirectiveParameter(
2271       std::vector<std::unique_ptr<Element>> &parameters);
2272   LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
2273                                              Token tok, bool isTopLevel);
2274   LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
2275                                        llvm::SMLoc loc, bool isTopLevel);
2276   LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element,
2277                                       llvm::SMLoc loc, bool isTopLevel);
2278   LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
2279                                       llvm::SMLoc loc, bool isTopLevel);
2280   LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
2281                                          llvm::SMLoc loc, bool isTopLevel);
2282   LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
2283                                    bool isTopLevel, bool isTypeRef = false);
2284   LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
2285                                           bool isTypeRef = false);
2286 
2287   //===--------------------------------------------------------------------===//
2288   // Lexer Utilities
2289   //===--------------------------------------------------------------------===//
2290 
2291   /// Advance the current lexer onto the next token.
consumeToken()2292   void consumeToken() {
2293     assert(curToken.getKind() != Token::eof &&
2294            curToken.getKind() != Token::error &&
2295            "shouldn't advance past EOF or errors");
2296     curToken = lexer.lexToken();
2297   }
parseToken(Token::Kind kind,const Twine & msg)2298   LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
2299     if (curToken.getKind() != kind)
2300       return emitError(curToken.getLoc(), msg);
2301     consumeToken();
2302     return ::mlir::success();
2303   }
emitError(llvm::SMLoc loc,const Twine & msg)2304   LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
2305     lexer.emitError(loc, msg);
2306     return ::mlir::failure();
2307   }
emitErrorAndNote(llvm::SMLoc loc,const Twine & msg,const Twine & note)2308   LogicalResult emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
2309                                  const Twine &note) {
2310     lexer.emitErrorAndNote(loc, msg, note);
2311     return ::mlir::failure();
2312   }
2313 
2314   //===--------------------------------------------------------------------===//
2315   // Fields
2316   //===--------------------------------------------------------------------===//
2317 
2318   FormatLexer lexer;
2319   Token curToken;
2320   OperationFormat &fmt;
2321   Operator &op;
2322 
2323   // The following are various bits of format state used for verification
2324   // during parsing.
2325   bool hasAttrDict = false;
2326   bool hasAllRegions = false, hasAllSuccessors = false;
2327   llvm::SmallBitVector seenOperandTypes, seenResultTypes;
2328   llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
2329   llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2330   llvm::DenseSet<const NamedRegion *> seenRegions;
2331   llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2332 };
2333 } // end anonymous namespace
2334 
parse()2335 LogicalResult FormatParser::parse() {
2336   llvm::SMLoc loc = curToken.getLoc();
2337 
2338   // Parse each of the format elements into the main format.
2339   while (curToken.getKind() != Token::eof) {
2340     std::unique_ptr<Element> element;
2341     if (failed(parseElement(element, /*isTopLevel=*/true)))
2342       return ::mlir::failure();
2343     fmt.elements.push_back(std::move(element));
2344   }
2345 
2346   // Check that the attribute dictionary is in the format.
2347   if (!hasAttrDict)
2348     return emitError(loc, "'attr-dict' directive not found in "
2349                           "custom assembly format");
2350 
2351   // Check for any type traits that we can use for inferring types.
2352   llvm::StringMap<TypeResolutionInstance> variableTyResolver;
2353   for (const OpTrait &trait : op.getTraits()) {
2354     const llvm::Record &def = trait.getDef();
2355     if (def.isSubClassOf("AllTypesMatch")) {
2356       handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
2357                                     variableTyResolver);
2358     } else if (def.getName() == "SameTypeOperands") {
2359       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
2360     } else if (def.getName() == "SameOperandsAndResultType") {
2361       handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
2362     } else if (def.isSubClassOf("TypesMatchWith")) {
2363       handleTypesMatchConstraint(variableTyResolver, def);
2364     }
2365   }
2366 
2367   // Verify the state of the various operation components.
2368   if (failed(verifyAttributes(loc)) ||
2369       failed(verifyResults(loc, variableTyResolver)) ||
2370       failed(verifyOperands(loc, variableTyResolver)) ||
2371       failed(verifyRegions(loc)) || failed(verifySuccessors(loc)))
2372     return ::mlir::failure();
2373 
2374   // Collect the set of used attributes in the format.
2375   fmt.usedAttributes = seenAttrs.takeVector();
2376   return ::mlir::success();
2377 }
2378 
verifyAttributes(llvm::SMLoc loc)2379 LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) {
2380   // Check that there are no `:` literals after an attribute without a constant
2381   // type. The attribute grammar contains an optional trailing colon type, which
2382   // can lead to unexpected and generally unintended behavior. Given that, it is
2383   // better to just error out here instead.
2384   using ElementsIterT = llvm::pointee_iterator<
2385       std::vector<std::unique_ptr<Element>>::const_iterator>;
2386   SmallVector<std::pair<ElementsIterT, ElementsIterT>, 1> iteratorStack;
2387   iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end());
2388   while (!iteratorStack.empty())
2389     if (failed(verifyAttributes(loc, iteratorStack)))
2390       return ::mlir::failure();
2391   return ::mlir::success();
2392 }
2393 /// Verify the attribute elements at the back of the given stack of iterators.
verifyAttributes(llvm::SMLoc loc,SmallVectorImpl<std::pair<ElementsIterT,ElementsIterT>> & iteratorStack)2394 LogicalResult FormatParser::verifyAttributes(
2395     llvm::SMLoc loc,
2396     SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack) {
2397   auto &stackIt = iteratorStack.back();
2398   ElementsIterT &it = stackIt.first, e = stackIt.second;
2399   while (it != e) {
2400     Element *element = &*(it++);
2401 
2402     // Traverse into optional groups.
2403     if (auto *optional = dyn_cast<OptionalElement>(element)) {
2404       auto elements = optional->getElements();
2405       iteratorStack.emplace_back(elements.begin(), elements.end());
2406       return ::mlir::success();
2407     }
2408 
2409     // We are checking for an attribute element followed by a `:`, so there is
2410     // no need to check the end.
2411     if (it == e && iteratorStack.size() == 1)
2412       break;
2413 
2414     // Check for an attribute with a constant type builder, followed by a `:`.
2415     auto *prevAttr = dyn_cast<AttributeVariable>(element);
2416     if (!prevAttr || prevAttr->getTypeBuilder())
2417       continue;
2418 
2419     // Check the next iterator within the stack for literal elements.
2420     for (auto &nextItPair : iteratorStack) {
2421       ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second;
2422       for (; nextIt != nextE; ++nextIt) {
2423         // Skip any trailing whitespace, attribute dictionaries, or optional
2424         // groups.
2425         if (isa<WhitespaceElement>(*nextIt) ||
2426             isa<AttrDictDirective>(*nextIt) || isa<OptionalElement>(*nextIt))
2427           continue;
2428 
2429         // We are only interested in `:` literals.
2430         auto *literal = dyn_cast<LiteralElement>(&*nextIt);
2431         if (!literal || literal->getLiteral() != ":")
2432           break;
2433 
2434         // TODO: Use the location of the literal element itself.
2435         return emitError(
2436             loc, llvm::formatv("format ambiguity caused by `:` literal found "
2437                                "after attribute `{0}` which does not have "
2438                                "a buildable type",
2439                                prevAttr->getVar()->name));
2440       }
2441     }
2442   }
2443   iteratorStack.pop_back();
2444   return ::mlir::success();
2445 }
2446 
verifyOperands(llvm::SMLoc loc,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2447 LogicalResult FormatParser::verifyOperands(
2448     llvm::SMLoc loc,
2449     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2450   // Check that all of the operands are within the format, and their types can
2451   // be inferred.
2452   auto &buildableTypes = fmt.buildableTypes;
2453   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
2454     NamedTypeConstraint &operand = op.getOperand(i);
2455 
2456     // Check that the operand itself is in the format.
2457     if (!fmt.allOperands && !seenOperands.count(&operand)) {
2458       return emitErrorAndNote(loc,
2459                               "operand #" + Twine(i) + ", named '" +
2460                                   operand.name + "', not found",
2461                               "suggest adding a '$" + operand.name +
2462                                   "' directive to the custom assembly format");
2463     }
2464 
2465     // Check that the operand type is in the format, or that it can be inferred.
2466     if (fmt.allOperandTypes || seenOperandTypes.test(i))
2467       continue;
2468 
2469     // Check to see if we can infer this type from another variable.
2470     auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
2471     if (varResolverIt != variableTyResolver.end()) {
2472       TypeResolutionInstance &resolver = varResolverIt->second;
2473       fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
2474       continue;
2475     }
2476 
2477     // Similarly to results, allow a custom builder for resolving the type if
2478     // we aren't using the 'operands' directive.
2479     Optional<StringRef> builder = operand.constraint.getBuilderCall();
2480     if (!builder || (fmt.allOperands && operand.isVariableLength())) {
2481       return emitErrorAndNote(
2482           loc,
2483           "type of operand #" + Twine(i) + ", named '" + operand.name +
2484               "', is not buildable and a buildable type cannot be inferred",
2485           "suggest adding a type constraint to the operation or adding a "
2486           "'type($" +
2487               operand.name + ")' directive to the " + "custom assembly format");
2488     }
2489     auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2490     fmt.operandTypes[i].setBuilderIdx(it.first->second);
2491   }
2492   return ::mlir::success();
2493 }
2494 
verifyRegions(llvm::SMLoc loc)2495 LogicalResult FormatParser::verifyRegions(llvm::SMLoc loc) {
2496   // Check that all of the regions are within the format.
2497   if (hasAllRegions)
2498     return ::mlir::success();
2499 
2500   for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
2501     const NamedRegion &region = op.getRegion(i);
2502     if (!seenRegions.count(&region)) {
2503       return emitErrorAndNote(loc,
2504                               "region #" + Twine(i) + ", named '" +
2505                                   region.name + "', not found",
2506                               "suggest adding a '$" + region.name +
2507                                   "' directive to the custom assembly format");
2508     }
2509   }
2510   return ::mlir::success();
2511 }
2512 
verifyResults(llvm::SMLoc loc,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2513 LogicalResult FormatParser::verifyResults(
2514     llvm::SMLoc loc,
2515     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2516   // If we format all of the types together, there is nothing to check.
2517   if (fmt.allResultTypes)
2518     return ::mlir::success();
2519 
2520   // Check that all of the result types can be inferred.
2521   auto &buildableTypes = fmt.buildableTypes;
2522   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
2523     if (seenResultTypes.test(i))
2524       continue;
2525 
2526     // Check to see if we can infer this type from another variable.
2527     auto varResolverIt = variableTyResolver.find(op.getResultName(i));
2528     if (varResolverIt != variableTyResolver.end()) {
2529       TypeResolutionInstance resolver = varResolverIt->second;
2530       fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
2531       continue;
2532     }
2533 
2534     // If the result is not variable length, allow for the case where the type
2535     // has a builder that we can use.
2536     NamedTypeConstraint &result = op.getResult(i);
2537     Optional<StringRef> builder = result.constraint.getBuilderCall();
2538     if (!builder || result.isVariableLength()) {
2539       return emitErrorAndNote(
2540           loc,
2541           "type of result #" + Twine(i) + ", named '" + result.name +
2542               "', is not buildable and a buildable type cannot be inferred",
2543           "suggest adding a type constraint to the operation or adding a "
2544           "'type($" +
2545               result.name + ")' directive to the " + "custom assembly format");
2546     }
2547     // Note in the format that this result uses the custom builder.
2548     auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2549     fmt.resultTypes[i].setBuilderIdx(it.first->second);
2550   }
2551   return ::mlir::success();
2552 }
2553 
verifySuccessors(llvm::SMLoc loc)2554 LogicalResult FormatParser::verifySuccessors(llvm::SMLoc loc) {
2555   // Check that all of the successors are within the format.
2556   if (hasAllSuccessors)
2557     return ::mlir::success();
2558 
2559   for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
2560     const NamedSuccessor &successor = op.getSuccessor(i);
2561     if (!seenSuccessors.count(&successor)) {
2562       return emitErrorAndNote(loc,
2563                               "successor #" + Twine(i) + ", named '" +
2564                                   successor.name + "', not found",
2565                               "suggest adding a '$" + successor.name +
2566                                   "' directive to the custom assembly format");
2567     }
2568   }
2569   return ::mlir::success();
2570 }
2571 
handleAllTypesMatchConstraint(ArrayRef<StringRef> values,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2572 void FormatParser::handleAllTypesMatchConstraint(
2573     ArrayRef<StringRef> values,
2574     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2575   for (unsigned i = 0, e = values.size(); i != e; ++i) {
2576     // Check to see if this value matches a resolved operand or result type.
2577     ConstArgument arg = findSeenArg(values[i]);
2578     if (!arg)
2579       continue;
2580 
2581     // Mark this value as the type resolver for the other variables.
2582     for (unsigned j = 0; j != i; ++j)
2583       variableTyResolver[values[j]] = {arg, llvm::None};
2584     for (unsigned j = i + 1; j != e; ++j)
2585       variableTyResolver[values[j]] = {arg, llvm::None};
2586   }
2587 }
2588 
handleSameTypesConstraint(llvm::StringMap<TypeResolutionInstance> & variableTyResolver,bool includeResults)2589 void FormatParser::handleSameTypesConstraint(
2590     llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2591     bool includeResults) {
2592   const NamedTypeConstraint *resolver = nullptr;
2593   int resolvedIt = -1;
2594 
2595   // Check to see if there is an operand or result to use for the resolution.
2596   if ((resolvedIt = seenOperandTypes.find_first()) != -1)
2597     resolver = &op.getOperand(resolvedIt);
2598   else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
2599     resolver = &op.getResult(resolvedIt);
2600   else
2601     return;
2602 
2603   // Set the resolvers for each operand and result.
2604   for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
2605     if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
2606       variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
2607   if (includeResults) {
2608     for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2609       if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
2610         variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
2611   }
2612 }
2613 
handleTypesMatchConstraint(llvm::StringMap<TypeResolutionInstance> & variableTyResolver,llvm::Record def)2614 void FormatParser::handleTypesMatchConstraint(
2615     llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2616     llvm::Record def) {
2617   StringRef lhsName = def.getValueAsString("lhs");
2618   StringRef rhsName = def.getValueAsString("rhs");
2619   StringRef transformer = def.getValueAsString("transformer");
2620   if (ConstArgument arg = findSeenArg(lhsName))
2621     variableTyResolver[rhsName] = {arg, transformer};
2622 }
2623 
findSeenArg(StringRef name)2624 ConstArgument FormatParser::findSeenArg(StringRef name) {
2625   if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
2626     return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
2627   if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
2628     return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
2629   if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
2630     return seenAttrs.count(attr) ? attr : nullptr;
2631   return nullptr;
2632 }
2633 
parseElement(std::unique_ptr<Element> & element,bool isTopLevel)2634 LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
2635                                          bool isTopLevel) {
2636   // Directives.
2637   if (curToken.isKeyword())
2638     return parseDirective(element, isTopLevel);
2639   // Literals.
2640   if (curToken.getKind() == Token::literal)
2641     return parseLiteral(element);
2642   // Optionals.
2643   if (curToken.getKind() == Token::l_paren)
2644     return parseOptional(element, isTopLevel);
2645   // Variables.
2646   if (curToken.getKind() == Token::variable)
2647     return parseVariable(element, isTopLevel);
2648   return emitError(curToken.getLoc(),
2649                    "expected directive, literal, variable, or optional group");
2650 }
2651 
parseVariable(std::unique_ptr<Element> & element,bool isTopLevel)2652 LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
2653                                           bool isTopLevel) {
2654   Token varTok = curToken;
2655   consumeToken();
2656 
2657   StringRef name = varTok.getSpelling().drop_front();
2658   llvm::SMLoc loc = varTok.getLoc();
2659 
2660   // Check that the parsed argument is something actually registered on the
2661   // op.
2662   /// Attributes
2663   if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
2664     if (isTopLevel && !seenAttrs.insert(attr))
2665       return emitError(loc, "attribute '" + name + "' is already bound");
2666     element = std::make_unique<AttributeVariable>(attr);
2667     return ::mlir::success();
2668   }
2669   /// Operands
2670   if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
2671     if (isTopLevel) {
2672       if (fmt.allOperands || !seenOperands.insert(operand).second)
2673         return emitError(loc, "operand '" + name + "' is already bound");
2674     }
2675     element = std::make_unique<OperandVariable>(operand);
2676     return ::mlir::success();
2677   }
2678   /// Regions
2679   if (const NamedRegion *region = findArg(op.getRegions(), name)) {
2680     if (!isTopLevel)
2681       return emitError(loc, "regions can only be used at the top level");
2682     if (hasAllRegions || !seenRegions.insert(region).second)
2683       return emitError(loc, "region '" + name + "' is already bound");
2684     element = std::make_unique<RegionVariable>(region);
2685     return ::mlir::success();
2686   }
2687   /// Results.
2688   if (const auto *result = findArg(op.getResults(), name)) {
2689     if (isTopLevel)
2690       return emitError(loc, "results can not be used at the top level");
2691     element = std::make_unique<ResultVariable>(result);
2692     return ::mlir::success();
2693   }
2694   /// Successors.
2695   if (const auto *successor = findArg(op.getSuccessors(), name)) {
2696     if (!isTopLevel)
2697       return emitError(loc, "successors can only be used at the top level");
2698     if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
2699       return emitError(loc, "successor '" + name + "' is already bound");
2700     element = std::make_unique<SuccessorVariable>(successor);
2701     return ::mlir::success();
2702   }
2703   return emitError(loc, "expected variable to refer to an argument, region, "
2704                         "result, or successor");
2705 }
2706 
parseDirective(std::unique_ptr<Element> & element,bool isTopLevel)2707 LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
2708                                            bool isTopLevel) {
2709   Token dirTok = curToken;
2710   consumeToken();
2711 
2712   switch (dirTok.getKind()) {
2713   case Token::kw_attr_dict:
2714     return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
2715                                   /*withKeyword=*/false);
2716   case Token::kw_attr_dict_w_keyword:
2717     return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
2718                                   /*withKeyword=*/true);
2719   case Token::kw_custom:
2720     return parseCustomDirective(element, dirTok.getLoc(), isTopLevel);
2721   case Token::kw_functional_type:
2722     return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
2723   case Token::kw_operands:
2724     return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel);
2725   case Token::kw_regions:
2726     return parseRegionsDirective(element, dirTok.getLoc(), isTopLevel);
2727   case Token::kw_results:
2728     return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
2729   case Token::kw_successors:
2730     return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel);
2731   case Token::kw_type_ref:
2732     return parseTypeDirective(element, dirTok, isTopLevel, /*isTypeRef=*/true);
2733   case Token::kw_type:
2734     return parseTypeDirective(element, dirTok, isTopLevel);
2735 
2736   default:
2737     llvm_unreachable("unknown directive token");
2738   }
2739 }
2740 
parseLiteral(std::unique_ptr<Element> & element)2741 LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element) {
2742   Token literalTok = curToken;
2743   consumeToken();
2744 
2745   StringRef value = literalTok.getSpelling().drop_front().drop_back();
2746 
2747   // The parsed literal is a space element (`` or ` `).
2748   if (value.empty() || (value.size() == 1 && value.front() == ' ')) {
2749     element = std::make_unique<SpaceElement>(!value.empty());
2750     return ::mlir::success();
2751   }
2752   // The parsed literal is a newline element.
2753   if (value == "\\n") {
2754     element = std::make_unique<NewlineElement>();
2755     return ::mlir::success();
2756   }
2757 
2758   // Check that the parsed literal is valid.
2759   if (!LiteralElement::isValidLiteral(value))
2760     return emitError(literalTok.getLoc(), "expected valid literal");
2761 
2762   element = std::make_unique<LiteralElement>(value);
2763   return ::mlir::success();
2764 }
2765 
parseOptional(std::unique_ptr<Element> & element,bool isTopLevel)2766 LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
2767                                           bool isTopLevel) {
2768   llvm::SMLoc curLoc = curToken.getLoc();
2769   if (!isTopLevel)
2770     return emitError(curLoc, "optional groups can only be used as top-level "
2771                              "elements");
2772   consumeToken();
2773 
2774   // Parse the child elements for this optional group.
2775   std::vector<std::unique_ptr<Element>> elements;
2776   Optional<unsigned> anchorIdx;
2777   do {
2778     if (failed(parseOptionalChildElement(elements, anchorIdx)))
2779       return ::mlir::failure();
2780   } while (curToken.getKind() != Token::r_paren);
2781   consumeToken();
2782   if (failed(parseToken(Token::question, "expected '?' after optional group")))
2783     return ::mlir::failure();
2784 
2785   // The optional group is required to have an anchor.
2786   if (!anchorIdx)
2787     return emitError(curLoc, "optional group specified no anchor element");
2788 
2789   // The first parsable element of the group must be able to be parsed in an
2790   // optional fashion.
2791   auto parseBegin = llvm::find_if_not(elements, [](auto &element) {
2792     return isa<WhitespaceElement>(element.get());
2793   });
2794   Element *firstElement = parseBegin->get();
2795   if (!isa<AttributeVariable>(firstElement) &&
2796       !isa<LiteralElement>(firstElement) &&
2797       !isa<OperandVariable>(firstElement) && !isa<RegionVariable>(firstElement))
2798     return emitError(curLoc,
2799                      "first parsable element of an operand group must be "
2800                      "an attribute, literal, operand, or region");
2801 
2802   auto parseStart = parseBegin - elements.begin();
2803   element = std::make_unique<OptionalElement>(std::move(elements), *anchorIdx,
2804                                               parseStart);
2805   return ::mlir::success();
2806 }
2807 
parseOptionalChildElement(std::vector<std::unique_ptr<Element>> & childElements,Optional<unsigned> & anchorIdx)2808 LogicalResult FormatParser::parseOptionalChildElement(
2809     std::vector<std::unique_ptr<Element>> &childElements,
2810     Optional<unsigned> &anchorIdx) {
2811   llvm::SMLoc childLoc = curToken.getLoc();
2812   childElements.push_back({});
2813   if (failed(parseElement(childElements.back(), /*isTopLevel=*/true)))
2814     return ::mlir::failure();
2815 
2816   // Check to see if this element is the anchor of the optional group.
2817   bool isAnchor = curToken.getKind() == Token::caret;
2818   if (isAnchor) {
2819     if (anchorIdx)
2820       return emitError(childLoc, "only one element can be marked as the anchor "
2821                                  "of an optional group");
2822     anchorIdx = childElements.size() - 1;
2823     consumeToken();
2824   }
2825 
2826   return verifyOptionalChildElement(childElements.back().get(), childLoc,
2827                                     isAnchor);
2828 }
2829 
verifyOptionalChildElement(Element * element,llvm::SMLoc childLoc,bool isAnchor)2830 LogicalResult FormatParser::verifyOptionalChildElement(Element *element,
2831                                                        llvm::SMLoc childLoc,
2832                                                        bool isAnchor) {
2833   return TypeSwitch<Element *, LogicalResult>(element)
2834       // All attributes can be within the optional group, but only optional
2835       // attributes can be the anchor.
2836       .Case([&](AttributeVariable *attrEle) {
2837         if (isAnchor && !attrEle->getVar()->attr.isOptional())
2838           return emitError(childLoc, "only optional attributes can be used to "
2839                                      "anchor an optional group");
2840         return ::mlir::success();
2841       })
2842       // Only optional-like(i.e. variadic) operands can be within an optional
2843       // group.
2844       .Case<OperandVariable>([&](OperandVariable *ele) {
2845         if (!ele->getVar()->isVariableLength())
2846           return emitError(childLoc, "only variable length operands can be "
2847                                      "used within an optional group");
2848         return ::mlir::success();
2849       })
2850       // Only optional-like(i.e. variadic) results can be within an optional
2851       // group.
2852       .Case<ResultVariable>([&](ResultVariable *ele) {
2853         if (!ele->getVar()->isVariableLength())
2854           return emitError(childLoc, "only variable length results can be "
2855                                      "used within an optional group");
2856         return ::mlir::success();
2857       })
2858       .Case<RegionVariable>([&](RegionVariable *) {
2859         // TODO: When ODS has proper support for marking "optional" regions, add
2860         // a check here.
2861         return ::mlir::success();
2862       })
2863       .Case<TypeDirective>([&](TypeDirective *ele) {
2864         return verifyOptionalChildElement(ele->getOperand(), childLoc,
2865                                           /*isAnchor=*/false);
2866       })
2867       .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *ele) {
2868         if (failed(verifyOptionalChildElement(ele->getInputs(), childLoc,
2869                                               /*isAnchor=*/false)))
2870           return failure();
2871         return verifyOptionalChildElement(ele->getResults(), childLoc,
2872                                           /*isAnchor=*/false);
2873       })
2874       // Literals, whitespace, and custom directives may be used, but they can't
2875       // anchor the group.
2876       .Case<LiteralElement, WhitespaceElement, CustomDirective,
2877             FunctionalTypeDirective, OptionalElement, TypeRefDirective>(
2878           [&](Element *) {
2879             if (isAnchor)
2880               return emitError(childLoc, "only variables and types can be used "
2881                                          "to anchor an optional group");
2882             return ::mlir::success();
2883           })
2884       .Default([&](Element *) {
2885         return emitError(childLoc, "only literals, types, and variables can be "
2886                                    "used within an optional group");
2887       });
2888 }
2889 
2890 LogicalResult
parseAttrDictDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel,bool withKeyword)2891 FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
2892                                      llvm::SMLoc loc, bool isTopLevel,
2893                                      bool withKeyword) {
2894   if (!isTopLevel)
2895     return emitError(loc, "'attr-dict' directive can only be used as a "
2896                           "top-level directive");
2897   if (hasAttrDict)
2898     return emitError(loc, "'attr-dict' directive has already been seen");
2899 
2900   hasAttrDict = true;
2901   element = std::make_unique<AttrDictDirective>(withKeyword);
2902   return ::mlir::success();
2903 }
2904 
2905 LogicalResult
parseCustomDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)2906 FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
2907                                    llvm::SMLoc loc, bool isTopLevel) {
2908   llvm::SMLoc curLoc = curToken.getLoc();
2909 
2910   // Parse the custom directive name.
2911   if (failed(
2912           parseToken(Token::less, "expected '<' before custom directive name")))
2913     return ::mlir::failure();
2914 
2915   Token nameTok = curToken;
2916   if (failed(parseToken(Token::identifier,
2917                         "expected custom directive name identifier")) ||
2918       failed(parseToken(Token::greater,
2919                         "expected '>' after custom directive name")) ||
2920       failed(parseToken(Token::l_paren,
2921                         "expected '(' before custom directive parameters")))
2922     return ::mlir::failure();
2923 
2924   // Parse the child elements for this optional group.=
2925   std::vector<std::unique_ptr<Element>> elements;
2926   do {
2927     if (failed(parseCustomDirectiveParameter(elements)))
2928       return ::mlir::failure();
2929     if (curToken.getKind() != Token::comma)
2930       break;
2931     consumeToken();
2932   } while (true);
2933 
2934   if (failed(parseToken(Token::r_paren,
2935                         "expected ')' after custom directive parameters")))
2936     return ::mlir::failure();
2937 
2938   // After parsing all of the elements, ensure that all type directives refer
2939   // only to variables.
2940   for (auto &ele : elements) {
2941     if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
2942       if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
2943         return emitError(curLoc,
2944                          "type_ref directives within a custom directive "
2945                          "may only refer to variables");
2946       }
2947     }
2948     if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
2949       if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
2950         return emitError(curLoc, "type directives within a custom directive "
2951                                  "may only refer to variables");
2952       }
2953     }
2954   }
2955 
2956   element = std::make_unique<CustomDirective>(nameTok.getSpelling(),
2957                                               std::move(elements));
2958   return ::mlir::success();
2959 }
2960 
parseCustomDirectiveParameter(std::vector<std::unique_ptr<Element>> & parameters)2961 LogicalResult FormatParser::parseCustomDirectiveParameter(
2962     std::vector<std::unique_ptr<Element>> &parameters) {
2963   llvm::SMLoc childLoc = curToken.getLoc();
2964   parameters.push_back({});
2965   if (failed(parseElement(parameters.back(), /*isTopLevel=*/true)))
2966     return ::mlir::failure();
2967 
2968   // Verify that the element can be placed within a custom directive.
2969   if (!isa<TypeRefDirective, TypeDirective, AttrDictDirective,
2970            AttributeVariable, OperandVariable, RegionVariable,
2971            SuccessorVariable>(parameters.back().get())) {
2972     return emitError(childLoc, "only variables and types may be used as "
2973                                "parameters to a custom directive");
2974   }
2975   return ::mlir::success();
2976 }
2977 
2978 LogicalResult
parseFunctionalTypeDirective(std::unique_ptr<Element> & element,Token tok,bool isTopLevel)2979 FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
2980                                            Token tok, bool isTopLevel) {
2981   llvm::SMLoc loc = tok.getLoc();
2982   if (!isTopLevel)
2983     return emitError(
2984         loc, "'functional-type' is only valid as a top-level directive");
2985 
2986   // Parse the main operand.
2987   std::unique_ptr<Element> inputs, results;
2988   if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
2989       failed(parseTypeDirectiveOperand(inputs)) ||
2990       failed(parseToken(Token::comma, "expected ',' after inputs argument")) ||
2991       failed(parseTypeDirectiveOperand(results)) ||
2992       failed(parseToken(Token::r_paren, "expected ')' after argument list")))
2993     return ::mlir::failure();
2994   element = std::make_unique<FunctionalTypeDirective>(std::move(inputs),
2995                                                       std::move(results));
2996   return ::mlir::success();
2997 }
2998 
2999 LogicalResult
parseOperandsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)3000 FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
3001                                      llvm::SMLoc loc, bool isTopLevel) {
3002   if (isTopLevel) {
3003     if (fmt.allOperands || !seenOperands.empty())
3004       return emitError(loc, "'operands' directive creates overlap in format");
3005     fmt.allOperands = true;
3006   }
3007   element = std::make_unique<OperandsDirective>();
3008   return ::mlir::success();
3009 }
3010 
3011 LogicalResult
parseRegionsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)3012 FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element,
3013                                     llvm::SMLoc loc, bool isTopLevel) {
3014   if (!isTopLevel)
3015     return emitError(loc, "'regions' is only valid as a top-level directive");
3016   if (hasAllRegions || !seenRegions.empty())
3017     return emitError(loc, "'regions' directive creates overlap in format");
3018   hasAllRegions = true;
3019   element = std::make_unique<RegionsDirective>();
3020   return ::mlir::success();
3021 }
3022 
3023 LogicalResult
parseResultsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)3024 FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
3025                                     llvm::SMLoc loc, bool isTopLevel) {
3026   if (isTopLevel)
3027     return emitError(loc, "'results' directive can not be used as a "
3028                           "top-level directive");
3029   element = std::make_unique<ResultsDirective>();
3030   return ::mlir::success();
3031 }
3032 
3033 LogicalResult
parseSuccessorsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)3034 FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
3035                                        llvm::SMLoc loc, bool isTopLevel) {
3036   if (!isTopLevel)
3037     return emitError(loc,
3038                      "'successors' is only valid as a top-level directive");
3039   if (hasAllSuccessors || !seenSuccessors.empty())
3040     return emitError(loc, "'successors' directive creates overlap in format");
3041   hasAllSuccessors = true;
3042   element = std::make_unique<SuccessorsDirective>();
3043   return ::mlir::success();
3044 }
3045 
3046 LogicalResult
parseTypeDirective(std::unique_ptr<Element> & element,Token tok,bool isTopLevel,bool isTypeRef)3047 FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
3048                                  bool isTopLevel, bool isTypeRef) {
3049   llvm::SMLoc loc = tok.getLoc();
3050   if (!isTopLevel)
3051     return emitError(loc, "'type' is only valid as a top-level directive");
3052 
3053   std::unique_ptr<Element> operand;
3054   if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
3055       failed(parseTypeDirectiveOperand(operand, isTypeRef)) ||
3056       failed(parseToken(Token::r_paren, "expected ')' after argument list")))
3057     return ::mlir::failure();
3058   if (isTypeRef)
3059     element = std::make_unique<TypeRefDirective>(std::move(operand));
3060   else
3061     element = std::make_unique<TypeDirective>(std::move(operand));
3062   return ::mlir::success();
3063 }
3064 
3065 LogicalResult
parseTypeDirectiveOperand(std::unique_ptr<Element> & element,bool isTypeRef)3066 FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
3067                                         bool isTypeRef) {
3068   llvm::SMLoc loc = curToken.getLoc();
3069   if (failed(parseElement(element, /*isTopLevel=*/false)))
3070     return ::mlir::failure();
3071   if (isa<LiteralElement>(element.get()))
3072     return emitError(
3073         loc, "'type' directive operand expects variable or directive operand");
3074 
3075   if (auto *var = dyn_cast<OperandVariable>(element.get())) {
3076     unsigned opIdx = var->getVar() - op.operand_begin();
3077     if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3078       return emitError(loc, "'type' of '" + var->getVar()->name +
3079                                 "' is already bound");
3080     if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3081       return emitError(loc, "'type_ref' of '" + var->getVar()->name +
3082                                 "' is not bound by a prior 'type' directive");
3083     seenOperandTypes.set(opIdx);
3084   } else if (auto *var = dyn_cast<ResultVariable>(element.get())) {
3085     unsigned resIdx = var->getVar() - op.result_begin();
3086     if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
3087       return emitError(loc, "'type' of '" + var->getVar()->name +
3088                                 "' is already bound");
3089     if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
3090       return emitError(loc, "'type_ref' of '" + var->getVar()->name +
3091                                 "' is not bound by a prior 'type' directive");
3092     seenResultTypes.set(resIdx);
3093   } else if (isa<OperandsDirective>(&*element)) {
3094     if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.any()))
3095       return emitError(loc, "'operands' 'type' is already bound");
3096     if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.all()))
3097       return emitError(
3098           loc,
3099           "'operands' 'type_ref' is not bound by a prior 'type' directive");
3100     fmt.allOperandTypes = true;
3101   } else if (isa<ResultsDirective>(&*element)) {
3102     if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.any()))
3103       return emitError(loc, "'results' 'type' is already bound");
3104     if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.all()))
3105       return emitError(
3106           loc, "'results' 'type_ref' is not bound by a prior 'type' directive");
3107     fmt.allResultTypes = true;
3108   } else {
3109     return emitError(loc, "invalid argument to 'type' directive");
3110   }
3111   return ::mlir::success();
3112 }
3113 
3114 //===----------------------------------------------------------------------===//
3115 // Interface
3116 //===----------------------------------------------------------------------===//
3117 
generateOpFormat(const Operator & constOp,OpClass & opClass)3118 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
3119   // TODO: Operator doesn't expose all necessary functionality via
3120   // the const interface.
3121   Operator &op = const_cast<Operator &>(constOp);
3122   if (!op.hasAssemblyFormat())
3123     return;
3124 
3125   // Parse the format description.
3126   llvm::SourceMgr mgr;
3127   mgr.AddNewSourceBuffer(
3128       llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), llvm::SMLoc());
3129   OperationFormat format(op);
3130   if (failed(FormatParser(mgr, format, op).parse())) {
3131     // Exit the process if format errors are treated as fatal.
3132     if (formatErrorIsFatal) {
3133       // Invoke the interrupt handlers to run the file cleanup handlers.
3134       llvm::sys::RunInterruptHandlers();
3135       std::exit(1);
3136     }
3137     return;
3138   }
3139 
3140   // Generate the printer and parser based on the parsed format.
3141   format.genParser(op, opClass);
3142   format.genPrinter(op, opClass);
3143 }
3144