1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "OpFormatGen.h"
10 #include "mlir/Support/LogicalResult.h"
11 #include "mlir/TableGen/Format.h"
12 #include "mlir/TableGen/GenInfo.h"
13 #include "mlir/TableGen/Interfaces.h"
14 #include "mlir/TableGen/OpClass.h"
15 #include "mlir/TableGen/OpTrait.h"
16 #include "mlir/TableGen/Operator.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Signals.h"
25 #include "llvm/TableGen/Error.h"
26 #include "llvm/TableGen/Record.h"
27
28 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
29
30 using namespace mlir;
31 using namespace mlir::tblgen;
32
33 static llvm::cl::opt<bool> formatErrorIsFatal(
34 "asmformat-error-is-fatal",
35 llvm::cl::desc("Emit a fatal error if format parsing fails"),
36 llvm::cl::init(true));
37
38 /// Returns true if the given string can be formatted as a keyword.
canFormatStringAsKeyword(StringRef value)39 static bool canFormatStringAsKeyword(StringRef value) {
40 if (!isalpha(value.front()) && value.front() != '_')
41 return false;
42 return llvm::all_of(value.drop_front(), [](char c) {
43 return isalnum(c) || c == '_' || c == '$' || c == '.';
44 });
45 }
46
47 //===----------------------------------------------------------------------===//
48 // Element
49 //===----------------------------------------------------------------------===//
50
51 namespace {
52 /// This class represents a single format element.
53 class Element {
54 public:
55 enum class Kind {
56 /// This element is a directive.
57 AttrDictDirective,
58 CustomDirective,
59 FunctionalTypeDirective,
60 OperandsDirective,
61 RegionsDirective,
62 ResultsDirective,
63 SuccessorsDirective,
64 TypeDirective,
65 TypeRefDirective,
66
67 /// This element is a literal.
68 Literal,
69
70 /// This element is a whitespace.
71 Newline,
72 Space,
73
74 /// This element is an variable value.
75 AttributeVariable,
76 OperandVariable,
77 RegionVariable,
78 ResultVariable,
79 SuccessorVariable,
80
81 /// This element is an optional element.
82 Optional,
83 };
Element(Kind kind)84 Element(Kind kind) : kind(kind) {}
85 virtual ~Element() = default;
86
87 /// Return the kind of this element.
getKind() const88 Kind getKind() const { return kind; }
89
90 private:
91 /// The kind of this element.
92 Kind kind;
93 };
94 } // namespace
95
96 //===----------------------------------------------------------------------===//
97 // VariableElement
98
99 namespace {
100 /// This class represents an instance of an variable element. A variable refers
101 /// to something registered on the operation itself, e.g. an argument, result,
102 /// etc.
103 template <typename VarT, Element::Kind kindVal>
104 class VariableElement : public Element {
105 public:
VariableElement(const VarT * var)106 VariableElement(const VarT *var) : Element(kindVal), var(var) {}
classof(const Element * element)107 static bool classof(const Element *element) {
108 return element->getKind() == kindVal;
109 }
getVar()110 const VarT *getVar() { return var; }
111
112 protected:
113 const VarT *var;
114 };
115
116 /// This class represents a variable that refers to an attribute argument.
117 struct AttributeVariable
118 : public VariableElement<NamedAttribute, Element::Kind::AttributeVariable> {
119 using VariableElement<NamedAttribute,
120 Element::Kind::AttributeVariable>::VariableElement;
121
122 /// Return the constant builder call for the type of this attribute, or None
123 /// if it doesn't have one.
getTypeBuilder__anon7e5610690311::AttributeVariable124 Optional<StringRef> getTypeBuilder() const {
125 Optional<Type> attrType = var->attr.getValueType();
126 return attrType ? attrType->getBuilderCall() : llvm::None;
127 }
128
129 /// Return if this attribute refers to a UnitAttr.
isUnitAttr__anon7e5610690311::AttributeVariable130 bool isUnitAttr() const {
131 return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
132 }
133 };
134
135 /// This class represents a variable that refers to an operand argument.
136 using OperandVariable =
137 VariableElement<NamedTypeConstraint, Element::Kind::OperandVariable>;
138
139 /// This class represents a variable that refers to a region.
140 using RegionVariable =
141 VariableElement<NamedRegion, Element::Kind::RegionVariable>;
142
143 /// This class represents a variable that refers to a result.
144 using ResultVariable =
145 VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
146
147 /// This class represents a variable that refers to a successor.
148 using SuccessorVariable =
149 VariableElement<NamedSuccessor, Element::Kind::SuccessorVariable>;
150 } // end anonymous namespace
151
152 //===----------------------------------------------------------------------===//
153 // DirectiveElement
154
155 namespace {
156 /// This class implements single kind directives.
157 template <Element::Kind type>
158 class DirectiveElement : public Element {
159 public:
DirectiveElement()160 DirectiveElement() : Element(type){};
classof(const Element * ele)161 static bool classof(const Element *ele) { return ele->getKind() == type; }
162 };
163 /// This class represents the `operands` directive. This directive represents
164 /// all of the operands of an operation.
165 using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
166
167 /// This class represents the `regions` directive. This directive represents
168 /// all of the regions of an operation.
169 using RegionsDirective = DirectiveElement<Element::Kind::RegionsDirective>;
170
171 /// This class represents the `results` directive. This directive represents
172 /// all of the results of an operation.
173 using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
174
175 /// This class represents the `successors` directive. This directive represents
176 /// all of the successors of an operation.
177 using SuccessorsDirective =
178 DirectiveElement<Element::Kind::SuccessorsDirective>;
179
180 /// This class represents the `attr-dict` directive. This directive represents
181 /// the attribute dictionary of the operation.
182 class AttrDictDirective
183 : public DirectiveElement<Element::Kind::AttrDictDirective> {
184 public:
AttrDictDirective(bool withKeyword)185 explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
isWithKeyword() const186 bool isWithKeyword() const { return withKeyword; }
187
188 private:
189 /// If the dictionary should be printed with the 'attributes' keyword.
190 bool withKeyword;
191 };
192
193 /// This class represents a custom format directive that is implemented by the
194 /// user in C++.
195 class CustomDirective : public Element {
196 public:
CustomDirective(StringRef name,std::vector<std::unique_ptr<Element>> && arguments)197 CustomDirective(StringRef name,
198 std::vector<std::unique_ptr<Element>> &&arguments)
199 : Element{Kind::CustomDirective}, name(name),
200 arguments(std::move(arguments)) {}
201
classof(const Element * element)202 static bool classof(const Element *element) {
203 return element->getKind() == Kind::CustomDirective;
204 }
205
206 /// Return the name of this optional element.
getName() const207 StringRef getName() const { return name; }
208
209 /// Return the arguments to the custom directive.
getArguments() const210 auto getArguments() const { return llvm::make_pointee_range(arguments); }
211
212 private:
213 /// The user provided name of the directive.
214 StringRef name;
215
216 /// The arguments to the custom directive.
217 std::vector<std::unique_ptr<Element>> arguments;
218 };
219
220 /// This class represents the `functional-type` directive. This directive takes
221 /// two arguments and formats them, respectively, as the inputs and results of a
222 /// FunctionType.
223 class FunctionalTypeDirective
224 : public DirectiveElement<Element::Kind::FunctionalTypeDirective> {
225 public:
FunctionalTypeDirective(std::unique_ptr<Element> inputs,std::unique_ptr<Element> results)226 FunctionalTypeDirective(std::unique_ptr<Element> inputs,
227 std::unique_ptr<Element> results)
228 : inputs(std::move(inputs)), results(std::move(results)) {}
getInputs() const229 Element *getInputs() const { return inputs.get(); }
getResults() const230 Element *getResults() const { return results.get(); }
231
232 private:
233 /// The input and result arguments.
234 std::unique_ptr<Element> inputs, results;
235 };
236
237 /// This class represents the `type` directive.
238 class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> {
239 public:
TypeDirective(std::unique_ptr<Element> arg)240 TypeDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
getOperand() const241 Element *getOperand() const { return operand.get(); }
242
243 private:
244 /// The operand that is used to format the directive.
245 std::unique_ptr<Element> operand;
246 };
247
248 /// This class represents the `type_ref` directive.
249 class TypeRefDirective
250 : public DirectiveElement<Element::Kind::TypeRefDirective> {
251 public:
TypeRefDirective(std::unique_ptr<Element> arg)252 TypeRefDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
getOperand() const253 Element *getOperand() const { return operand.get(); }
254
255 private:
256 /// The operand that is used to format the directive.
257 std::unique_ptr<Element> operand;
258 };
259 } // namespace
260
261 //===----------------------------------------------------------------------===//
262 // LiteralElement
263
264 namespace {
265 /// This class represents an instance of a literal element.
266 class LiteralElement : public Element {
267 public:
LiteralElement(StringRef literal)268 LiteralElement(StringRef literal)
269 : Element{Kind::Literal}, literal(literal) {}
classof(const Element * element)270 static bool classof(const Element *element) {
271 return element->getKind() == Kind::Literal;
272 }
273
274 /// Return the literal for this element.
getLiteral() const275 StringRef getLiteral() const { return literal; }
276
277 /// Returns true if the given string is a valid literal.
278 static bool isValidLiteral(StringRef value);
279
280 private:
281 /// The spelling of the literal for this element.
282 StringRef literal;
283 };
284 } // end anonymous namespace
285
isValidLiteral(StringRef value)286 bool LiteralElement::isValidLiteral(StringRef value) {
287 if (value.empty())
288 return false;
289 char front = value.front();
290
291 // If there is only one character, this must either be punctuation or a
292 // single character bare identifier.
293 if (value.size() == 1)
294 return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front);
295
296 // Check the punctuation that are larger than a single character.
297 if (value == "->")
298 return true;
299
300 // Otherwise, this must be an identifier.
301 return canFormatStringAsKeyword(value);
302 }
303
304 //===----------------------------------------------------------------------===//
305 // WhitespaceElement
306
307 namespace {
308 /// This class represents a whitespace element, e.g. newline or space. It's a
309 /// literal that is printed but never parsed.
310 class WhitespaceElement : public Element {
311 public:
WhitespaceElement(Kind kind)312 WhitespaceElement(Kind kind) : Element{kind} {}
classof(const Element * element)313 static bool classof(const Element *element) {
314 Kind kind = element->getKind();
315 return kind == Kind::Newline || kind == Kind::Space;
316 }
317 };
318
319 /// This class represents an instance of a newline element. It's a literal that
320 /// prints a newline. It is ignored by the parser.
321 class NewlineElement : public WhitespaceElement {
322 public:
NewlineElement()323 NewlineElement() : WhitespaceElement(Kind::Newline) {}
classof(const Element * element)324 static bool classof(const Element *element) {
325 return element->getKind() == Kind::Newline;
326 }
327 };
328
329 /// This class represents an instance of a space element. It's a literal that
330 /// prints or omits printing a space. It is ignored by the parser.
331 class SpaceElement : public WhitespaceElement {
332 public:
SpaceElement(bool value)333 SpaceElement(bool value) : WhitespaceElement(Kind::Space), value(value) {}
classof(const Element * element)334 static bool classof(const Element *element) {
335 return element->getKind() == Kind::Space;
336 }
337
338 /// Returns true if this element should print as a space. Otherwise, the
339 /// element should omit printing a space between the surrounding elements.
getValue() const340 bool getValue() const { return value; }
341
342 private:
343 bool value;
344 };
345 } // end anonymous namespace
346
347 //===----------------------------------------------------------------------===//
348 // OptionalElement
349
350 namespace {
351 /// This class represents a group of elements that are optionally emitted based
352 /// upon an optional variable of the operation.
353 class OptionalElement : public Element {
354 public:
OptionalElement(std::vector<std::unique_ptr<Element>> && elements,unsigned anchor,unsigned parseStart)355 OptionalElement(std::vector<std::unique_ptr<Element>> &&elements,
356 unsigned anchor, unsigned parseStart)
357 : Element{Kind::Optional}, elements(std::move(elements)), anchor(anchor),
358 parseStart(parseStart) {}
classof(const Element * element)359 static bool classof(const Element *element) {
360 return element->getKind() == Kind::Optional;
361 }
362
363 /// Return the nested elements of this grouping.
getElements() const364 auto getElements() const { return llvm::make_pointee_range(elements); }
365
366 /// Return the anchor of this optional group.
getAnchor() const367 Element *getAnchor() const { return elements[anchor].get(); }
368
369 /// Return the index of the first element that needs to be parsed.
getParseStart() const370 unsigned getParseStart() const { return parseStart; }
371
372 private:
373 /// The child elements of this optional.
374 std::vector<std::unique_ptr<Element>> elements;
375 /// The index of the element that acts as the anchor for the optional group.
376 unsigned anchor;
377 /// The index of the first element that is parsed (is not a
378 /// WhitespaceElement).
379 unsigned parseStart;
380 };
381 } // end anonymous namespace
382
383 //===----------------------------------------------------------------------===//
384 // OperationFormat
385 //===----------------------------------------------------------------------===//
386
387 namespace {
388
389 using ConstArgument =
390 llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
391
392 struct OperationFormat {
393 /// This class represents a specific resolver for an operand or result type.
394 class TypeResolution {
395 public:
396 TypeResolution() = default;
397
398 /// Get the index into the buildable types for this type, or None.
getBuilderIdx() const399 Optional<int> getBuilderIdx() const { return builderIdx; }
setBuilderIdx(int idx)400 void setBuilderIdx(int idx) { builderIdx = idx; }
401
402 /// Get the variable this type is resolved to, or nullptr.
getVariable() const403 const NamedTypeConstraint *getVariable() const {
404 return resolver.dyn_cast<const NamedTypeConstraint *>();
405 }
406 /// Get the attribute this type is resolved to, or nullptr.
getAttribute() const407 const NamedAttribute *getAttribute() const {
408 return resolver.dyn_cast<const NamedAttribute *>();
409 }
410 /// Get the transformer for the type of the variable, or None.
getVarTransformer() const411 Optional<StringRef> getVarTransformer() const {
412 return variableTransformer;
413 }
setResolver(ConstArgument arg,Optional<StringRef> transformer)414 void setResolver(ConstArgument arg, Optional<StringRef> transformer) {
415 resolver = arg;
416 variableTransformer = transformer;
417 assert(getVariable() || getAttribute());
418 }
419
420 private:
421 /// If the type is resolved with a buildable type, this is the index into
422 /// 'buildableTypes' in the parent format.
423 Optional<int> builderIdx;
424 /// If the type is resolved based upon another operand or result, this is
425 /// the variable or the attribute that this type is resolved to.
426 ConstArgument resolver;
427 /// If the type is resolved based upon another operand or result, this is
428 /// a transformer to apply to the variable when resolving.
429 Optional<StringRef> variableTransformer;
430 };
431
OperationFormat__anon7e5610690811::OperationFormat432 OperationFormat(const Operator &op)
433 : allOperands(false), allOperandTypes(false), allResultTypes(false) {
434 operandTypes.resize(op.getNumOperands(), TypeResolution());
435 resultTypes.resize(op.getNumResults(), TypeResolution());
436
437 hasImplicitTermTrait =
438 llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
439 return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
440 });
441 }
442
443 /// Generate the operation parser from this format.
444 void genParser(Operator &op, OpClass &opClass);
445 /// Generate the parser code for a specific format element.
446 void genElementParser(Element *element, OpMethodBody &body,
447 FmtContext &attrTypeCtx);
448 /// Generate the c++ to resolve the types of operands and results during
449 /// parsing.
450 void genParserTypeResolution(Operator &op, OpMethodBody &body);
451 /// Generate the c++ to resolve regions during parsing.
452 void genParserRegionResolution(Operator &op, OpMethodBody &body);
453 /// Generate the c++ to resolve successors during parsing.
454 void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
455 /// Generate the c++ to handling variadic segment size traits.
456 void genParserVariadicSegmentResolution(Operator &op, OpMethodBody &body);
457
458 /// Generate the operation printer from this format.
459 void genPrinter(Operator &op, OpClass &opClass);
460
461 /// Generate the printer code for a specific format element.
462 void genElementPrinter(Element *element, OpMethodBody &body, Operator &op,
463 bool &shouldEmitSpace, bool &lastWasPunctuation);
464
465 /// The various elements in this format.
466 std::vector<std::unique_ptr<Element>> elements;
467
468 /// A flag indicating if all operand/result types were seen. If the format
469 /// contains these, it can not contain individual type resolvers.
470 bool allOperands, allOperandTypes, allResultTypes;
471
472 /// A flag indicating if this operation has the SingleBlockImplicitTerminator
473 /// trait.
474 bool hasImplicitTermTrait;
475
476 /// A map of buildable types to indices.
477 llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
478
479 /// The index of the buildable type, if valid, for every operand and result.
480 std::vector<TypeResolution> operandTypes, resultTypes;
481
482 /// The set of attributes explicitly used within the format.
483 SmallVector<const NamedAttribute *, 8> usedAttributes;
484 };
485 } // end anonymous namespace
486
487 //===----------------------------------------------------------------------===//
488 // Parser Gen
489
490 /// Returns true if we can format the given attribute as an EnumAttr in the
491 /// parser format.
canFormatEnumAttr(const NamedAttribute * attr)492 static bool canFormatEnumAttr(const NamedAttribute *attr) {
493 Attribute baseAttr = attr->attr.getBaseAttr();
494 const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
495 if (!enumAttr)
496 return false;
497
498 // The attribute must have a valid underlying type and a constant builder.
499 return !enumAttr->getUnderlyingType().empty() &&
500 !enumAttr->getConstBuilderTemplate().empty();
501 }
502
503 /// Returns if we should format the given attribute as an SymbolNameAttr.
shouldFormatSymbolNameAttr(const NamedAttribute * attr)504 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
505 return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
506 }
507
508 /// The code snippet used to generate a parser call for an attribute.
509 ///
510 /// {0}: The name of the attribute.
511 /// {1}: The type for the attribute.
512 const char *const attrParserCode = R"(
513 if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
514 return ::mlir::failure();
515 )";
516 const char *const optionalAttrParserCode = R"(
517 {
518 ::mlir::OptionalParseResult parseResult =
519 parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
520 if (parseResult.hasValue() && failed(*parseResult))
521 return ::mlir::failure();
522 }
523 )";
524
525 /// The code snippet used to generate a parser call for a symbol name attribute.
526 ///
527 /// {0}: The name of the attribute.
528 const char *const symbolNameAttrParserCode = R"(
529 if (parser.parseSymbolName({0}Attr, "{0}", result.attributes))
530 return ::mlir::failure();
531 )";
532 const char *const optionalSymbolNameAttrParserCode = R"(
533 // Parsing an optional symbol name doesn't fail, so no need to check the
534 // result.
535 (void)parser.parseOptionalSymbolName({0}Attr, "{0}", result.attributes);
536 )";
537
538 /// The code snippet used to generate a parser call for an enum attribute.
539 ///
540 /// {0}: The name of the attribute.
541 /// {1}: The c++ namespace for the enum symbolize functions.
542 /// {2}: The function to symbolize a string of the enum.
543 /// {3}: The constant builder call to create an attribute of the enum type.
544 /// {4}: The set of allowed enum keywords.
545 /// {5}: The error message on failure when the enum isn't present.
546 const char *const enumAttrParserCode = R"(
547 {
548 ::llvm::StringRef attrStr;
549 ::mlir::NamedAttrList attrStorage;
550 auto loc = parser.getCurrentLocation();
551 if (parser.parseOptionalKeyword(&attrStr, {4})) {
552 ::mlir::StringAttr attrVal;
553 ::mlir::OptionalParseResult parseResult =
554 parser.parseOptionalAttribute(attrVal,
555 parser.getBuilder().getNoneType(),
556 "{0}", attrStorage);
557 if (parseResult.hasValue()) {{
558 if (failed(*parseResult))
559 return ::mlir::failure();
560 attrStr = attrVal.getValue();
561 } else {
562 {5}
563 }
564 }
565 if (!attrStr.empty()) {
566 auto attrOptional = {1}::{2}(attrStr);
567 if (!attrOptional)
568 return parser.emitError(loc, "invalid ")
569 << "{0} attribute specification: \"" << attrStr << '"';;
570
571 {0}Attr = {3};
572 result.addAttribute("{0}", {0}Attr);
573 }
574 }
575 )";
576
577 /// The code snippet used to generate a parser call for an operand.
578 ///
579 /// {0}: The name of the operand.
580 const char *const variadicOperandParserCode = R"(
581 {0}OperandsLoc = parser.getCurrentLocation();
582 if (parser.parseOperandList({0}Operands))
583 return ::mlir::failure();
584 )";
585 const char *const optionalOperandParserCode = R"(
586 {
587 {0}OperandsLoc = parser.getCurrentLocation();
588 ::mlir::OpAsmParser::OperandType operand;
589 ::mlir::OptionalParseResult parseResult =
590 parser.parseOptionalOperand(operand);
591 if (parseResult.hasValue()) {
592 if (failed(*parseResult))
593 return ::mlir::failure();
594 {0}Operands.push_back(operand);
595 }
596 }
597 )";
598 const char *const operandParserCode = R"(
599 {0}OperandsLoc = parser.getCurrentLocation();
600 if (parser.parseOperand({0}RawOperands[0]))
601 return ::mlir::failure();
602 )";
603
604 /// The code snippet used to generate a parser call for a type list.
605 ///
606 /// {0}: The name for the type list.
607 const char *const variadicTypeParserCode = R"(
608 if (parser.parseTypeList({0}Types))
609 return ::mlir::failure();
610 )";
611 const char *const optionalTypeParserCode = R"(
612 {
613 ::mlir::Type optionalType;
614 ::mlir::OptionalParseResult parseResult =
615 parser.parseOptionalType(optionalType);
616 if (parseResult.hasValue()) {
617 if (failed(*parseResult))
618 return ::mlir::failure();
619 {0}Types.push_back(optionalType);
620 }
621 }
622 )";
623 const char *const typeParserCode = R"(
624 if (parser.parseType({0}RawTypes[0]))
625 return ::mlir::failure();
626 )";
627
628 /// The code snippet used to generate a parser call for a functional type.
629 ///
630 /// {0}: The name for the input type list.
631 /// {1}: The name for the result type list.
632 const char *const functionalTypeParserCode = R"(
633 ::mlir::FunctionType {0}__{1}_functionType;
634 if (parser.parseType({0}__{1}_functionType))
635 return ::mlir::failure();
636 {0}Types = {0}__{1}_functionType.getInputs();
637 {1}Types = {0}__{1}_functionType.getResults();
638 )";
639
640 /// The code snippet used to generate a parser call for a region list.
641 ///
642 /// {0}: The name for the region list.
643 const char *regionListParserCode = R"(
644 {
645 std::unique_ptr<::mlir::Region> region;
646 auto firstRegionResult = parser.parseOptionalRegion(region);
647 if (firstRegionResult.hasValue()) {
648 if (failed(*firstRegionResult))
649 return ::mlir::failure();
650 {0}Regions.emplace_back(std::move(region));
651
652 // Parse any trailing regions.
653 while (succeeded(parser.parseOptionalComma())) {
654 region = std::make_unique<::mlir::Region>();
655 if (parser.parseRegion(*region))
656 return ::mlir::failure();
657 {0}Regions.emplace_back(std::move(region));
658 }
659 }
660 }
661 )";
662
663 /// The code snippet used to ensure a list of regions have terminators.
664 ///
665 /// {0}: The name of the region list.
666 const char *regionListEnsureTerminatorParserCode = R"(
667 for (auto ®ion : {0}Regions)
668 ensureTerminator(*region, parser.getBuilder(), result.location);
669 )";
670
671 /// The code snippet used to generate a parser call for an optional region.
672 ///
673 /// {0}: The name of the region.
674 const char *optionalRegionParserCode = R"(
675 {
676 auto parseResult = parser.parseOptionalRegion(*{0}Region);
677 if (parseResult.hasValue() && failed(*parseResult))
678 return ::mlir::failure();
679 }
680 )";
681
682 /// The code snippet used to generate a parser call for a region.
683 ///
684 /// {0}: The name of the region.
685 const char *regionParserCode = R"(
686 if (parser.parseRegion(*{0}Region))
687 return ::mlir::failure();
688 )";
689
690 /// The code snippet used to ensure a region has a terminator.
691 ///
692 /// {0}: The name of the region.
693 const char *regionEnsureTerminatorParserCode = R"(
694 ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
695 )";
696
697 /// The code snippet used to generate a parser call for a successor list.
698 ///
699 /// {0}: The name for the successor list.
700 const char *successorListParserCode = R"(
701 {
702 ::mlir::Block *succ;
703 auto firstSucc = parser.parseOptionalSuccessor(succ);
704 if (firstSucc.hasValue()) {
705 if (failed(*firstSucc))
706 return ::mlir::failure();
707 {0}Successors.emplace_back(succ);
708
709 // Parse any trailing successors.
710 while (succeeded(parser.parseOptionalComma())) {
711 if (parser.parseSuccessor(succ))
712 return ::mlir::failure();
713 {0}Successors.emplace_back(succ);
714 }
715 }
716 }
717 )";
718
719 /// The code snippet used to generate a parser call for a successor.
720 ///
721 /// {0}: The name of the successor.
722 const char *successorParserCode = R"(
723 if (parser.parseSuccessor({0}Successor))
724 return ::mlir::failure();
725 )";
726
727 namespace {
728 /// The type of length for a given parse argument.
729 enum class ArgumentLengthKind {
730 /// The argument is variadic, and may contain 0->N elements.
731 Variadic,
732 /// The argument is optional, and may contain 0 or 1 elements.
733 Optional,
734 /// The argument is a single element, i.e. always represents 1 element.
735 Single
736 };
737 } // end anonymous namespace
738
739 /// Get the length kind for the given constraint.
740 static ArgumentLengthKind
getArgumentLengthKind(const NamedTypeConstraint * var)741 getArgumentLengthKind(const NamedTypeConstraint *var) {
742 if (var->isOptional())
743 return ArgumentLengthKind::Optional;
744 if (var->isVariadic())
745 return ArgumentLengthKind::Variadic;
746 return ArgumentLengthKind::Single;
747 }
748
749 /// Get the name used for the type list for the given type directive operand.
750 /// 'lengthKind' to the corresponding kind for the given argument.
751 static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) {
752 if (auto *operand = dyn_cast<OperandVariable>(arg)) {
753 lengthKind = getArgumentLengthKind(operand->getVar());
754 return operand->getVar()->name;
755 }
756 if (auto *result = dyn_cast<ResultVariable>(arg)) {
757 lengthKind = getArgumentLengthKind(result->getVar());
758 return result->getVar()->name;
759 }
760 lengthKind = ArgumentLengthKind::Variadic;
761 if (isa<OperandsDirective>(arg))
762 return "allOperand";
763 if (isa<ResultsDirective>(arg))
764 return "allResult";
765 llvm_unreachable("unknown 'type' directive argument");
766 }
767
768 /// Generate the parser for a literal value.
769 static void genLiteralParser(StringRef value, OpMethodBody &body) {
770 // Handle the case of a keyword/identifier.
771 if (value.front() == '_' || isalpha(value.front())) {
772 body << "Keyword(\"" << value << "\")";
773 return;
774 }
775 body << (StringRef)StringSwitch<StringRef>(value)
776 .Case("->", "Arrow()")
777 .Case(":", "Colon()")
778 .Case(",", "Comma()")
779 .Case("=", "Equal()")
780 .Case("<", "Less()")
781 .Case(">", "Greater()")
782 .Case("{", "LBrace()")
783 .Case("}", "RBrace()")
784 .Case("(", "LParen()")
785 .Case(")", "RParen()")
786 .Case("[", "LSquare()")
787 .Case("]", "RSquare()")
788 .Case("?", "Question()")
789 .Case("+", "Plus()")
790 .Case("*", "Star()");
791 }
792
793 /// Generate the storage code required for parsing the given element.
genElementParserStorage(Element * element,OpMethodBody & body)794 static void genElementParserStorage(Element *element, OpMethodBody &body) {
795 if (auto *optional = dyn_cast<OptionalElement>(element)) {
796 auto elements = optional->getElements();
797
798 // If the anchor is a unit attribute, it won't be parsed directly so elide
799 // it.
800 auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
801 Element *elidedAnchorElement = nullptr;
802 if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr())
803 elidedAnchorElement = anchor;
804 for (auto &childElement : elements)
805 if (&childElement != elidedAnchorElement)
806 genElementParserStorage(&childElement, body);
807
808 } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
809 for (auto ¶mElement : custom->getArguments())
810 genElementParserStorage(¶mElement, body);
811
812 } else if (isa<OperandsDirective>(element)) {
813 body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
814 "allOperands;\n";
815
816 } else if (isa<RegionsDirective>(element)) {
817 body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
818 "fullRegions;\n";
819
820 } else if (isa<SuccessorsDirective>(element)) {
821 body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
822
823 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
824 const NamedAttribute *var = attr->getVar();
825 body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
826 var->name);
827
828 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
829 StringRef name = operand->getVar()->name;
830 if (operand->getVar()->isVariableLength()) {
831 body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
832 << name << "Operands;\n";
833 } else {
834 body << " ::mlir::OpAsmParser::OperandType " << name
835 << "RawOperands[1];\n"
836 << " ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name
837 << "Operands(" << name << "RawOperands);";
838 }
839 body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
840 " (void){0}OperandsLoc;\n",
841 name);
842
843 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
844 StringRef name = region->getVar()->name;
845 if (region->getVar()->isVariadic()) {
846 body << llvm::formatv(
847 " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
848 "{0}Regions;\n",
849 name);
850 } else {
851 body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
852 "std::make_unique<::mlir::Region>();\n",
853 name);
854 }
855
856 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
857 StringRef name = successor->getVar()->name;
858 if (successor->getVar()->isVariadic()) {
859 body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
860 "{0}Successors;\n",
861 name);
862 } else {
863 body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
864 }
865
866 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
867 ArgumentLengthKind lengthKind;
868 StringRef name = getTypeListName(dir->getOperand(), lengthKind);
869 if (lengthKind != ArgumentLengthKind::Single)
870 body << " ::mlir::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
871 else
872 body << llvm::formatv(" ::mlir::Type {0}RawTypes[1];\n", name)
873 << llvm::formatv(
874 " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
875 name);
876 } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
877 ArgumentLengthKind lengthKind;
878 StringRef name = getTypeListName(dir->getOperand(), lengthKind);
879 // Refer to the previously encountered TypeDirective for name.
880 // Take a `const ::mlir::SmallVector<::mlir::Type, 1> &` in the declaration
881 // to properly track the types that will be parsed and pushed later on.
882 if (lengthKind != ArgumentLengthKind::Single)
883 body << " const ::mlir::SmallVector<::mlir::Type, 1> &" << name
884 << "TypesRef(" << name << "Types);\n";
885 else
886 body << llvm::formatv(
887 " ::llvm::ArrayRef<::mlir::Type> {0}RawTypesRef({0}RawTypes);\n",
888 name);
889 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
890 ArgumentLengthKind ignored;
891 body << " ::llvm::ArrayRef<::mlir::Type> "
892 << getTypeListName(dir->getInputs(), ignored) << "Types;\n";
893 body << " ::llvm::ArrayRef<::mlir::Type> "
894 << getTypeListName(dir->getResults(), ignored) << "Types;\n";
895 }
896 }
897
898 /// Generate the parser for a parameter to a custom directive.
genCustomParameterParser(Element & param,OpMethodBody & body)899 static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
900 body << ", ";
901 if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
902 body << attr->getVar()->name << "Attr";
903 } else if (isa<AttrDictDirective>(¶m)) {
904 body << "result.attributes";
905 } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
906 StringRef name = operand->getVar()->name;
907 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
908 if (lengthKind == ArgumentLengthKind::Variadic)
909 body << llvm::formatv("{0}Operands", name);
910 else if (lengthKind == ArgumentLengthKind::Optional)
911 body << llvm::formatv("{0}Operand", name);
912 else
913 body << formatv("{0}RawOperands[0]", name);
914
915 } else if (auto *region = dyn_cast<RegionVariable>(¶m)) {
916 StringRef name = region->getVar()->name;
917 if (region->getVar()->isVariadic())
918 body << llvm::formatv("{0}Regions", name);
919 else
920 body << llvm::formatv("*{0}Region", name);
921
922 } else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
923 StringRef name = successor->getVar()->name;
924 if (successor->getVar()->isVariadic())
925 body << llvm::formatv("{0}Successors", name);
926 else
927 body << llvm::formatv("{0}Successor", name);
928
929 } else if (auto *dir = dyn_cast<TypeRefDirective>(¶m)) {
930 ArgumentLengthKind lengthKind;
931 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
932 if (lengthKind == ArgumentLengthKind::Variadic)
933 body << llvm::formatv("{0}TypesRef", listName);
934 else if (lengthKind == ArgumentLengthKind::Optional)
935 body << llvm::formatv("{0}TypeRef", listName);
936 else
937 body << formatv("{0}RawTypesRef[0]", listName);
938 } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
939 ArgumentLengthKind lengthKind;
940 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
941 if (lengthKind == ArgumentLengthKind::Variadic)
942 body << llvm::formatv("{0}Types", listName);
943 else if (lengthKind == ArgumentLengthKind::Optional)
944 body << llvm::formatv("{0}Type", listName);
945 else
946 body << formatv("{0}RawTypes[0]", listName);
947 } else {
948 llvm_unreachable("unknown custom directive parameter");
949 }
950 }
951
952 /// Generate the parser for a custom directive.
genCustomDirectiveParser(CustomDirective * dir,OpMethodBody & body)953 static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
954 body << " {\n";
955
956 // Preprocess the directive variables.
957 // * Add a local variable for optional operands and types. This provides a
958 // better API to the user defined parser methods.
959 // * Set the location of operand variables.
960 for (Element ¶m : dir->getArguments()) {
961 if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
962 body << " " << operand->getVar()->name
963 << "OperandsLoc = parser.getCurrentLocation();\n";
964 if (operand->getVar()->isOptional()) {
965 body << llvm::formatv(
966 " llvm::Optional<::mlir::OpAsmParser::OperandType> "
967 "{0}Operand;\n",
968 operand->getVar()->name);
969 }
970 } else if (auto *dir = dyn_cast<TypeRefDirective>(¶m)) {
971 // Reference to an optional which may or may not have been set.
972 // Retrieve from vector if not empty.
973 ArgumentLengthKind lengthKind;
974 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
975 if (lengthKind == ArgumentLengthKind::Optional)
976 body << llvm::formatv(
977 " ::mlir::Type {0}TypeRef = {0}TypesRef.empty() "
978 "? Type() : {0}TypesRef[0];\n",
979 listName);
980 } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
981 ArgumentLengthKind lengthKind;
982 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
983 if (lengthKind == ArgumentLengthKind::Optional)
984 body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
985 }
986 }
987
988 body << " if (parse" << dir->getName() << "(parser";
989 for (Element ¶m : dir->getArguments())
990 genCustomParameterParser(param, body);
991
992 body << "))\n"
993 << " return ::mlir::failure();\n";
994
995 // After parsing, add handling for any of the optional constructs.
996 for (Element ¶m : dir->getArguments()) {
997 if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
998 const NamedAttribute *var = attr->getVar();
999 if (var->attr.isOptional())
1000 body << llvm::formatv(" if ({0}Attr)\n ", var->name);
1001
1002 body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
1003 var->name);
1004 } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
1005 const NamedTypeConstraint *var = operand->getVar();
1006 if (!var->isOptional())
1007 continue;
1008 body << llvm::formatv(" if ({0}Operand.hasValue())\n"
1009 " {0}Operands.push_back(*{0}Operand);\n",
1010 var->name);
1011 } else if (isa<TypeRefDirective>(¶m)) {
1012 // In the `type_ref` case, do not parse a new Type that needs to be added.
1013 // Just do nothing here.
1014 } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
1015 ArgumentLengthKind lengthKind;
1016 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1017 if (lengthKind == ArgumentLengthKind::Optional) {
1018 body << llvm::formatv(" if ({0}Type)\n"
1019 " {0}Types.push_back({0}Type);\n",
1020 listName);
1021 }
1022 }
1023 }
1024
1025 body << " }\n";
1026 }
1027
1028 /// Generate the parser for a enum attribute.
genEnumAttrParser(const NamedAttribute * var,OpMethodBody & body,FmtContext & attrTypeCtx)1029 static void genEnumAttrParser(const NamedAttribute *var, OpMethodBody &body,
1030 FmtContext &attrTypeCtx) {
1031 Attribute baseAttr = var->attr.getBaseAttr();
1032 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1033 std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1034
1035 // Generate the code for building an attribute for this enum.
1036 std::string attrBuilderStr;
1037 {
1038 llvm::raw_string_ostream os(attrBuilderStr);
1039 os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
1040 "attrOptional.getValue()");
1041 }
1042
1043 // Build a string containing the cases that can be formatted as a keyword.
1044 std::string validCaseKeywordsStr = "{";
1045 llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr);
1046 for (const EnumAttrCase &attrCase : cases)
1047 if (canFormatStringAsKeyword(attrCase.getStr()))
1048 validCaseKeywordsOS << '"' << attrCase.getStr() << "\",";
1049 validCaseKeywordsOS.str().back() = '}';
1050
1051 // If the attribute is not optional, build an error message for the missing
1052 // attribute.
1053 std::string errorMessage;
1054 if (!var->attr.isOptional()) {
1055 llvm::raw_string_ostream errorMessageOS(errorMessage);
1056 errorMessageOS
1057 << "return parser.emitError(loc, \"expected string or "
1058 "keyword containing one of the following enum values for attribute '"
1059 << var->name << "' [";
1060 llvm::interleaveComma(cases, errorMessageOS, [&](const auto &attrCase) {
1061 errorMessageOS << attrCase.getStr();
1062 });
1063 errorMessageOS << "]\");";
1064 }
1065
1066 body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
1067 enumAttr.getStringToSymbolFnName(), attrBuilderStr,
1068 validCaseKeywordsStr, errorMessage);
1069 }
1070
genParser(Operator & op,OpClass & opClass)1071 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
1072 llvm::SmallVector<OpMethodParameter, 4> paramList;
1073 paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1074 paramList.emplace_back("::mlir::OperationState &", "result");
1075
1076 auto *method =
1077 opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
1078 OpMethod::MP_Static, std::move(paramList));
1079 auto &body = method->body();
1080
1081 // Generate variables to store the operands and type within the format. This
1082 // allows for referencing these variables in the presence of optional
1083 // groupings.
1084 for (auto &element : elements)
1085 genElementParserStorage(&*element, body);
1086
1087 // A format context used when parsing attributes with buildable types.
1088 FmtContext attrTypeCtx;
1089 attrTypeCtx.withBuilder("parser.getBuilder()");
1090
1091 // Generate parsers for each of the elements.
1092 for (auto &element : elements)
1093 genElementParser(element.get(), body, attrTypeCtx);
1094
1095 // Generate the code to resolve the operand/result types and successors now
1096 // that they have been parsed.
1097 genParserTypeResolution(op, body);
1098 genParserRegionResolution(op, body);
1099 genParserSuccessorResolution(op, body);
1100 genParserVariadicSegmentResolution(op, body);
1101
1102 body << " return ::mlir::success();\n";
1103 }
1104
genElementParser(Element * element,OpMethodBody & body,FmtContext & attrTypeCtx)1105 void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
1106 FmtContext &attrTypeCtx) {
1107 /// Optional Group.
1108 if (auto *optional = dyn_cast<OptionalElement>(element)) {
1109 auto elements =
1110 llvm::drop_begin(optional->getElements(), optional->getParseStart());
1111
1112 // Generate a special optional parser for the first element to gate the
1113 // parsing of the rest of the elements.
1114 Element *firstElement = &*elements.begin();
1115 if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
1116 genElementParser(attrVar, body, attrTypeCtx);
1117 body << " if (" << attrVar->getVar()->name << "Attr) {\n";
1118 } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
1119 body << " if (succeeded(parser.parseOptional";
1120 genLiteralParser(literal->getLiteral(), body);
1121 body << ")) {\n";
1122 } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
1123 genElementParser(opVar, body, attrTypeCtx);
1124 body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
1125 } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
1126 const NamedRegion *region = regionVar->getVar();
1127 if (region->isVariadic()) {
1128 genElementParser(regionVar, body, attrTypeCtx);
1129 body << " if (!" << region->name << "Regions.empty()) {\n";
1130 } else {
1131 body << llvm::formatv(optionalRegionParserCode, region->name);
1132 body << " if (!" << region->name << "Region->empty()) {\n ";
1133 if (hasImplicitTermTrait)
1134 body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
1135 }
1136 }
1137
1138 // If the anchor is a unit attribute, we don't need to print it. When
1139 // parsing, we will add this attribute if this group is present.
1140 Element *elidedAnchorElement = nullptr;
1141 auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
1142 if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) {
1143 elidedAnchorElement = anchorAttr;
1144
1145 // Add the anchor unit attribute to the operation state.
1146 body << " result.addAttribute(\"" << anchorAttr->getVar()->name
1147 << "\", parser.getBuilder().getUnitAttr());\n";
1148 }
1149
1150 // Generate the rest of the elements normally.
1151 for (Element &childElement : llvm::drop_begin(elements, 1)) {
1152 if (&childElement != elidedAnchorElement)
1153 genElementParser(&childElement, body, attrTypeCtx);
1154 }
1155 body << " }\n";
1156
1157 /// Literals.
1158 } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
1159 body << " if (parser.parse";
1160 genLiteralParser(literal->getLiteral(), body);
1161 body << ")\n return ::mlir::failure();\n";
1162
1163 /// Whitespaces.
1164 } else if (isa<WhitespaceElement>(element)) {
1165 // Nothing to parse.
1166
1167 /// Arguments.
1168 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1169 const NamedAttribute *var = attr->getVar();
1170
1171 // Check to see if we can parse this as an enum attribute.
1172 if (canFormatEnumAttr(var))
1173 return genEnumAttrParser(var, body, attrTypeCtx);
1174
1175 // Check to see if we should parse this as a symbol name attribute.
1176 if (shouldFormatSymbolNameAttr(var)) {
1177 body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode
1178 : symbolNameAttrParserCode,
1179 var->name);
1180 return;
1181 }
1182
1183 // If this attribute has a buildable type, use that when parsing the
1184 // attribute.
1185 std::string attrTypeStr;
1186 if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
1187 llvm::raw_string_ostream os(attrTypeStr);
1188 os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
1189 }
1190
1191 body << formatv(var->attr.isOptional() ? optionalAttrParserCode
1192 : attrParserCode,
1193 var->name, attrTypeStr);
1194 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1195 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
1196 StringRef name = operand->getVar()->name;
1197 if (lengthKind == ArgumentLengthKind::Variadic)
1198 body << llvm::formatv(variadicOperandParserCode, name);
1199 else if (lengthKind == ArgumentLengthKind::Optional)
1200 body << llvm::formatv(optionalOperandParserCode, name);
1201 else
1202 body << formatv(operandParserCode, name);
1203
1204 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1205 bool isVariadic = region->getVar()->isVariadic();
1206 body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
1207 region->getVar()->name);
1208 if (hasImplicitTermTrait) {
1209 body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
1210 : regionEnsureTerminatorParserCode,
1211 region->getVar()->name);
1212 }
1213
1214 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1215 bool isVariadic = successor->getVar()->isVariadic();
1216 body << formatv(isVariadic ? successorListParserCode : successorParserCode,
1217 successor->getVar()->name);
1218
1219 /// Directives.
1220 } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1221 body << " if (parser.parseOptionalAttrDict"
1222 << (attrDict->isWithKeyword() ? "WithKeyword" : "")
1223 << "(result.attributes))\n"
1224 << " return ::mlir::failure();\n";
1225 } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
1226 genCustomDirectiveParser(customDir, body);
1227
1228 } else if (isa<OperandsDirective>(element)) {
1229 body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
1230 << " if (parser.parseOperandList(allOperands))\n"
1231 << " return ::mlir::failure();\n";
1232
1233 } else if (isa<RegionsDirective>(element)) {
1234 body << llvm::formatv(regionListParserCode, "full");
1235 if (hasImplicitTermTrait)
1236 body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
1237
1238 } else if (isa<SuccessorsDirective>(element)) {
1239 body << llvm::formatv(successorListParserCode, "full");
1240
1241 } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
1242 ArgumentLengthKind lengthKind;
1243 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1244 if (lengthKind == ArgumentLengthKind::Variadic)
1245 body << llvm::formatv(variadicTypeParserCode, listName);
1246 else if (lengthKind == ArgumentLengthKind::Optional)
1247 body << llvm::formatv(optionalTypeParserCode, listName);
1248 else
1249 body << formatv(typeParserCode, listName);
1250 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1251 ArgumentLengthKind lengthKind;
1252 StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
1253 if (lengthKind == ArgumentLengthKind::Variadic)
1254 body << llvm::formatv(variadicTypeParserCode, listName);
1255 else if (lengthKind == ArgumentLengthKind::Optional)
1256 body << llvm::formatv(optionalTypeParserCode, listName);
1257 else
1258 body << formatv(typeParserCode, listName);
1259 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1260 ArgumentLengthKind ignored;
1261 body << formatv(functionalTypeParserCode,
1262 getTypeListName(dir->getInputs(), ignored),
1263 getTypeListName(dir->getResults(), ignored));
1264 } else {
1265 llvm_unreachable("unknown format element");
1266 }
1267 }
1268
genParserTypeResolution(Operator & op,OpMethodBody & body)1269 void OperationFormat::genParserTypeResolution(Operator &op,
1270 OpMethodBody &body) {
1271 // If any of type resolutions use transformed variables, make sure that the
1272 // types of those variables are resolved.
1273 SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
1274 FmtContext verifierFCtx;
1275 for (TypeResolution &resolver :
1276 llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
1277 Optional<StringRef> transformer = resolver.getVarTransformer();
1278 if (!transformer)
1279 continue;
1280 // Ensure that we don't verify the same variables twice.
1281 const NamedTypeConstraint *variable = resolver.getVariable();
1282 if (!variable || !verifiedVariables.insert(variable).second)
1283 continue;
1284
1285 auto constraint = variable->constraint;
1286 body << " for (::mlir::Type type : " << variable->name << "Types) {\n"
1287 << " (void)type;\n"
1288 << " if (!("
1289 << tgfmt(constraint.getConditionTemplate(),
1290 &verifierFCtx.withSelf("type"))
1291 << ")) {\n"
1292 << formatv(" return parser.emitError(parser.getNameLoc()) << "
1293 "\"'{0}' must be {1}, but got \" << type;\n",
1294 variable->name, constraint.getSummary())
1295 << " }\n"
1296 << " }\n";
1297 }
1298
1299 // Initialize the set of buildable types.
1300 if (!buildableTypes.empty()) {
1301 FmtContext typeBuilderCtx;
1302 typeBuilderCtx.withBuilder("parser.getBuilder()");
1303 for (auto &it : buildableTypes)
1304 body << " ::mlir::Type odsBuildableType" << it.second << " = "
1305 << tgfmt(it.first, &typeBuilderCtx) << ";\n";
1306 }
1307
1308 // Emit the code necessary for a type resolver.
1309 auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
1310 if (Optional<int> val = resolver.getBuilderIdx()) {
1311 body << "odsBuildableType" << *val;
1312 } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1313 if (Optional<StringRef> tform = resolver.getVarTransformer()) {
1314 FmtContext fmtContext;
1315 if (var->isVariadic())
1316 fmtContext.withSelf(var->name + "Types");
1317 else
1318 fmtContext.withSelf(var->name + "Types[0]");
1319 body << tgfmt(*tform, &fmtContext);
1320 } else {
1321 body << var->name << "Types";
1322 }
1323 } else if (const NamedAttribute *attr = resolver.getAttribute()) {
1324 if (Optional<StringRef> tform = resolver.getVarTransformer())
1325 body << tgfmt(*tform,
1326 &FmtContext().withSelf(attr->name + "Attr.getType()"));
1327 else
1328 body << attr->name << "Attr.getType()";
1329 } else {
1330 body << curVar << "Types";
1331 }
1332 };
1333
1334 // Resolve each of the result types.
1335 if (allResultTypes) {
1336 body << " result.addTypes(allResultTypes);\n";
1337 } else {
1338 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
1339 body << " result.addTypes(";
1340 emitTypeResolver(resultTypes[i], op.getResultName(i));
1341 body << ");\n";
1342 }
1343 }
1344
1345 // Early exit if there are no operands.
1346 if (op.getNumOperands() == 0)
1347 return;
1348
1349 // Handle the case where all operand types are in one group.
1350 if (allOperandTypes) {
1351 // If we have all operands together, use the full operand list directly.
1352 if (allOperands) {
1353 body << " if (parser.resolveOperands(allOperands, allOperandTypes, "
1354 "allOperandLoc, result.operands))\n"
1355 " return ::mlir::failure();\n";
1356 return;
1357 }
1358
1359 // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1360 // llvm::concat does not allow the case of a single range, so guard it here.
1361 body << " if (parser.resolveOperands(";
1362 if (op.getNumOperands() > 1) {
1363 body << "::llvm::concat<const ::mlir::OpAsmParser::OperandType>(";
1364 llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
1365 body << operand.name << "Operands";
1366 });
1367 body << ")";
1368 } else {
1369 body << op.operand_begin()->name << "Operands";
1370 }
1371 body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1372 << " return ::mlir::failure();\n";
1373 return;
1374 }
1375 // Handle the case where all of the operands were grouped together.
1376 if (allOperands) {
1377 body << " if (parser.resolveOperands(allOperands, ";
1378
1379 // Group all of the operand types together to perform the resolution all at
1380 // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1381 // the case of a single range, so guard it here.
1382 if (op.getNumOperands() > 1) {
1383 body << "::llvm::concat<const Type>(";
1384 llvm::interleaveComma(
1385 llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1386 body << "::llvm::ArrayRef<::mlir::Type>(";
1387 emitTypeResolver(operandTypes[i], op.getOperand(i).name);
1388 body << ")";
1389 });
1390 body << ")";
1391 } else {
1392 emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
1393 }
1394
1395 body << ", allOperandLoc, result.operands))\n"
1396 << " return ::mlir::failure();\n";
1397 return;
1398 }
1399
1400 // The final case is the one where each of the operands types are resolved
1401 // separately.
1402 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
1403 NamedTypeConstraint &operand = op.getOperand(i);
1404 body << " if (parser.resolveOperands(" << operand.name << "Operands, ";
1405
1406 // Resolve the type of this operand.
1407 TypeResolution &operandType = operandTypes[i];
1408 emitTypeResolver(operandType, operand.name);
1409
1410 // If the type is resolved by a non-variadic variable, index into the
1411 // resolved type list. This allows for resolving the types of a variadic
1412 // operand list from a non-variadic variable.
1413 bool verifyOperandAndTypeSize = true;
1414 if (auto *resolverVar = operandType.getVariable()) {
1415 if (!resolverVar->isVariadic() && !operandType.getVarTransformer()) {
1416 body << "[0]";
1417 verifyOperandAndTypeSize = false;
1418 }
1419 } else {
1420 verifyOperandAndTypeSize = !operandType.getBuilderIdx();
1421 }
1422
1423 // Check to see if the sizes between the types and operands must match. If
1424 // they do, provide the operand location to select the proper resolution
1425 // overload.
1426 if (verifyOperandAndTypeSize)
1427 body << ", " << operand.name << "OperandsLoc";
1428 body << ", result.operands))\n return ::mlir::failure();\n";
1429 }
1430 }
1431
genParserRegionResolution(Operator & op,OpMethodBody & body)1432 void OperationFormat::genParserRegionResolution(Operator &op,
1433 OpMethodBody &body) {
1434 // Check for the case where all regions were parsed.
1435 bool hasAllRegions = llvm::any_of(
1436 elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
1437 if (hasAllRegions) {
1438 body << " result.addRegions(fullRegions);\n";
1439 return;
1440 }
1441
1442 // Otherwise, handle each region individually.
1443 for (const NamedRegion ®ion : op.getRegions()) {
1444 if (region.isVariadic())
1445 body << " result.addRegions(" << region.name << "Regions);\n";
1446 else
1447 body << " result.addRegion(std::move(" << region.name << "Region));\n";
1448 }
1449 }
1450
genParserSuccessorResolution(Operator & op,OpMethodBody & body)1451 void OperationFormat::genParserSuccessorResolution(Operator &op,
1452 OpMethodBody &body) {
1453 // Check for the case where all successors were parsed.
1454 bool hasAllSuccessors = llvm::any_of(
1455 elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); });
1456 if (hasAllSuccessors) {
1457 body << " result.addSuccessors(fullSuccessors);\n";
1458 return;
1459 }
1460
1461 // Otherwise, handle each successor individually.
1462 for (const NamedSuccessor &successor : op.getSuccessors()) {
1463 if (successor.isVariadic())
1464 body << " result.addSuccessors(" << successor.name << "Successors);\n";
1465 else
1466 body << " result.addSuccessors(" << successor.name << "Successor);\n";
1467 }
1468 }
1469
genParserVariadicSegmentResolution(Operator & op,OpMethodBody & body)1470 void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
1471 OpMethodBody &body) {
1472 if (!allOperands &&
1473 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1474 body << " result.addAttribute(\"operand_segment_sizes\", "
1475 << "parser.getBuilder().getI32VectorAttr({";
1476 auto interleaveFn = [&](const NamedTypeConstraint &operand) {
1477 // If the operand is variadic emit the parsed size.
1478 if (operand.isVariableLength())
1479 body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
1480 else
1481 body << "1";
1482 };
1483 llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1484 body << "}));\n";
1485 }
1486
1487 if (!allResultTypes &&
1488 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1489 body << " result.addAttribute(\"result_segment_sizes\", "
1490 << "parser.getBuilder().getI32VectorAttr({";
1491 auto interleaveFn = [&](const NamedTypeConstraint &result) {
1492 // If the result is variadic emit the parsed size.
1493 if (result.isVariableLength())
1494 body << "static_cast<int32_t>(" << result.name << "Types.size())";
1495 else
1496 body << "1";
1497 };
1498 llvm::interleaveComma(op.getResults(), body, interleaveFn);
1499 body << "}));\n";
1500 }
1501 }
1502
1503 //===----------------------------------------------------------------------===//
1504 // PrinterGen
1505
1506 /// The code snippet used to generate a printer call for a region of an
1507 // operation that has the SingleBlockImplicitTerminator trait.
1508 ///
1509 /// {0}: The name of the region.
1510 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
1511 {
1512 bool printTerminator = true;
1513 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1514 printTerminator = !term->getAttrDictionary().empty() ||
1515 term->getNumOperands() != 0 ||
1516 term->getNumResults() != 0;
1517 }
1518 p.printRegion({0}, /*printEntryBlockArgs=*/true,
1519 /*printBlockTerminators=*/printTerminator);
1520 }
1521 )";
1522
1523 /// The code snippet used to generate a printer call for an enum that has cases
1524 /// that can't be represented with a keyword.
1525 ///
1526 /// {0}: The name of the enum attribute.
1527 /// {1}: The name of the enum attributes symbolToString function.
1528 const char *enumAttrBeginPrinterCode = R"(
1529 {
1530 auto caseValue = {0}();
1531 auto caseValueStr = {1}(caseValue);
1532 )";
1533
1534 /// Generate the printer for the 'attr-dict' directive.
genAttrDictPrinter(OperationFormat & fmt,Operator & op,OpMethodBody & body,bool withKeyword)1535 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
1536 OpMethodBody &body, bool withKeyword) {
1537 body << " p.printOptionalAttrDict" << (withKeyword ? "WithKeyword" : "")
1538 << "(getAttrs(), /*elidedAttrs=*/{";
1539 // Elide the variadic segment size attributes if necessary.
1540 if (!fmt.allOperands &&
1541 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1542 body << "\"operand_segment_sizes\", ";
1543 if (!fmt.allResultTypes &&
1544 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1545 body << "\"result_segment_sizes\", ";
1546 llvm::interleaveComma(
1547 fmt.usedAttributes, body,
1548 [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
1549 body << "});\n";
1550 }
1551
1552 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1553 /// space should be emitted before this element. `lastWasPunctuation` is true if
1554 /// the previous element was a punctuation literal.
genLiteralPrinter(StringRef value,OpMethodBody & body,bool & shouldEmitSpace,bool & lastWasPunctuation)1555 static void genLiteralPrinter(StringRef value, OpMethodBody &body,
1556 bool &shouldEmitSpace, bool &lastWasPunctuation) {
1557 body << " p";
1558
1559 // Don't insert a space for certain punctuation.
1560 auto shouldPrintSpaceBeforeLiteral = [&] {
1561 if (value.size() != 1 && value != "->")
1562 return true;
1563 if (lastWasPunctuation)
1564 return !StringRef(">)}],").contains(value.front());
1565 return !StringRef("<>(){}[],").contains(value.front());
1566 };
1567 if (shouldEmitSpace && shouldPrintSpaceBeforeLiteral())
1568 body << " << ' '";
1569 body << " << \"" << value << "\";\n";
1570
1571 // Insert a space after certain literals.
1572 shouldEmitSpace =
1573 value.size() != 1 || !StringRef("<({[").contains(value.front());
1574 lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
1575 }
1576
1577 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1578 /// are set to false.
genSpacePrinter(bool value,OpMethodBody & body,bool & shouldEmitSpace,bool & lastWasPunctuation)1579 static void genSpacePrinter(bool value, OpMethodBody &body,
1580 bool &shouldEmitSpace, bool &lastWasPunctuation) {
1581 if (value) {
1582 body << " p << ' ';\n";
1583 lastWasPunctuation = false;
1584 }
1585 shouldEmitSpace = false;
1586 }
1587
1588 /// Generate the printer for a custom directive.
genCustomDirectivePrinter(CustomDirective * customDir,OpMethodBody & body)1589 static void genCustomDirectivePrinter(CustomDirective *customDir,
1590 OpMethodBody &body) {
1591 body << " print" << customDir->getName() << "(p, *this";
1592 for (Element ¶m : customDir->getArguments()) {
1593 body << ", ";
1594 if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
1595 body << attr->getVar()->name << "Attr()";
1596
1597 } else if (isa<AttrDictDirective>(¶m)) {
1598 body << "getOperation()->getAttrDictionary()";
1599
1600 } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
1601 body << operand->getVar()->name << "()";
1602
1603 } else if (auto *region = dyn_cast<RegionVariable>(¶m)) {
1604 body << region->getVar()->name << "()";
1605
1606 } else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
1607 body << successor->getVar()->name << "()";
1608
1609 } else if (auto *dir = dyn_cast<TypeRefDirective>(¶m)) {
1610 auto *typeOperand = dir->getOperand();
1611 auto *operand = dyn_cast<OperandVariable>(typeOperand);
1612 auto *var = operand ? operand->getVar()
1613 : cast<ResultVariable>(typeOperand)->getVar();
1614 if (var->isVariadic())
1615 body << var->name << "().getTypes()";
1616 else if (var->isOptional())
1617 body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
1618 else
1619 body << var->name << "().getType()";
1620 } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
1621 auto *typeOperand = dir->getOperand();
1622 auto *operand = dyn_cast<OperandVariable>(typeOperand);
1623 auto *var = operand ? operand->getVar()
1624 : cast<ResultVariable>(typeOperand)->getVar();
1625 if (var->isVariadic())
1626 body << var->name << "().getTypes()";
1627 else if (var->isOptional())
1628 body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
1629 else
1630 body << var->name << "().getType()";
1631 } else {
1632 llvm_unreachable("unknown custom directive parameter");
1633 }
1634 }
1635
1636 body << ");\n";
1637 }
1638
1639 /// Generate the printer for a region with the given variable name.
genRegionPrinter(const Twine & regionName,OpMethodBody & body,bool hasImplicitTermTrait)1640 static void genRegionPrinter(const Twine ®ionName, OpMethodBody &body,
1641 bool hasImplicitTermTrait) {
1642 if (hasImplicitTermTrait)
1643 body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
1644 regionName);
1645 else
1646 body << " p.printRegion(" << regionName << ");\n";
1647 }
genVariadicRegionPrinter(const Twine & regionListName,OpMethodBody & body,bool hasImplicitTermTrait)1648 static void genVariadicRegionPrinter(const Twine ®ionListName,
1649 OpMethodBody &body,
1650 bool hasImplicitTermTrait) {
1651 body << " llvm::interleaveComma(" << regionListName
1652 << ", p, [&](::mlir::Region ®ion) {\n ";
1653 genRegionPrinter("region", body, hasImplicitTermTrait);
1654 body << " });\n";
1655 }
1656
1657 /// Generate the C++ for an operand to a (*-)type directive.
genTypeOperandPrinter(Element * arg,OpMethodBody & body)1658 static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
1659 if (isa<OperandsDirective>(arg))
1660 return body << "getOperation()->getOperandTypes()";
1661 if (isa<ResultsDirective>(arg))
1662 return body << "getOperation()->getResultTypes()";
1663 auto *operand = dyn_cast<OperandVariable>(arg);
1664 auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
1665 if (var->isVariadic())
1666 return body << var->name << "().getTypes()";
1667 if (var->isOptional())
1668 return body << llvm::formatv(
1669 "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
1670 "::llvm::ArrayRef<::mlir::Type>())",
1671 var->name);
1672 return body << "::llvm::ArrayRef<::mlir::Type>(" << var->name
1673 << "().getType())";
1674 }
1675
1676 /// Generate the printer for an enum attribute.
genEnumAttrPrinter(const NamedAttribute * var,OpMethodBody & body)1677 static void genEnumAttrPrinter(const NamedAttribute *var, OpMethodBody &body) {
1678 Attribute baseAttr = var->attr.getBaseAttr();
1679 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1680 std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1681
1682 body << llvm::formatv(enumAttrBeginPrinterCode,
1683 (var->attr.isOptional() ? "*" : "") + var->name,
1684 enumAttr.getSymbolToStringFnName());
1685
1686 // Get a string containing all of the cases that can't be represented with a
1687 // keyword.
1688 llvm::BitVector nonKeywordCases(cases.size());
1689 bool hasStrCase = false;
1690 for (auto it : llvm::enumerate(cases)) {
1691 hasStrCase = it.value().isStrCase();
1692 if (!canFormatStringAsKeyword(it.value().getStr()))
1693 nonKeywordCases.set(it.index());
1694 }
1695
1696 // If this is a string enum, use the case string to determine which cases
1697 // need to use the string form.
1698 if (hasStrCase) {
1699 if (nonKeywordCases.any()) {
1700 body << " if (llvm::is_contained(llvm::ArrayRef<llvm::StringRef>(";
1701 llvm::interleaveComma(nonKeywordCases.set_bits(), body, [&](unsigned it) {
1702 body << '"' << cases[it].getStr() << '"';
1703 });
1704 body << ")))\n"
1705 " p << '\"' << caseValueStr << '\"';\n"
1706 " else\n ";
1707 }
1708 body << " p << caseValueStr;\n"
1709 " }\n";
1710 return;
1711 }
1712
1713 // Otherwise if this is a bit enum attribute, don't allow cases that may
1714 // overlap with other cases. For simplicity sake, only allow cases with a
1715 // single bit value.
1716 if (enumAttr.isBitEnum()) {
1717 for (auto it : llvm::enumerate(cases)) {
1718 int64_t value = it.value().getValue();
1719 if (value < 0 || !llvm::isPowerOf2_64(value))
1720 nonKeywordCases.set(it.index());
1721 }
1722 }
1723
1724 // If there are any cases that can't be used with a keyword, switch on the
1725 // case value to determine when to print in the string form.
1726 if (nonKeywordCases.any()) {
1727 body << " switch (caseValue) {\n";
1728 StringRef cppNamespace = enumAttr.getCppNamespace();
1729 StringRef enumName = enumAttr.getEnumClassName();
1730 for (auto it : llvm::enumerate(cases)) {
1731 if (nonKeywordCases.test(it.index()))
1732 continue;
1733 StringRef symbol = it.value().getSymbol();
1734 body << llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName,
1735 llvm::isDigit(symbol.front()) ? ("_" + symbol)
1736 : symbol);
1737 }
1738 body << " p << caseValueStr;\n"
1739 " break;\n"
1740 " default:\n"
1741 " p << '\"' << caseValueStr << '\"';\n"
1742 " break;\n"
1743 " }\n"
1744 " }\n";
1745 return;
1746 }
1747
1748 body << " p << caseValueStr;\n"
1749 " }\n";
1750 }
1751
1752 /// Generate the check for the anchor of an optional group.
genOptionalGroupPrinterAnchor(Element * anchor,OpMethodBody & body)1753 static void genOptionalGroupPrinterAnchor(Element *anchor, OpMethodBody &body) {
1754 TypeSwitch<Element *>(anchor)
1755 .Case<OperandVariable, ResultVariable>([&](auto *element) {
1756 const NamedTypeConstraint *var = element->getVar();
1757 if (var->isOptional())
1758 body << " if (" << var->name << "()) {\n";
1759 else if (var->isVariadic())
1760 body << " if (!" << var->name << "().empty()) {\n";
1761 })
1762 .Case<RegionVariable>([&](RegionVariable *element) {
1763 const NamedRegion *var = element->getVar();
1764 // TODO: Add a check for optional regions here when ODS supports it.
1765 body << " if (!" << var->name << "().empty()) {\n";
1766 })
1767 .Case<TypeDirective>([&](TypeDirective *element) {
1768 genOptionalGroupPrinterAnchor(element->getOperand(), body);
1769 })
1770 .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
1771 genOptionalGroupPrinterAnchor(element->getInputs(), body);
1772 })
1773 .Case<AttributeVariable>([&](AttributeVariable *attr) {
1774 body << " if ((*this)->getAttr(\"" << attr->getVar()->name
1775 << "\")) {\n";
1776 });
1777 }
1778
genElementPrinter(Element * element,OpMethodBody & body,Operator & op,bool & shouldEmitSpace,bool & lastWasPunctuation)1779 void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
1780 Operator &op, bool &shouldEmitSpace,
1781 bool &lastWasPunctuation) {
1782 if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
1783 return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
1784 lastWasPunctuation);
1785
1786 // Emit a whitespace element.
1787 if (isa<NewlineElement>(element)) {
1788 body << " p.printNewline();\n";
1789 return;
1790 }
1791 if (SpaceElement *space = dyn_cast<SpaceElement>(element))
1792 return genSpacePrinter(space->getValue(), body, shouldEmitSpace,
1793 lastWasPunctuation);
1794
1795 // Emit an optional group.
1796 if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
1797 // Emit the check for the presence of the anchor element.
1798 Element *anchor = optional->getAnchor();
1799 genOptionalGroupPrinterAnchor(anchor, body);
1800
1801 // If the anchor is a unit attribute, we don't need to print it. When
1802 // parsing, we will add this attribute if this group is present.
1803 auto elements = optional->getElements();
1804 Element *elidedAnchorElement = nullptr;
1805 auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
1806 if (anchorAttr && anchorAttr != &*elements.begin() &&
1807 anchorAttr->isUnitAttr()) {
1808 elidedAnchorElement = anchorAttr;
1809 }
1810
1811 // Emit each of the elements.
1812 for (Element &childElement : elements) {
1813 if (&childElement != elidedAnchorElement) {
1814 genElementPrinter(&childElement, body, op, shouldEmitSpace,
1815 lastWasPunctuation);
1816 }
1817 }
1818 body << " }\n";
1819 return;
1820 }
1821
1822 // Emit the attribute dictionary.
1823 if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1824 genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
1825 lastWasPunctuation = false;
1826 return;
1827 }
1828
1829 // Optionally insert a space before the next element. The AttrDict printer
1830 // already adds a space as necessary.
1831 if (shouldEmitSpace || !lastWasPunctuation)
1832 body << " p << ' ';\n";
1833 lastWasPunctuation = false;
1834 shouldEmitSpace = true;
1835
1836 if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1837 const NamedAttribute *var = attr->getVar();
1838
1839 // If we are formatting as an enum, symbolize the attribute as a string.
1840 if (canFormatEnumAttr(var))
1841 return genEnumAttrPrinter(var, body);
1842
1843 // If we are formatting as a symbol name, handle it as a symbol name.
1844 if (shouldFormatSymbolNameAttr(var)) {
1845 body << " p.printSymbolName(" << var->name << "Attr().getValue());\n";
1846 return;
1847 }
1848
1849 // Elide the attribute type if it is buildable.
1850 if (attr->getTypeBuilder())
1851 body << " p.printAttributeWithoutType(" << var->name << "Attr());\n";
1852 else
1853 body << " p.printAttribute(" << var->name << "Attr());\n";
1854 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1855 if (operand->getVar()->isOptional()) {
1856 body << " if (::mlir::Value value = " << operand->getVar()->name
1857 << "())\n"
1858 << " p << value;\n";
1859 } else {
1860 body << " p << " << operand->getVar()->name << "();\n";
1861 }
1862 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1863 const NamedRegion *var = region->getVar();
1864 if (var->isVariadic()) {
1865 genVariadicRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
1866 } else {
1867 genRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
1868 }
1869 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1870 const NamedSuccessor *var = successor->getVar();
1871 if (var->isVariadic())
1872 body << " ::llvm::interleaveComma(" << var->name << "(), p);\n";
1873 else
1874 body << " p << " << var->name << "();\n";
1875 } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
1876 genCustomDirectivePrinter(dir, body);
1877 } else if (isa<OperandsDirective>(element)) {
1878 body << " p << getOperation()->getOperands();\n";
1879 } else if (isa<RegionsDirective>(element)) {
1880 genVariadicRegionPrinter("getOperation()->getRegions()", body,
1881 hasImplicitTermTrait);
1882 } else if (isa<SuccessorsDirective>(element)) {
1883 body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n";
1884 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1885 body << " p << ";
1886 genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
1887 } else if (auto *dir = dyn_cast<TypeRefDirective>(element)) {
1888 body << " p << ";
1889 genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
1890 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1891 body << " p.printFunctionalType(";
1892 genTypeOperandPrinter(dir->getInputs(), body) << ", ";
1893 genTypeOperandPrinter(dir->getResults(), body) << ");\n";
1894 } else {
1895 llvm_unreachable("unknown format element");
1896 }
1897 }
1898
genPrinter(Operator & op,OpClass & opClass)1899 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
1900 auto *method =
1901 opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &p");
1902 auto &body = method->body();
1903
1904 // Emit the operation name, trimming the prefix if this is the standard
1905 // dialect.
1906 body << " p << \"";
1907 std::string opName = op.getOperationName();
1908 if (op.getDialectName() == "std")
1909 body << StringRef(opName).drop_front(4);
1910 else
1911 body << opName;
1912 body << "\";\n";
1913
1914 // Flags for if we should emit a space, and if the last element was
1915 // punctuation.
1916 bool shouldEmitSpace = true, lastWasPunctuation = false;
1917 for (auto &element : elements)
1918 genElementPrinter(element.get(), body, op, shouldEmitSpace,
1919 lastWasPunctuation);
1920 }
1921
1922 //===----------------------------------------------------------------------===//
1923 // FormatLexer
1924 //===----------------------------------------------------------------------===//
1925
1926 namespace {
1927 /// This class represents a specific token in the input format.
1928 class Token {
1929 public:
1930 enum Kind {
1931 // Markers.
1932 eof,
1933 error,
1934
1935 // Tokens with no info.
1936 l_paren,
1937 r_paren,
1938 caret,
1939 comma,
1940 equal,
1941 less,
1942 greater,
1943 question,
1944
1945 // Keywords.
1946 keyword_start,
1947 kw_attr_dict,
1948 kw_attr_dict_w_keyword,
1949 kw_custom,
1950 kw_functional_type,
1951 kw_operands,
1952 kw_regions,
1953 kw_results,
1954 kw_successors,
1955 kw_type,
1956 kw_type_ref,
1957 keyword_end,
1958
1959 // String valued tokens.
1960 identifier,
1961 literal,
1962 variable,
1963 };
Token(Kind kind,StringRef spelling)1964 Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
1965
1966 /// Return the bytes that make up this token.
getSpelling() const1967 StringRef getSpelling() const { return spelling; }
1968
1969 /// Return the kind of this token.
getKind() const1970 Kind getKind() const { return kind; }
1971
1972 /// Return a location for this token.
getLoc() const1973 llvm::SMLoc getLoc() const {
1974 return llvm::SMLoc::getFromPointer(spelling.data());
1975 }
1976
1977 /// Return if this token is a keyword.
isKeyword() const1978 bool isKeyword() const { return kind > keyword_start && kind < keyword_end; }
1979
1980 private:
1981 /// Discriminator that indicates the kind of token this is.
1982 Kind kind;
1983
1984 /// A reference to the entire token contents; this is always a pointer into
1985 /// a memory buffer owned by the source manager.
1986 StringRef spelling;
1987 };
1988
1989 /// This class implements a simple lexer for operation assembly format strings.
1990 class FormatLexer {
1991 public:
1992 FormatLexer(llvm::SourceMgr &mgr, Operator &op);
1993
1994 /// Lex the next token and return it.
1995 Token lexToken();
1996
1997 /// Emit an error to the lexer with the given location and message.
1998 Token emitError(llvm::SMLoc loc, const Twine &msg);
1999 Token emitError(const char *loc, const Twine &msg);
2000
2001 Token emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, const Twine ¬e);
2002
2003 private:
formToken(Token::Kind kind,const char * tokStart)2004 Token formToken(Token::Kind kind, const char *tokStart) {
2005 return Token(kind, StringRef(tokStart, curPtr - tokStart));
2006 }
2007
2008 /// Return the next character in the stream.
2009 int getNextChar();
2010
2011 /// Lex an identifier, literal, or variable.
2012 Token lexIdentifier(const char *tokStart);
2013 Token lexLiteral(const char *tokStart);
2014 Token lexVariable(const char *tokStart);
2015
2016 llvm::SourceMgr &srcMgr;
2017 Operator &op;
2018 StringRef curBuffer;
2019 const char *curPtr;
2020 };
2021 } // end anonymous namespace
2022
FormatLexer(llvm::SourceMgr & mgr,Operator & op)2023 FormatLexer::FormatLexer(llvm::SourceMgr &mgr, Operator &op)
2024 : srcMgr(mgr), op(op) {
2025 curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
2026 curPtr = curBuffer.begin();
2027 }
2028
emitError(llvm::SMLoc loc,const Twine & msg)2029 Token FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) {
2030 srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
2031 llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
2032 "in custom assembly format for this operation");
2033 return formToken(Token::error, loc.getPointer());
2034 }
emitErrorAndNote(llvm::SMLoc loc,const Twine & msg,const Twine & note)2035 Token FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
2036 const Twine ¬e) {
2037 srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
2038 llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
2039 "in custom assembly format for this operation");
2040 srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
2041 return formToken(Token::error, loc.getPointer());
2042 }
emitError(const char * loc,const Twine & msg)2043 Token FormatLexer::emitError(const char *loc, const Twine &msg) {
2044 return emitError(llvm::SMLoc::getFromPointer(loc), msg);
2045 }
2046
getNextChar()2047 int FormatLexer::getNextChar() {
2048 char curChar = *curPtr++;
2049 switch (curChar) {
2050 default:
2051 return (unsigned char)curChar;
2052 case 0: {
2053 // A nul character in the stream is either the end of the current buffer or
2054 // a random nul in the file. Disambiguate that here.
2055 if (curPtr - 1 != curBuffer.end())
2056 return 0;
2057
2058 // Otherwise, return end of file.
2059 --curPtr;
2060 return EOF;
2061 }
2062 case '\n':
2063 case '\r':
2064 // Handle the newline character by ignoring it and incrementing the line
2065 // count. However, be careful about 'dos style' files with \n\r in them.
2066 // Only treat a \n\r or \r\n as a single line.
2067 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
2068 ++curPtr;
2069 return '\n';
2070 }
2071 }
2072
lexToken()2073 Token FormatLexer::lexToken() {
2074 const char *tokStart = curPtr;
2075
2076 // This always consumes at least one character.
2077 int curChar = getNextChar();
2078 switch (curChar) {
2079 default:
2080 // Handle identifiers: [a-zA-Z_]
2081 if (isalpha(curChar) || curChar == '_')
2082 return lexIdentifier(tokStart);
2083
2084 // Unknown character, emit an error.
2085 return emitError(tokStart, "unexpected character");
2086 case EOF:
2087 // Return EOF denoting the end of lexing.
2088 return formToken(Token::eof, tokStart);
2089
2090 // Lex punctuation.
2091 case '^':
2092 return formToken(Token::caret, tokStart);
2093 case ',':
2094 return formToken(Token::comma, tokStart);
2095 case '=':
2096 return formToken(Token::equal, tokStart);
2097 case '<':
2098 return formToken(Token::less, tokStart);
2099 case '>':
2100 return formToken(Token::greater, tokStart);
2101 case '?':
2102 return formToken(Token::question, tokStart);
2103 case '(':
2104 return formToken(Token::l_paren, tokStart);
2105 case ')':
2106 return formToken(Token::r_paren, tokStart);
2107
2108 // Ignore whitespace characters.
2109 case 0:
2110 case ' ':
2111 case '\t':
2112 case '\n':
2113 return lexToken();
2114
2115 case '`':
2116 return lexLiteral(tokStart);
2117 case '$':
2118 return lexVariable(tokStart);
2119 }
2120 }
2121
lexLiteral(const char * tokStart)2122 Token FormatLexer::lexLiteral(const char *tokStart) {
2123 assert(curPtr[-1] == '`');
2124
2125 // Lex a literal surrounded by ``.
2126 while (const char curChar = *curPtr++) {
2127 if (curChar == '`')
2128 return formToken(Token::literal, tokStart);
2129 }
2130 return emitError(curPtr - 1, "unexpected end of file in literal");
2131 }
2132
lexVariable(const char * tokStart)2133 Token FormatLexer::lexVariable(const char *tokStart) {
2134 if (!isalpha(curPtr[0]) && curPtr[0] != '_')
2135 return emitError(curPtr - 1, "expected variable name");
2136
2137 // Otherwise, consume the rest of the characters.
2138 while (isalnum(*curPtr) || *curPtr == '_')
2139 ++curPtr;
2140 return formToken(Token::variable, tokStart);
2141 }
2142
lexIdentifier(const char * tokStart)2143 Token FormatLexer::lexIdentifier(const char *tokStart) {
2144 // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
2145 while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
2146 ++curPtr;
2147
2148 // Check to see if this identifier is a keyword.
2149 StringRef str(tokStart, curPtr - tokStart);
2150 Token::Kind kind =
2151 StringSwitch<Token::Kind>(str)
2152 .Case("attr-dict", Token::kw_attr_dict)
2153 .Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
2154 .Case("custom", Token::kw_custom)
2155 .Case("functional-type", Token::kw_functional_type)
2156 .Case("operands", Token::kw_operands)
2157 .Case("regions", Token::kw_regions)
2158 .Case("results", Token::kw_results)
2159 .Case("successors", Token::kw_successors)
2160 .Case("type", Token::kw_type)
2161 .Case("type_ref", Token::kw_type_ref)
2162 .Default(Token::identifier);
2163 return Token(kind, str);
2164 }
2165
2166 //===----------------------------------------------------------------------===//
2167 // FormatParser
2168 //===----------------------------------------------------------------------===//
2169
2170 /// Function to find an element within the given range that has the same name as
2171 /// 'name'.
2172 template <typename RangeT>
findArg(RangeT && range,StringRef name)2173 static auto findArg(RangeT &&range, StringRef name) {
2174 auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
2175 return it != range.end() ? &*it : nullptr;
2176 }
2177
2178 namespace {
2179 /// This class implements a parser for an instance of an operation assembly
2180 /// format.
2181 class FormatParser {
2182 public:
FormatParser(llvm::SourceMgr & mgr,OperationFormat & format,Operator & op)2183 FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
2184 : lexer(mgr, op), curToken(lexer.lexToken()), fmt(format), op(op),
2185 seenOperandTypes(op.getNumOperands()),
2186 seenResultTypes(op.getNumResults()) {}
2187
2188 /// Parse the operation assembly format.
2189 LogicalResult parse();
2190
2191 private:
2192 /// This struct represents a type resolution instance. It includes a specific
2193 /// type as well as an optional transformer to apply to that type in order to
2194 /// properly resolve the type of a variable.
2195 struct TypeResolutionInstance {
2196 ConstArgument resolver;
2197 Optional<StringRef> transformer;
2198 };
2199
2200 /// An iterator over the elements of a format group.
2201 using ElementsIterT = llvm::pointee_iterator<
2202 std::vector<std::unique_ptr<Element>>::const_iterator>;
2203
2204 /// Verify the state of operation attributes within the format.
2205 LogicalResult verifyAttributes(llvm::SMLoc loc);
2206 /// Verify the attribute elements at the back of the given stack of iterators.
2207 LogicalResult verifyAttributes(
2208 llvm::SMLoc loc,
2209 SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack);
2210
2211 /// Verify the state of operation operands within the format.
2212 LogicalResult
2213 verifyOperands(llvm::SMLoc loc,
2214 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2215
2216 /// Verify the state of operation regions within the format.
2217 LogicalResult verifyRegions(llvm::SMLoc loc);
2218
2219 /// Verify the state of operation results within the format.
2220 LogicalResult
2221 verifyResults(llvm::SMLoc loc,
2222 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2223
2224 /// Verify the state of operation successors within the format.
2225 LogicalResult verifySuccessors(llvm::SMLoc loc);
2226
2227 /// Given the values of an `AllTypesMatch` trait, check for inferable type
2228 /// resolution.
2229 void handleAllTypesMatchConstraint(
2230 ArrayRef<StringRef> values,
2231 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2232 /// Check for inferable type resolution given all operands, and or results,
2233 /// have the same type. If 'includeResults' is true, the results also have the
2234 /// same type as all of the operands.
2235 void handleSameTypesConstraint(
2236 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2237 bool includeResults);
2238 /// Check for inferable type resolution based on another operand, result, or
2239 /// attribute.
2240 void handleTypesMatchConstraint(
2241 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2242 llvm::Record def);
2243
2244 /// Returns an argument or attribute with the given name that has been seen
2245 /// within the format.
2246 ConstArgument findSeenArg(StringRef name);
2247
2248 /// Parse a specific element.
2249 LogicalResult parseElement(std::unique_ptr<Element> &element,
2250 bool isTopLevel);
2251 LogicalResult parseVariable(std::unique_ptr<Element> &element,
2252 bool isTopLevel);
2253 LogicalResult parseDirective(std::unique_ptr<Element> &element,
2254 bool isTopLevel);
2255 LogicalResult parseLiteral(std::unique_ptr<Element> &element);
2256 LogicalResult parseOptional(std::unique_ptr<Element> &element,
2257 bool isTopLevel);
2258 LogicalResult parseOptionalChildElement(
2259 std::vector<std::unique_ptr<Element>> &childElements,
2260 Optional<unsigned> &anchorIdx);
2261 LogicalResult verifyOptionalChildElement(Element *element,
2262 llvm::SMLoc childLoc, bool isAnchor);
2263
2264 /// Parse the various different directives.
2265 LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
2266 llvm::SMLoc loc, bool isTopLevel,
2267 bool withKeyword);
2268 LogicalResult parseCustomDirective(std::unique_ptr<Element> &element,
2269 llvm::SMLoc loc, bool isTopLevel);
2270 LogicalResult parseCustomDirectiveParameter(
2271 std::vector<std::unique_ptr<Element>> ¶meters);
2272 LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
2273 Token tok, bool isTopLevel);
2274 LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
2275 llvm::SMLoc loc, bool isTopLevel);
2276 LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element,
2277 llvm::SMLoc loc, bool isTopLevel);
2278 LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
2279 llvm::SMLoc loc, bool isTopLevel);
2280 LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
2281 llvm::SMLoc loc, bool isTopLevel);
2282 LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
2283 bool isTopLevel, bool isTypeRef = false);
2284 LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
2285 bool isTypeRef = false);
2286
2287 //===--------------------------------------------------------------------===//
2288 // Lexer Utilities
2289 //===--------------------------------------------------------------------===//
2290
2291 /// Advance the current lexer onto the next token.
consumeToken()2292 void consumeToken() {
2293 assert(curToken.getKind() != Token::eof &&
2294 curToken.getKind() != Token::error &&
2295 "shouldn't advance past EOF or errors");
2296 curToken = lexer.lexToken();
2297 }
parseToken(Token::Kind kind,const Twine & msg)2298 LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
2299 if (curToken.getKind() != kind)
2300 return emitError(curToken.getLoc(), msg);
2301 consumeToken();
2302 return ::mlir::success();
2303 }
emitError(llvm::SMLoc loc,const Twine & msg)2304 LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
2305 lexer.emitError(loc, msg);
2306 return ::mlir::failure();
2307 }
emitErrorAndNote(llvm::SMLoc loc,const Twine & msg,const Twine & note)2308 LogicalResult emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
2309 const Twine ¬e) {
2310 lexer.emitErrorAndNote(loc, msg, note);
2311 return ::mlir::failure();
2312 }
2313
2314 //===--------------------------------------------------------------------===//
2315 // Fields
2316 //===--------------------------------------------------------------------===//
2317
2318 FormatLexer lexer;
2319 Token curToken;
2320 OperationFormat &fmt;
2321 Operator &op;
2322
2323 // The following are various bits of format state used for verification
2324 // during parsing.
2325 bool hasAttrDict = false;
2326 bool hasAllRegions = false, hasAllSuccessors = false;
2327 llvm::SmallBitVector seenOperandTypes, seenResultTypes;
2328 llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
2329 llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2330 llvm::DenseSet<const NamedRegion *> seenRegions;
2331 llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2332 };
2333 } // end anonymous namespace
2334
parse()2335 LogicalResult FormatParser::parse() {
2336 llvm::SMLoc loc = curToken.getLoc();
2337
2338 // Parse each of the format elements into the main format.
2339 while (curToken.getKind() != Token::eof) {
2340 std::unique_ptr<Element> element;
2341 if (failed(parseElement(element, /*isTopLevel=*/true)))
2342 return ::mlir::failure();
2343 fmt.elements.push_back(std::move(element));
2344 }
2345
2346 // Check that the attribute dictionary is in the format.
2347 if (!hasAttrDict)
2348 return emitError(loc, "'attr-dict' directive not found in "
2349 "custom assembly format");
2350
2351 // Check for any type traits that we can use for inferring types.
2352 llvm::StringMap<TypeResolutionInstance> variableTyResolver;
2353 for (const OpTrait &trait : op.getTraits()) {
2354 const llvm::Record &def = trait.getDef();
2355 if (def.isSubClassOf("AllTypesMatch")) {
2356 handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
2357 variableTyResolver);
2358 } else if (def.getName() == "SameTypeOperands") {
2359 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
2360 } else if (def.getName() == "SameOperandsAndResultType") {
2361 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
2362 } else if (def.isSubClassOf("TypesMatchWith")) {
2363 handleTypesMatchConstraint(variableTyResolver, def);
2364 }
2365 }
2366
2367 // Verify the state of the various operation components.
2368 if (failed(verifyAttributes(loc)) ||
2369 failed(verifyResults(loc, variableTyResolver)) ||
2370 failed(verifyOperands(loc, variableTyResolver)) ||
2371 failed(verifyRegions(loc)) || failed(verifySuccessors(loc)))
2372 return ::mlir::failure();
2373
2374 // Collect the set of used attributes in the format.
2375 fmt.usedAttributes = seenAttrs.takeVector();
2376 return ::mlir::success();
2377 }
2378
verifyAttributes(llvm::SMLoc loc)2379 LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) {
2380 // Check that there are no `:` literals after an attribute without a constant
2381 // type. The attribute grammar contains an optional trailing colon type, which
2382 // can lead to unexpected and generally unintended behavior. Given that, it is
2383 // better to just error out here instead.
2384 using ElementsIterT = llvm::pointee_iterator<
2385 std::vector<std::unique_ptr<Element>>::const_iterator>;
2386 SmallVector<std::pair<ElementsIterT, ElementsIterT>, 1> iteratorStack;
2387 iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end());
2388 while (!iteratorStack.empty())
2389 if (failed(verifyAttributes(loc, iteratorStack)))
2390 return ::mlir::failure();
2391 return ::mlir::success();
2392 }
2393 /// Verify the attribute elements at the back of the given stack of iterators.
verifyAttributes(llvm::SMLoc loc,SmallVectorImpl<std::pair<ElementsIterT,ElementsIterT>> & iteratorStack)2394 LogicalResult FormatParser::verifyAttributes(
2395 llvm::SMLoc loc,
2396 SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack) {
2397 auto &stackIt = iteratorStack.back();
2398 ElementsIterT &it = stackIt.first, e = stackIt.second;
2399 while (it != e) {
2400 Element *element = &*(it++);
2401
2402 // Traverse into optional groups.
2403 if (auto *optional = dyn_cast<OptionalElement>(element)) {
2404 auto elements = optional->getElements();
2405 iteratorStack.emplace_back(elements.begin(), elements.end());
2406 return ::mlir::success();
2407 }
2408
2409 // We are checking for an attribute element followed by a `:`, so there is
2410 // no need to check the end.
2411 if (it == e && iteratorStack.size() == 1)
2412 break;
2413
2414 // Check for an attribute with a constant type builder, followed by a `:`.
2415 auto *prevAttr = dyn_cast<AttributeVariable>(element);
2416 if (!prevAttr || prevAttr->getTypeBuilder())
2417 continue;
2418
2419 // Check the next iterator within the stack for literal elements.
2420 for (auto &nextItPair : iteratorStack) {
2421 ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second;
2422 for (; nextIt != nextE; ++nextIt) {
2423 // Skip any trailing whitespace, attribute dictionaries, or optional
2424 // groups.
2425 if (isa<WhitespaceElement>(*nextIt) ||
2426 isa<AttrDictDirective>(*nextIt) || isa<OptionalElement>(*nextIt))
2427 continue;
2428
2429 // We are only interested in `:` literals.
2430 auto *literal = dyn_cast<LiteralElement>(&*nextIt);
2431 if (!literal || literal->getLiteral() != ":")
2432 break;
2433
2434 // TODO: Use the location of the literal element itself.
2435 return emitError(
2436 loc, llvm::formatv("format ambiguity caused by `:` literal found "
2437 "after attribute `{0}` which does not have "
2438 "a buildable type",
2439 prevAttr->getVar()->name));
2440 }
2441 }
2442 }
2443 iteratorStack.pop_back();
2444 return ::mlir::success();
2445 }
2446
verifyOperands(llvm::SMLoc loc,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2447 LogicalResult FormatParser::verifyOperands(
2448 llvm::SMLoc loc,
2449 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2450 // Check that all of the operands are within the format, and their types can
2451 // be inferred.
2452 auto &buildableTypes = fmt.buildableTypes;
2453 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
2454 NamedTypeConstraint &operand = op.getOperand(i);
2455
2456 // Check that the operand itself is in the format.
2457 if (!fmt.allOperands && !seenOperands.count(&operand)) {
2458 return emitErrorAndNote(loc,
2459 "operand #" + Twine(i) + ", named '" +
2460 operand.name + "', not found",
2461 "suggest adding a '$" + operand.name +
2462 "' directive to the custom assembly format");
2463 }
2464
2465 // Check that the operand type is in the format, or that it can be inferred.
2466 if (fmt.allOperandTypes || seenOperandTypes.test(i))
2467 continue;
2468
2469 // Check to see if we can infer this type from another variable.
2470 auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
2471 if (varResolverIt != variableTyResolver.end()) {
2472 TypeResolutionInstance &resolver = varResolverIt->second;
2473 fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
2474 continue;
2475 }
2476
2477 // Similarly to results, allow a custom builder for resolving the type if
2478 // we aren't using the 'operands' directive.
2479 Optional<StringRef> builder = operand.constraint.getBuilderCall();
2480 if (!builder || (fmt.allOperands && operand.isVariableLength())) {
2481 return emitErrorAndNote(
2482 loc,
2483 "type of operand #" + Twine(i) + ", named '" + operand.name +
2484 "', is not buildable and a buildable type cannot be inferred",
2485 "suggest adding a type constraint to the operation or adding a "
2486 "'type($" +
2487 operand.name + ")' directive to the " + "custom assembly format");
2488 }
2489 auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2490 fmt.operandTypes[i].setBuilderIdx(it.first->second);
2491 }
2492 return ::mlir::success();
2493 }
2494
verifyRegions(llvm::SMLoc loc)2495 LogicalResult FormatParser::verifyRegions(llvm::SMLoc loc) {
2496 // Check that all of the regions are within the format.
2497 if (hasAllRegions)
2498 return ::mlir::success();
2499
2500 for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
2501 const NamedRegion ®ion = op.getRegion(i);
2502 if (!seenRegions.count(®ion)) {
2503 return emitErrorAndNote(loc,
2504 "region #" + Twine(i) + ", named '" +
2505 region.name + "', not found",
2506 "suggest adding a '$" + region.name +
2507 "' directive to the custom assembly format");
2508 }
2509 }
2510 return ::mlir::success();
2511 }
2512
verifyResults(llvm::SMLoc loc,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2513 LogicalResult FormatParser::verifyResults(
2514 llvm::SMLoc loc,
2515 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2516 // If we format all of the types together, there is nothing to check.
2517 if (fmt.allResultTypes)
2518 return ::mlir::success();
2519
2520 // Check that all of the result types can be inferred.
2521 auto &buildableTypes = fmt.buildableTypes;
2522 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
2523 if (seenResultTypes.test(i))
2524 continue;
2525
2526 // Check to see if we can infer this type from another variable.
2527 auto varResolverIt = variableTyResolver.find(op.getResultName(i));
2528 if (varResolverIt != variableTyResolver.end()) {
2529 TypeResolutionInstance resolver = varResolverIt->second;
2530 fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
2531 continue;
2532 }
2533
2534 // If the result is not variable length, allow for the case where the type
2535 // has a builder that we can use.
2536 NamedTypeConstraint &result = op.getResult(i);
2537 Optional<StringRef> builder = result.constraint.getBuilderCall();
2538 if (!builder || result.isVariableLength()) {
2539 return emitErrorAndNote(
2540 loc,
2541 "type of result #" + Twine(i) + ", named '" + result.name +
2542 "', is not buildable and a buildable type cannot be inferred",
2543 "suggest adding a type constraint to the operation or adding a "
2544 "'type($" +
2545 result.name + ")' directive to the " + "custom assembly format");
2546 }
2547 // Note in the format that this result uses the custom builder.
2548 auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2549 fmt.resultTypes[i].setBuilderIdx(it.first->second);
2550 }
2551 return ::mlir::success();
2552 }
2553
verifySuccessors(llvm::SMLoc loc)2554 LogicalResult FormatParser::verifySuccessors(llvm::SMLoc loc) {
2555 // Check that all of the successors are within the format.
2556 if (hasAllSuccessors)
2557 return ::mlir::success();
2558
2559 for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
2560 const NamedSuccessor &successor = op.getSuccessor(i);
2561 if (!seenSuccessors.count(&successor)) {
2562 return emitErrorAndNote(loc,
2563 "successor #" + Twine(i) + ", named '" +
2564 successor.name + "', not found",
2565 "suggest adding a '$" + successor.name +
2566 "' directive to the custom assembly format");
2567 }
2568 }
2569 return ::mlir::success();
2570 }
2571
handleAllTypesMatchConstraint(ArrayRef<StringRef> values,llvm::StringMap<TypeResolutionInstance> & variableTyResolver)2572 void FormatParser::handleAllTypesMatchConstraint(
2573 ArrayRef<StringRef> values,
2574 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2575 for (unsigned i = 0, e = values.size(); i != e; ++i) {
2576 // Check to see if this value matches a resolved operand or result type.
2577 ConstArgument arg = findSeenArg(values[i]);
2578 if (!arg)
2579 continue;
2580
2581 // Mark this value as the type resolver for the other variables.
2582 for (unsigned j = 0; j != i; ++j)
2583 variableTyResolver[values[j]] = {arg, llvm::None};
2584 for (unsigned j = i + 1; j != e; ++j)
2585 variableTyResolver[values[j]] = {arg, llvm::None};
2586 }
2587 }
2588
handleSameTypesConstraint(llvm::StringMap<TypeResolutionInstance> & variableTyResolver,bool includeResults)2589 void FormatParser::handleSameTypesConstraint(
2590 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2591 bool includeResults) {
2592 const NamedTypeConstraint *resolver = nullptr;
2593 int resolvedIt = -1;
2594
2595 // Check to see if there is an operand or result to use for the resolution.
2596 if ((resolvedIt = seenOperandTypes.find_first()) != -1)
2597 resolver = &op.getOperand(resolvedIt);
2598 else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
2599 resolver = &op.getResult(resolvedIt);
2600 else
2601 return;
2602
2603 // Set the resolvers for each operand and result.
2604 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
2605 if (!seenOperandTypes.test(i) && !op.getOperand(i).name.empty())
2606 variableTyResolver[op.getOperand(i).name] = {resolver, llvm::None};
2607 if (includeResults) {
2608 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2609 if (!seenResultTypes.test(i) && !op.getResultName(i).empty())
2610 variableTyResolver[op.getResultName(i)] = {resolver, llvm::None};
2611 }
2612 }
2613
handleTypesMatchConstraint(llvm::StringMap<TypeResolutionInstance> & variableTyResolver,llvm::Record def)2614 void FormatParser::handleTypesMatchConstraint(
2615 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2616 llvm::Record def) {
2617 StringRef lhsName = def.getValueAsString("lhs");
2618 StringRef rhsName = def.getValueAsString("rhs");
2619 StringRef transformer = def.getValueAsString("transformer");
2620 if (ConstArgument arg = findSeenArg(lhsName))
2621 variableTyResolver[rhsName] = {arg, transformer};
2622 }
2623
findSeenArg(StringRef name)2624 ConstArgument FormatParser::findSeenArg(StringRef name) {
2625 if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
2626 return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
2627 if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
2628 return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
2629 if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
2630 return seenAttrs.count(attr) ? attr : nullptr;
2631 return nullptr;
2632 }
2633
parseElement(std::unique_ptr<Element> & element,bool isTopLevel)2634 LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
2635 bool isTopLevel) {
2636 // Directives.
2637 if (curToken.isKeyword())
2638 return parseDirective(element, isTopLevel);
2639 // Literals.
2640 if (curToken.getKind() == Token::literal)
2641 return parseLiteral(element);
2642 // Optionals.
2643 if (curToken.getKind() == Token::l_paren)
2644 return parseOptional(element, isTopLevel);
2645 // Variables.
2646 if (curToken.getKind() == Token::variable)
2647 return parseVariable(element, isTopLevel);
2648 return emitError(curToken.getLoc(),
2649 "expected directive, literal, variable, or optional group");
2650 }
2651
parseVariable(std::unique_ptr<Element> & element,bool isTopLevel)2652 LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
2653 bool isTopLevel) {
2654 Token varTok = curToken;
2655 consumeToken();
2656
2657 StringRef name = varTok.getSpelling().drop_front();
2658 llvm::SMLoc loc = varTok.getLoc();
2659
2660 // Check that the parsed argument is something actually registered on the
2661 // op.
2662 /// Attributes
2663 if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
2664 if (isTopLevel && !seenAttrs.insert(attr))
2665 return emitError(loc, "attribute '" + name + "' is already bound");
2666 element = std::make_unique<AttributeVariable>(attr);
2667 return ::mlir::success();
2668 }
2669 /// Operands
2670 if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
2671 if (isTopLevel) {
2672 if (fmt.allOperands || !seenOperands.insert(operand).second)
2673 return emitError(loc, "operand '" + name + "' is already bound");
2674 }
2675 element = std::make_unique<OperandVariable>(operand);
2676 return ::mlir::success();
2677 }
2678 /// Regions
2679 if (const NamedRegion *region = findArg(op.getRegions(), name)) {
2680 if (!isTopLevel)
2681 return emitError(loc, "regions can only be used at the top level");
2682 if (hasAllRegions || !seenRegions.insert(region).second)
2683 return emitError(loc, "region '" + name + "' is already bound");
2684 element = std::make_unique<RegionVariable>(region);
2685 return ::mlir::success();
2686 }
2687 /// Results.
2688 if (const auto *result = findArg(op.getResults(), name)) {
2689 if (isTopLevel)
2690 return emitError(loc, "results can not be used at the top level");
2691 element = std::make_unique<ResultVariable>(result);
2692 return ::mlir::success();
2693 }
2694 /// Successors.
2695 if (const auto *successor = findArg(op.getSuccessors(), name)) {
2696 if (!isTopLevel)
2697 return emitError(loc, "successors can only be used at the top level");
2698 if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
2699 return emitError(loc, "successor '" + name + "' is already bound");
2700 element = std::make_unique<SuccessorVariable>(successor);
2701 return ::mlir::success();
2702 }
2703 return emitError(loc, "expected variable to refer to an argument, region, "
2704 "result, or successor");
2705 }
2706
parseDirective(std::unique_ptr<Element> & element,bool isTopLevel)2707 LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
2708 bool isTopLevel) {
2709 Token dirTok = curToken;
2710 consumeToken();
2711
2712 switch (dirTok.getKind()) {
2713 case Token::kw_attr_dict:
2714 return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
2715 /*withKeyword=*/false);
2716 case Token::kw_attr_dict_w_keyword:
2717 return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
2718 /*withKeyword=*/true);
2719 case Token::kw_custom:
2720 return parseCustomDirective(element, dirTok.getLoc(), isTopLevel);
2721 case Token::kw_functional_type:
2722 return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
2723 case Token::kw_operands:
2724 return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel);
2725 case Token::kw_regions:
2726 return parseRegionsDirective(element, dirTok.getLoc(), isTopLevel);
2727 case Token::kw_results:
2728 return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
2729 case Token::kw_successors:
2730 return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel);
2731 case Token::kw_type_ref:
2732 return parseTypeDirective(element, dirTok, isTopLevel, /*isTypeRef=*/true);
2733 case Token::kw_type:
2734 return parseTypeDirective(element, dirTok, isTopLevel);
2735
2736 default:
2737 llvm_unreachable("unknown directive token");
2738 }
2739 }
2740
parseLiteral(std::unique_ptr<Element> & element)2741 LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element) {
2742 Token literalTok = curToken;
2743 consumeToken();
2744
2745 StringRef value = literalTok.getSpelling().drop_front().drop_back();
2746
2747 // The parsed literal is a space element (`` or ` `).
2748 if (value.empty() || (value.size() == 1 && value.front() == ' ')) {
2749 element = std::make_unique<SpaceElement>(!value.empty());
2750 return ::mlir::success();
2751 }
2752 // The parsed literal is a newline element.
2753 if (value == "\\n") {
2754 element = std::make_unique<NewlineElement>();
2755 return ::mlir::success();
2756 }
2757
2758 // Check that the parsed literal is valid.
2759 if (!LiteralElement::isValidLiteral(value))
2760 return emitError(literalTok.getLoc(), "expected valid literal");
2761
2762 element = std::make_unique<LiteralElement>(value);
2763 return ::mlir::success();
2764 }
2765
parseOptional(std::unique_ptr<Element> & element,bool isTopLevel)2766 LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
2767 bool isTopLevel) {
2768 llvm::SMLoc curLoc = curToken.getLoc();
2769 if (!isTopLevel)
2770 return emitError(curLoc, "optional groups can only be used as top-level "
2771 "elements");
2772 consumeToken();
2773
2774 // Parse the child elements for this optional group.
2775 std::vector<std::unique_ptr<Element>> elements;
2776 Optional<unsigned> anchorIdx;
2777 do {
2778 if (failed(parseOptionalChildElement(elements, anchorIdx)))
2779 return ::mlir::failure();
2780 } while (curToken.getKind() != Token::r_paren);
2781 consumeToken();
2782 if (failed(parseToken(Token::question, "expected '?' after optional group")))
2783 return ::mlir::failure();
2784
2785 // The optional group is required to have an anchor.
2786 if (!anchorIdx)
2787 return emitError(curLoc, "optional group specified no anchor element");
2788
2789 // The first parsable element of the group must be able to be parsed in an
2790 // optional fashion.
2791 auto parseBegin = llvm::find_if_not(elements, [](auto &element) {
2792 return isa<WhitespaceElement>(element.get());
2793 });
2794 Element *firstElement = parseBegin->get();
2795 if (!isa<AttributeVariable>(firstElement) &&
2796 !isa<LiteralElement>(firstElement) &&
2797 !isa<OperandVariable>(firstElement) && !isa<RegionVariable>(firstElement))
2798 return emitError(curLoc,
2799 "first parsable element of an operand group must be "
2800 "an attribute, literal, operand, or region");
2801
2802 auto parseStart = parseBegin - elements.begin();
2803 element = std::make_unique<OptionalElement>(std::move(elements), *anchorIdx,
2804 parseStart);
2805 return ::mlir::success();
2806 }
2807
parseOptionalChildElement(std::vector<std::unique_ptr<Element>> & childElements,Optional<unsigned> & anchorIdx)2808 LogicalResult FormatParser::parseOptionalChildElement(
2809 std::vector<std::unique_ptr<Element>> &childElements,
2810 Optional<unsigned> &anchorIdx) {
2811 llvm::SMLoc childLoc = curToken.getLoc();
2812 childElements.push_back({});
2813 if (failed(parseElement(childElements.back(), /*isTopLevel=*/true)))
2814 return ::mlir::failure();
2815
2816 // Check to see if this element is the anchor of the optional group.
2817 bool isAnchor = curToken.getKind() == Token::caret;
2818 if (isAnchor) {
2819 if (anchorIdx)
2820 return emitError(childLoc, "only one element can be marked as the anchor "
2821 "of an optional group");
2822 anchorIdx = childElements.size() - 1;
2823 consumeToken();
2824 }
2825
2826 return verifyOptionalChildElement(childElements.back().get(), childLoc,
2827 isAnchor);
2828 }
2829
verifyOptionalChildElement(Element * element,llvm::SMLoc childLoc,bool isAnchor)2830 LogicalResult FormatParser::verifyOptionalChildElement(Element *element,
2831 llvm::SMLoc childLoc,
2832 bool isAnchor) {
2833 return TypeSwitch<Element *, LogicalResult>(element)
2834 // All attributes can be within the optional group, but only optional
2835 // attributes can be the anchor.
2836 .Case([&](AttributeVariable *attrEle) {
2837 if (isAnchor && !attrEle->getVar()->attr.isOptional())
2838 return emitError(childLoc, "only optional attributes can be used to "
2839 "anchor an optional group");
2840 return ::mlir::success();
2841 })
2842 // Only optional-like(i.e. variadic) operands can be within an optional
2843 // group.
2844 .Case<OperandVariable>([&](OperandVariable *ele) {
2845 if (!ele->getVar()->isVariableLength())
2846 return emitError(childLoc, "only variable length operands can be "
2847 "used within an optional group");
2848 return ::mlir::success();
2849 })
2850 // Only optional-like(i.e. variadic) results can be within an optional
2851 // group.
2852 .Case<ResultVariable>([&](ResultVariable *ele) {
2853 if (!ele->getVar()->isVariableLength())
2854 return emitError(childLoc, "only variable length results can be "
2855 "used within an optional group");
2856 return ::mlir::success();
2857 })
2858 .Case<RegionVariable>([&](RegionVariable *) {
2859 // TODO: When ODS has proper support for marking "optional" regions, add
2860 // a check here.
2861 return ::mlir::success();
2862 })
2863 .Case<TypeDirective>([&](TypeDirective *ele) {
2864 return verifyOptionalChildElement(ele->getOperand(), childLoc,
2865 /*isAnchor=*/false);
2866 })
2867 .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *ele) {
2868 if (failed(verifyOptionalChildElement(ele->getInputs(), childLoc,
2869 /*isAnchor=*/false)))
2870 return failure();
2871 return verifyOptionalChildElement(ele->getResults(), childLoc,
2872 /*isAnchor=*/false);
2873 })
2874 // Literals, whitespace, and custom directives may be used, but they can't
2875 // anchor the group.
2876 .Case<LiteralElement, WhitespaceElement, CustomDirective,
2877 FunctionalTypeDirective, OptionalElement, TypeRefDirective>(
2878 [&](Element *) {
2879 if (isAnchor)
2880 return emitError(childLoc, "only variables and types can be used "
2881 "to anchor an optional group");
2882 return ::mlir::success();
2883 })
2884 .Default([&](Element *) {
2885 return emitError(childLoc, "only literals, types, and variables can be "
2886 "used within an optional group");
2887 });
2888 }
2889
2890 LogicalResult
parseAttrDictDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel,bool withKeyword)2891 FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
2892 llvm::SMLoc loc, bool isTopLevel,
2893 bool withKeyword) {
2894 if (!isTopLevel)
2895 return emitError(loc, "'attr-dict' directive can only be used as a "
2896 "top-level directive");
2897 if (hasAttrDict)
2898 return emitError(loc, "'attr-dict' directive has already been seen");
2899
2900 hasAttrDict = true;
2901 element = std::make_unique<AttrDictDirective>(withKeyword);
2902 return ::mlir::success();
2903 }
2904
2905 LogicalResult
parseCustomDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)2906 FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
2907 llvm::SMLoc loc, bool isTopLevel) {
2908 llvm::SMLoc curLoc = curToken.getLoc();
2909
2910 // Parse the custom directive name.
2911 if (failed(
2912 parseToken(Token::less, "expected '<' before custom directive name")))
2913 return ::mlir::failure();
2914
2915 Token nameTok = curToken;
2916 if (failed(parseToken(Token::identifier,
2917 "expected custom directive name identifier")) ||
2918 failed(parseToken(Token::greater,
2919 "expected '>' after custom directive name")) ||
2920 failed(parseToken(Token::l_paren,
2921 "expected '(' before custom directive parameters")))
2922 return ::mlir::failure();
2923
2924 // Parse the child elements for this optional group.=
2925 std::vector<std::unique_ptr<Element>> elements;
2926 do {
2927 if (failed(parseCustomDirectiveParameter(elements)))
2928 return ::mlir::failure();
2929 if (curToken.getKind() != Token::comma)
2930 break;
2931 consumeToken();
2932 } while (true);
2933
2934 if (failed(parseToken(Token::r_paren,
2935 "expected ')' after custom directive parameters")))
2936 return ::mlir::failure();
2937
2938 // After parsing all of the elements, ensure that all type directives refer
2939 // only to variables.
2940 for (auto &ele : elements) {
2941 if (auto *typeEle = dyn_cast<TypeRefDirective>(ele.get())) {
2942 if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
2943 return emitError(curLoc,
2944 "type_ref directives within a custom directive "
2945 "may only refer to variables");
2946 }
2947 }
2948 if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
2949 if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
2950 return emitError(curLoc, "type directives within a custom directive "
2951 "may only refer to variables");
2952 }
2953 }
2954 }
2955
2956 element = std::make_unique<CustomDirective>(nameTok.getSpelling(),
2957 std::move(elements));
2958 return ::mlir::success();
2959 }
2960
parseCustomDirectiveParameter(std::vector<std::unique_ptr<Element>> & parameters)2961 LogicalResult FormatParser::parseCustomDirectiveParameter(
2962 std::vector<std::unique_ptr<Element>> ¶meters) {
2963 llvm::SMLoc childLoc = curToken.getLoc();
2964 parameters.push_back({});
2965 if (failed(parseElement(parameters.back(), /*isTopLevel=*/true)))
2966 return ::mlir::failure();
2967
2968 // Verify that the element can be placed within a custom directive.
2969 if (!isa<TypeRefDirective, TypeDirective, AttrDictDirective,
2970 AttributeVariable, OperandVariable, RegionVariable,
2971 SuccessorVariable>(parameters.back().get())) {
2972 return emitError(childLoc, "only variables and types may be used as "
2973 "parameters to a custom directive");
2974 }
2975 return ::mlir::success();
2976 }
2977
2978 LogicalResult
parseFunctionalTypeDirective(std::unique_ptr<Element> & element,Token tok,bool isTopLevel)2979 FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
2980 Token tok, bool isTopLevel) {
2981 llvm::SMLoc loc = tok.getLoc();
2982 if (!isTopLevel)
2983 return emitError(
2984 loc, "'functional-type' is only valid as a top-level directive");
2985
2986 // Parse the main operand.
2987 std::unique_ptr<Element> inputs, results;
2988 if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
2989 failed(parseTypeDirectiveOperand(inputs)) ||
2990 failed(parseToken(Token::comma, "expected ',' after inputs argument")) ||
2991 failed(parseTypeDirectiveOperand(results)) ||
2992 failed(parseToken(Token::r_paren, "expected ')' after argument list")))
2993 return ::mlir::failure();
2994 element = std::make_unique<FunctionalTypeDirective>(std::move(inputs),
2995 std::move(results));
2996 return ::mlir::success();
2997 }
2998
2999 LogicalResult
parseOperandsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)3000 FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
3001 llvm::SMLoc loc, bool isTopLevel) {
3002 if (isTopLevel) {
3003 if (fmt.allOperands || !seenOperands.empty())
3004 return emitError(loc, "'operands' directive creates overlap in format");
3005 fmt.allOperands = true;
3006 }
3007 element = std::make_unique<OperandsDirective>();
3008 return ::mlir::success();
3009 }
3010
3011 LogicalResult
parseRegionsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)3012 FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element,
3013 llvm::SMLoc loc, bool isTopLevel) {
3014 if (!isTopLevel)
3015 return emitError(loc, "'regions' is only valid as a top-level directive");
3016 if (hasAllRegions || !seenRegions.empty())
3017 return emitError(loc, "'regions' directive creates overlap in format");
3018 hasAllRegions = true;
3019 element = std::make_unique<RegionsDirective>();
3020 return ::mlir::success();
3021 }
3022
3023 LogicalResult
parseResultsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)3024 FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
3025 llvm::SMLoc loc, bool isTopLevel) {
3026 if (isTopLevel)
3027 return emitError(loc, "'results' directive can not be used as a "
3028 "top-level directive");
3029 element = std::make_unique<ResultsDirective>();
3030 return ::mlir::success();
3031 }
3032
3033 LogicalResult
parseSuccessorsDirective(std::unique_ptr<Element> & element,llvm::SMLoc loc,bool isTopLevel)3034 FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
3035 llvm::SMLoc loc, bool isTopLevel) {
3036 if (!isTopLevel)
3037 return emitError(loc,
3038 "'successors' is only valid as a top-level directive");
3039 if (hasAllSuccessors || !seenSuccessors.empty())
3040 return emitError(loc, "'successors' directive creates overlap in format");
3041 hasAllSuccessors = true;
3042 element = std::make_unique<SuccessorsDirective>();
3043 return ::mlir::success();
3044 }
3045
3046 LogicalResult
parseTypeDirective(std::unique_ptr<Element> & element,Token tok,bool isTopLevel,bool isTypeRef)3047 FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
3048 bool isTopLevel, bool isTypeRef) {
3049 llvm::SMLoc loc = tok.getLoc();
3050 if (!isTopLevel)
3051 return emitError(loc, "'type' is only valid as a top-level directive");
3052
3053 std::unique_ptr<Element> operand;
3054 if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
3055 failed(parseTypeDirectiveOperand(operand, isTypeRef)) ||
3056 failed(parseToken(Token::r_paren, "expected ')' after argument list")))
3057 return ::mlir::failure();
3058 if (isTypeRef)
3059 element = std::make_unique<TypeRefDirective>(std::move(operand));
3060 else
3061 element = std::make_unique<TypeDirective>(std::move(operand));
3062 return ::mlir::success();
3063 }
3064
3065 LogicalResult
parseTypeDirectiveOperand(std::unique_ptr<Element> & element,bool isTypeRef)3066 FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
3067 bool isTypeRef) {
3068 llvm::SMLoc loc = curToken.getLoc();
3069 if (failed(parseElement(element, /*isTopLevel=*/false)))
3070 return ::mlir::failure();
3071 if (isa<LiteralElement>(element.get()))
3072 return emitError(
3073 loc, "'type' directive operand expects variable or directive operand");
3074
3075 if (auto *var = dyn_cast<OperandVariable>(element.get())) {
3076 unsigned opIdx = var->getVar() - op.operand_begin();
3077 if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3078 return emitError(loc, "'type' of '" + var->getVar()->name +
3079 "' is already bound");
3080 if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3081 return emitError(loc, "'type_ref' of '" + var->getVar()->name +
3082 "' is not bound by a prior 'type' directive");
3083 seenOperandTypes.set(opIdx);
3084 } else if (auto *var = dyn_cast<ResultVariable>(element.get())) {
3085 unsigned resIdx = var->getVar() - op.result_begin();
3086 if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
3087 return emitError(loc, "'type' of '" + var->getVar()->name +
3088 "' is already bound");
3089 if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
3090 return emitError(loc, "'type_ref' of '" + var->getVar()->name +
3091 "' is not bound by a prior 'type' directive");
3092 seenResultTypes.set(resIdx);
3093 } else if (isa<OperandsDirective>(&*element)) {
3094 if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.any()))
3095 return emitError(loc, "'operands' 'type' is already bound");
3096 if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.all()))
3097 return emitError(
3098 loc,
3099 "'operands' 'type_ref' is not bound by a prior 'type' directive");
3100 fmt.allOperandTypes = true;
3101 } else if (isa<ResultsDirective>(&*element)) {
3102 if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.any()))
3103 return emitError(loc, "'results' 'type' is already bound");
3104 if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.all()))
3105 return emitError(
3106 loc, "'results' 'type_ref' is not bound by a prior 'type' directive");
3107 fmt.allResultTypes = true;
3108 } else {
3109 return emitError(loc, "invalid argument to 'type' directive");
3110 }
3111 return ::mlir::success();
3112 }
3113
3114 //===----------------------------------------------------------------------===//
3115 // Interface
3116 //===----------------------------------------------------------------------===//
3117
generateOpFormat(const Operator & constOp,OpClass & opClass)3118 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
3119 // TODO: Operator doesn't expose all necessary functionality via
3120 // the const interface.
3121 Operator &op = const_cast<Operator &>(constOp);
3122 if (!op.hasAssemblyFormat())
3123 return;
3124
3125 // Parse the format description.
3126 llvm::SourceMgr mgr;
3127 mgr.AddNewSourceBuffer(
3128 llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), llvm::SMLoc());
3129 OperationFormat format(op);
3130 if (failed(FormatParser(mgr, format, op).parse())) {
3131 // Exit the process if format errors are treated as fatal.
3132 if (formatErrorIsFatal) {
3133 // Invoke the interrupt handlers to run the file cleanup handlers.
3134 llvm::sys::RunInterruptHandlers();
3135 std::exit(1);
3136 }
3137 return;
3138 }
3139
3140 // Generate the printer and parser based on the parsed format.
3141 format.genParser(op, opClass);
3142 format.genPrinter(op, opClass);
3143 }
3144