1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "OpFormatGen.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Interfaces.h"
19 #include "mlir/TableGen/OpClass.h"
20 #include "mlir/TableGen/OpTrait.h"
21 #include "mlir/TableGen/Operator.h"
22 #include "mlir/TableGen/SideEffects.h"
23 #include "llvm/ADT/Sequence.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Regex.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Record.h"
30 #include "llvm/TableGen/TableGenBackend.h"
31 
32 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
33 
34 using namespace llvm;
35 using namespace mlir;
36 using namespace mlir::tblgen;
37 
38 cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls");
39 
40 static cl::opt<std::string> opIncFilter(
41     "op-include-regex",
42     cl::desc("Regex of name of op's to include (no filter if empty)"),
43     cl::cat(opDefGenCat));
44 static cl::opt<std::string> opExcFilter(
45     "op-exclude-regex",
46     cl::desc("Regex of name of op's to exclude (no filter if empty)"),
47     cl::cat(opDefGenCat));
48 
49 static const char *const tblgenNamePrefix = "tblgen_";
50 static const char *const generatedArgName = "odsArg";
51 static const char *const odsBuilder = "odsBuilder";
52 static const char *const builderOpState = "odsState";
53 
54 // The logic to calculate the actual value range for a declared operand/result
55 // of an op with variadic operands/results. Note that this logic is not for
56 // general use; it assumes all variadic operands/results must have the same
57 // number of values.
58 //
59 // {0}: The list of whether each declared operand/result is variadic.
60 // {1}: The total number of non-variadic operands/results.
61 // {2}: The total number of variadic operands/results.
62 // {3}: The total number of actual values.
63 // {4}: "operand" or "result".
64 const char *sameVariadicSizeValueRangeCalcCode = R"(
65   bool isVariadic[] = {{{0}};
66   int prevVariadicCount = 0;
67   for (unsigned i = 0; i < index; ++i)
68     if (isVariadic[i]) ++prevVariadicCount;
69 
70   // Calculate how many dynamic values a static variadic {4} corresponds to.
71   // This assumes all static variadic {4}s have the same dynamic value count.
72   int variadicSize = ({3} - {1}) / {2};
73   // `index` passed in as the parameter is the static index which counts each
74   // {4} (variadic or not) as size 1. So here for each previous static variadic
75   // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
76   // value pack for this static {4} starts.
77   int start = index + (variadicSize - 1) * prevVariadicCount;
78   int size = isVariadic[index] ? variadicSize : 1;
79   return {{start, size};
80 )";
81 
82 // The logic to calculate the actual value range for a declared operand/result
83 // of an op with variadic operands/results. Note that this logic is assumes
84 // the op has an attribute specifying the size of each operand/result segment
85 // (variadic or not).
86 //
87 // {0}: The name of the attribute specifying the segment sizes.
88 const char *adapterSegmentSizeAttrInitCode = R"(
89   assert(odsAttrs && "missing segment size attribute for op");
90   auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
91 )";
92 const char *opSegmentSizeAttrInitCode = R"(
93   auto sizeAttr = (*this)->getAttrOfType<::mlir::DenseIntElementsAttr>("{0}");
94 )";
95 const char *attrSizedSegmentValueRangeCalcCode = R"(
96   unsigned start = 0;
97   for (unsigned i = 0; i < index; ++i)
98     start += (*(sizeAttr.begin() + i)).getZExtValue();
99   unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
100   return {start, size};
101 )";
102 
103 // The logic to build a range of either operand or result values.
104 //
105 // {0}: The begin iterator of the actual values.
106 // {1}: The call to generate the start and length of the value range.
107 const char *valueRangeReturnCode = R"(
108   auto valueRange = {1};
109   return {{std::next({0}, valueRange.first),
110            std::next({0}, valueRange.first + valueRange.second)};
111 )";
112 
113 static const char *const opCommentHeader = R"(
114 //===----------------------------------------------------------------------===//
115 // {0} {1}
116 //===----------------------------------------------------------------------===//
117 
118 )";
119 
120 //===----------------------------------------------------------------------===//
121 // StaticVerifierFunctionEmitter
122 //===----------------------------------------------------------------------===//
123 
124 namespace {
125 /// This class deduplicates shared operation verification code by emitting
126 /// static functions alongside the op definitions. These methods are local to
127 /// the definition file, and are invoked within the operation verify methods.
128 /// An example is shown below:
129 ///
130 /// static LogicalResult localVerify(...)
131 ///
132 /// LogicalResult OpA::verify(...) {
133 ///  if (failed(localVerify(...)))
134 ///    return failure();
135 ///  ...
136 /// }
137 ///
138 /// LogicalResult OpB::verify(...) {
139 ///  if (failed(localVerify(...)))
140 ///    return failure();
141 ///  ...
142 /// }
143 ///
144 class StaticVerifierFunctionEmitter {
145 public:
146   StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
147                                 ArrayRef<llvm::Record *> opDefs,
148                                 raw_ostream &os, bool emitDecl);
149 
150   /// Get the name of the local function used for the given type constraint.
151   /// These functions are used for operand and result constraints and have the
152   /// form:
153   ///   LogicalResult(Operation *op, Type type, StringRef valueKind,
154   ///                 unsigned valueGroupStartIndex);
getTypeConstraintFn(const Constraint & constraint) const155   StringRef getTypeConstraintFn(const Constraint &constraint) const {
156     auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
157     assert(it != localTypeConstraints.end() && "expected valid constraint fn");
158     return it->second;
159   }
160 
161 private:
162   /// Returns a unique name to use when generating local methods.
163   static std::string getUniqueName(const llvm::RecordKeeper &records);
164 
165   /// Emit local methods for the type constraints used within the provided op
166   /// definitions.
167   void emitTypeConstraintMethods(ArrayRef<llvm::Record *> opDefs,
168                                  raw_ostream &os, bool emitDecl);
169 
170   /// A unique label for the file currently being generated. This is used to
171   /// ensure that the local functions have a unique name.
172   std::string uniqueOutputLabel;
173 
174   /// A set of functions implementing type constraints, used for operand and
175   /// result verification.
176   llvm::DenseMap<const void *, std::string> localTypeConstraints;
177 };
178 } // namespace
179 
StaticVerifierFunctionEmitter(const llvm::RecordKeeper & records,ArrayRef<llvm::Record * > opDefs,raw_ostream & os,bool emitDecl)180 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
181     const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
182     raw_ostream &os, bool emitDecl)
183     : uniqueOutputLabel(getUniqueName(records)) {
184   llvm::Optional<NamespaceEmitter> namespaceEmitter;
185   if (!emitDecl) {
186     os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
187     namespaceEmitter.emplace(os, Operator(*opDefs[0]).getDialect());
188   }
189 
190   emitTypeConstraintMethods(opDefs, os, emitDecl);
191 }
192 
getUniqueName(const llvm::RecordKeeper & records)193 std::string StaticVerifierFunctionEmitter::getUniqueName(
194     const llvm::RecordKeeper &records) {
195   // Use the input file name when generating a unique name.
196   std::string inputFilename = records.getInputFilename();
197 
198   // Drop all but the base filename.
199   StringRef nameRef = llvm::sys::path::filename(inputFilename);
200   nameRef.consume_back(".td");
201 
202   // Sanitize any invalid characters.
203   std::string uniqueName;
204   for (char c : nameRef) {
205     if (llvm::isAlnum(c) || c == '_')
206       uniqueName.push_back(c);
207     else
208       uniqueName.append(llvm::utohexstr((unsigned char)c));
209   }
210   return uniqueName;
211 }
212 
emitTypeConstraintMethods(ArrayRef<llvm::Record * > opDefs,raw_ostream & os,bool emitDecl)213 void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
214     ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
215   // Collect a set of all of the used type constraints within the operation
216   // definitions.
217   llvm::SetVector<const void *> typeConstraints;
218   for (Record *def : opDefs) {
219     Operator op(*def);
220     for (NamedTypeConstraint &operand : op.getOperands())
221       if (operand.hasPredicate())
222         typeConstraints.insert(operand.constraint.getAsOpaquePointer());
223     for (NamedTypeConstraint &result : op.getResults())
224       if (result.hasPredicate())
225         typeConstraints.insert(result.constraint.getAsOpaquePointer());
226   }
227 
228   FmtContext fctx;
229   for (auto it : llvm::enumerate(typeConstraints)) {
230     // Generate an obscure and unique name for this type constraint.
231     std::string name = (Twine("__mlir_ods_local_type_constraint_") +
232                         uniqueOutputLabel + Twine(it.index()))
233                            .str();
234     localTypeConstraints.try_emplace(it.value(), name);
235 
236     // Only generate the methods if we are generating definitions.
237     if (emitDecl)
238       continue;
239 
240     Constraint constraint = Constraint::getFromOpaquePointer(it.value());
241     os << "static ::mlir::LogicalResult " << name
242        << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
243           "valueKind, unsigned valueGroupStartIndex) {\n";
244 
245     os << "  if (!("
246        << tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type"))
247        << ")) {\n"
248        << formatv(
249               "    return op->emitOpError(valueKind) << \" #\" << "
250               "valueGroupStartIndex << \" must be {0}, but got \" << type;\n",
251               constraint.getSummary())
252        << "  }\n"
253        << "  return ::mlir::success();\n"
254        << "}\n\n";
255   }
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // Utility structs and functions
260 //===----------------------------------------------------------------------===//
261 
262 // Replaces all occurrences of `match` in `str` with `substitute`.
replaceAllSubstrs(std::string str,const std::string & match,const std::string & substitute)263 static std::string replaceAllSubstrs(std::string str, const std::string &match,
264                                      const std::string &substitute) {
265   std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
266   while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
267     str = str.replace(matchLoc, match.size(), substitute);
268     scanLoc = matchLoc + substitute.size();
269   }
270   return str;
271 }
272 
273 // Returns whether the record has a value of the given name that can be returned
274 // via getValueAsString.
hasStringAttribute(const Record & record,StringRef fieldName)275 static inline bool hasStringAttribute(const Record &record,
276                                       StringRef fieldName) {
277   auto valueInit = record.getValueInit(fieldName);
278   return isa<StringInit>(valueInit);
279 }
280 
getArgumentName(const Operator & op,int index)281 static std::string getArgumentName(const Operator &op, int index) {
282   const auto &operand = op.getOperand(index);
283   if (!operand.name.empty())
284     return std::string(operand.name);
285   else
286     return std::string(formatv("{0}_{1}", generatedArgName, index));
287 }
288 
289 // Returns true if we can use unwrapped value for the given `attr` in builders.
canUseUnwrappedRawValue(const tblgen::Attribute & attr)290 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
291   return attr.getReturnType() != attr.getStorageType() &&
292          // We need to wrap the raw value into an attribute in the builder impl
293          // so we need to make sure that the attribute specifies how to do that.
294          !attr.getConstBuilderTemplate().empty();
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // Op emitter
299 //===----------------------------------------------------------------------===//
300 
301 namespace {
302 // Helper class to emit a record into the given output stream.
303 class OpEmitter {
304 public:
305   static void
306   emitDecl(const Operator &op, raw_ostream &os,
307            const StaticVerifierFunctionEmitter &staticVerifierEmitter);
308   static void
309   emitDef(const Operator &op, raw_ostream &os,
310           const StaticVerifierFunctionEmitter &staticVerifierEmitter);
311 
312 private:
313   OpEmitter(const Operator &op,
314             const StaticVerifierFunctionEmitter &staticVerifierEmitter);
315 
316   void emitDecl(raw_ostream &os);
317   void emitDef(raw_ostream &os);
318 
319   // Generates the OpAsmOpInterface for this operation if possible.
320   void genOpAsmInterface();
321 
322   // Generates the `getOperationName` method for this op.
323   void genOpNameGetter();
324 
325   // Generates getters for the attributes.
326   void genAttrGetters();
327 
328   // Generates setter for the attributes.
329   void genAttrSetters();
330 
331   // Generates removers for optional attributes.
332   void genOptionalAttrRemovers();
333 
334   // Generates getters for named operands.
335   void genNamedOperandGetters();
336 
337   // Generates setters for named operands.
338   void genNamedOperandSetters();
339 
340   // Generates getters for named results.
341   void genNamedResultGetters();
342 
343   // Generates getters for named regions.
344   void genNamedRegionGetters();
345 
346   // Generates getters for named successors.
347   void genNamedSuccessorGetters();
348 
349   // Generates builder methods for the operation.
350   void genBuilder();
351 
352   // Generates the build() method that takes each operand/attribute
353   // as a stand-alone parameter.
354   void genSeparateArgParamBuilder();
355 
356   // Generates the build() method that takes each operand/attribute as a
357   // stand-alone parameter. The generated build() method uses first operand's
358   // type as all results' types.
359   void genUseOperandAsResultTypeSeparateParamBuilder();
360 
361   // Generates the build() method that takes all operands/attributes
362   // collectively as one parameter. The generated build() method uses first
363   // operand's type as all results' types.
364   void genUseOperandAsResultTypeCollectiveParamBuilder();
365 
366   // Generates the build() method that takes aggregate operands/attributes
367   // parameters. This build() method uses inferred types as result types.
368   // Requires: The type needs to be inferable via InferTypeOpInterface.
369   void genInferredTypeCollectiveParamBuilder();
370 
371   // Generates the build() method that takes each operand/attribute as a
372   // stand-alone parameter. The generated build() method uses first attribute's
373   // type as all result's types.
374   void genUseAttrAsResultTypeBuilder();
375 
376   // Generates the build() method that takes all result types collectively as
377   // one parameter. Similarly for operands and attributes.
378   void genCollectiveParamBuilder();
379 
380   // The kind of parameter to generate for result types in builders.
381   enum class TypeParamKind {
382     None,       // No result type in parameter list.
383     Separate,   // A separate parameter for each result type.
384     Collective, // An ArrayRef<Type> for all result types.
385   };
386 
387   // The kind of parameter to generate for attributes in builders.
388   enum class AttrParamKind {
389     WrappedAttr,    // A wrapped MLIR Attribute instance.
390     UnwrappedValue, // A raw value without MLIR Attribute wrapper.
391   };
392 
393   // Builds the parameter list for build() method of this op. This method writes
394   // to `paramList` the comma-separated parameter list and updates
395   // `resultTypeNames` with the names for parameters for specifying result
396   // types. The given `typeParamKind` and `attrParamKind` controls how result
397   // types and attributes are placed in the parameter list.
398   void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> &paramList,
399                       SmallVectorImpl<std::string> &resultTypeNames,
400                       TypeParamKind typeParamKind,
401                       AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
402 
403   // Adds op arguments and regions into operation state for build() methods.
404   void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
405                                               bool isRawValueAttr = false);
406 
407   // Generates canonicalizer declaration for the operation.
408   void genCanonicalizerDecls();
409 
410   // Generates the folder declaration for the operation.
411   void genFolderDecls();
412 
413   // Generates the parser for the operation.
414   void genParser();
415 
416   // Generates the printer for the operation.
417   void genPrinter();
418 
419   // Generates verify method for the operation.
420   void genVerifier();
421 
422   // Generates verify statements for operands and results in the operation.
423   // The generated code will be attached to `body`.
424   void genOperandResultVerifier(OpMethodBody &body,
425                                 Operator::value_range values,
426                                 StringRef valueKind);
427 
428   // Generates verify statements for regions in the operation.
429   // The generated code will be attached to `body`.
430   void genRegionVerifier(OpMethodBody &body);
431 
432   // Generates verify statements for successors in the operation.
433   // The generated code will be attached to `body`.
434   void genSuccessorVerifier(OpMethodBody &body);
435 
436   // Generates the traits used by the object.
437   void genTraits();
438 
439   // Generate the OpInterface methods for all interfaces.
440   void genOpInterfaceMethods();
441 
442   // Generate op interface methods for the given interface.
443   void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait);
444 
445   // Generate op interface method for the given interface method. If
446   // 'declaration' is true, generates a declaration, else a definition.
447   OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
448                                  bool declaration = true);
449 
450   // Generate the side effect interface methods.
451   void genSideEffectInterfaceMethods();
452 
453   // Generate the type inference interface methods.
454   void genTypeInterfaceMethods();
455 
456 private:
457   // The TableGen record for this op.
458   // TODO: OpEmitter should not have a Record directly,
459   // it should rather go through the Operator for better abstraction.
460   const Record &def;
461 
462   // The wrapper operator class for querying information from this op.
463   Operator op;
464 
465   // The C++ code builder for this op
466   OpClass opClass;
467 
468   // The format context for verification code generation.
469   FmtContext verifyCtx;
470 
471   // The emitter containing all of the locally emitted verification functions.
472   const StaticVerifierFunctionEmitter &staticVerifierEmitter;
473 };
474 } // end anonymous namespace
475 
476 // Populate the format context `ctx` with substitutions of attributes, operands
477 // and results.
478 // - attrGet corresponds to the name of the function to call to get value of
479 //   attribute (the generated function call returns an Attribute);
480 // - operandGet corresponds to the name of the function with which to retrieve
481 //   an operand (the generated function call returns an OperandRange);
482 // - resultGet corresponds to the name of the function to get an result (the
483 //   generated function call returns a ValueRange);
populateSubstitutions(const Operator & op,const char * attrGet,const char * operandGet,const char * resultGet,FmtContext & ctx)484 static void populateSubstitutions(const Operator &op, const char *attrGet,
485                                   const char *operandGet, const char *resultGet,
486                                   FmtContext &ctx) {
487   // Populate substitutions for attributes and named operands.
488   for (const auto &namedAttr : op.getAttributes())
489     ctx.addSubst(namedAttr.name,
490                  formatv("{0}(\"{1}\")", attrGet, namedAttr.name));
491   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
492     auto &value = op.getOperand(i);
493     if (value.name.empty())
494       continue;
495 
496     if (value.isVariadic())
497       ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i));
498     else
499       ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i));
500   }
501 
502   // Populate substitutions for results.
503   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
504     auto &value = op.getResult(i);
505     if (value.name.empty())
506       continue;
507 
508     if (value.isVariadic())
509       ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i));
510     else
511       ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i));
512   }
513 }
514 
515 // Generate attribute verification. If emitVerificationRequiringOp is set then
516 // only verification for attributes whose value depend on op being known are
517 // emitted, else only verification that doesn't depend on the op being known are
518 // generated.
519 // - emitErrorPrefix is the prefix for the error emitting call which consists
520 //   of the entire function call up to start of error message fragment;
521 // - emitVerificationRequiringOp specifies whether verification should be
522 //   emitted for verification that require the op to exist;
genAttributeVerifier(const Operator & op,const char * attrGet,const Twine & emitErrorPrefix,bool emitVerificationRequiringOp,FmtContext & ctx,OpMethodBody & body)523 static void genAttributeVerifier(const Operator &op, const char *attrGet,
524                                  const Twine &emitErrorPrefix,
525                                  bool emitVerificationRequiringOp,
526                                  FmtContext &ctx, OpMethodBody &body) {
527   for (const auto &namedAttr : op.getAttributes()) {
528     const auto &attr = namedAttr.attr;
529     if (attr.isDerivedAttr())
530       continue;
531 
532     auto attrName = namedAttr.name;
533     bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
534     auto attrPred = attr.getPredicate();
535     auto condition = attrPred.isNull() ? "" : attrPred.getCondition();
536     // There is a condition to emit only if the use of $_op and whether to
537     // emit verifications for op matches.
538     bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^
539                                emitVerificationRequiringOp);
540 
541     // Prefix with `tblgen_` to avoid hiding the attribute accessor.
542     auto varName = tblgenNamePrefix + attrName;
543 
544     // If the attribute is
545     //  1. Required (not allowed missing) and not in op verification, or
546     //  2. Has a condition that will get verified
547     // then the variable will be used.
548     //
549     // Therefore, for optional attributes whose verification requires that an
550     // op already exists for verification/emitVerificationRequiringOp is set
551     // has nothing that can be verified here.
552     if ((allowMissingAttr || emitVerificationRequiringOp) &&
553         !hasConditionToEmit)
554       continue;
555 
556     body << formatv("  {\n  auto {0} = {1}(\"{2}\");\n", varName, attrGet,
557                     attrName);
558 
559     if (!emitVerificationRequiringOp && !allowMissingAttr) {
560       body << "  if (!" << varName << ") return " << emitErrorPrefix
561            << "\"requires attribute '" << attrName << "'\");\n";
562     }
563 
564     if (!hasConditionToEmit) {
565       body << "  }\n";
566       continue;
567     }
568 
569     if (allowMissingAttr) {
570       // If the attribute has a default value, then only verify the predicate if
571       // set. This does effectively assume that the default value is valid.
572       // TODO: verify the debug value is valid (perhaps in debug mode only).
573       body << "  if (" << varName << ") {\n";
574     }
575 
576     body << tgfmt("    if (!($0)) return $1\"attribute '$2' "
577                   "failed to satisfy constraint: $3\");\n",
578                   /*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)),
579                   emitErrorPrefix, attrName, attr.getSummary());
580     if (allowMissingAttr)
581       body << "  }\n";
582     body << "  }\n";
583   }
584 }
585 
OpEmitter(const Operator & op,const StaticVerifierFunctionEmitter & staticVerifierEmitter)586 OpEmitter::OpEmitter(const Operator &op,
587                      const StaticVerifierFunctionEmitter &staticVerifierEmitter)
588     : def(op.getDef()), op(op),
589       opClass(op.getCppClassName(), op.getExtraClassDeclaration()),
590       staticVerifierEmitter(staticVerifierEmitter) {
591   verifyCtx.withOp("(*this->getOperation())");
592 
593   genTraits();
594 
595   // Generate C++ code for various op methods. The order here determines the
596   // methods in the generated file.
597   genOpAsmInterface();
598   genOpNameGetter();
599   genNamedOperandGetters();
600   genNamedOperandSetters();
601   genNamedResultGetters();
602   genNamedRegionGetters();
603   genNamedSuccessorGetters();
604   genAttrGetters();
605   genAttrSetters();
606   genOptionalAttrRemovers();
607   genBuilder();
608   genParser();
609   genPrinter();
610   genVerifier();
611   genCanonicalizerDecls();
612   genFolderDecls();
613   genTypeInterfaceMethods();
614   genOpInterfaceMethods();
615   generateOpFormat(op, opClass);
616   genSideEffectInterfaceMethods();
617 }
618 
emitDecl(const Operator & op,raw_ostream & os,const StaticVerifierFunctionEmitter & staticVerifierEmitter)619 void OpEmitter::emitDecl(
620     const Operator &op, raw_ostream &os,
621     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
622   OpEmitter(op, staticVerifierEmitter).emitDecl(os);
623 }
624 
emitDef(const Operator & op,raw_ostream & os,const StaticVerifierFunctionEmitter & staticVerifierEmitter)625 void OpEmitter::emitDef(
626     const Operator &op, raw_ostream &os,
627     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
628   OpEmitter(op, staticVerifierEmitter).emitDef(os);
629 }
630 
emitDecl(raw_ostream & os)631 void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
632 
emitDef(raw_ostream & os)633 void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
634 
genAttrGetters()635 void OpEmitter::genAttrGetters() {
636   FmtContext fctx;
637   fctx.withBuilder("::mlir::Builder((*this)->getContext())");
638 
639   Dialect opDialect = op.getDialect();
640   // Emit the derived attribute body.
641   auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
642     auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
643     if (!method)
644       return;
645     auto &body = method->body();
646     body << "  " << attr.getDerivedCodeBody() << "\n";
647   };
648 
649   // Emit with return type specified.
650   auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
651     auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
652     auto &body = method->body();
653     body << "  auto attr = " << name << "Attr();\n";
654     if (attr.hasDefaultValue()) {
655       // Returns the default value if not set.
656       // TODO: this is inefficient, we are recreating the attribute for every
657       // call. This should be set instead.
658       std::string defaultValue = std::string(
659           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
660       body << "    if (!attr)\n      return "
661            << tgfmt(attr.getConvertFromStorageCall(),
662                     &fctx.withSelf(defaultValue))
663            << ";\n";
664     }
665     body << "  return "
666          << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
667          << ";\n";
668   };
669 
670   // Generate raw named accessor type. This is a wrapper class that allows
671   // referring to the attributes via accessors instead of having to use
672   // the string interface for better compile time verification.
673   auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
674     auto *method =
675         opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
676     if (!method)
677       return;
678     auto &body = method->body();
679     body << "  return (*this)->getAttr(\"" << name << "\").template ";
680     if (attr.isOptional() || attr.hasDefaultValue())
681       body << "dyn_cast_or_null<";
682     else
683       body << "cast<";
684     body << attr.getStorageType() << ">();";
685   };
686 
687   for (auto &namedAttr : op.getAttributes()) {
688     const auto &name = namedAttr.name;
689     const auto &attr = namedAttr.attr;
690     if (attr.isDerivedAttr()) {
691       emitDerivedAttr(name, attr);
692     } else {
693       emitAttrWithStorageType(name, attr);
694       emitAttrWithReturnType(name, attr);
695     }
696   }
697 
698   auto derivedAttrs = make_filter_range(op.getAttributes(),
699                                         [](const NamedAttribute &namedAttr) {
700                                           return namedAttr.attr.isDerivedAttr();
701                                         });
702   if (!derivedAttrs.empty()) {
703     opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
704     // Generate helper method to query whether a named attribute is a derived
705     // attribute. This enables, for example, avoiding adding an attribute that
706     // overlaps with a derived attribute.
707     {
708       auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute",
709                                                OpMethod::MP_Static,
710                                                "::llvm::StringRef", "name");
711       auto &body = method->body();
712       for (auto namedAttr : derivedAttrs)
713         body << "  if (name == \"" << namedAttr.name << "\") return true;\n";
714       body << " return false;";
715     }
716     // Generate method to materialize derived attributes as a DictionaryAttr.
717     {
718       auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr",
719                                                "materializeDerivedAttributes");
720       auto &body = method->body();
721 
722       auto nonMaterializable =
723           make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
724             return namedAttr.attr.getConvertFromStorageCall().empty();
725           });
726       if (!nonMaterializable.empty()) {
727         std::string attrs;
728         llvm::raw_string_ostream os(attrs);
729         interleaveComma(nonMaterializable, os,
730                         [&](const NamedAttribute &attr) { os << attr.name; });
731         PrintWarning(
732             op.getLoc(),
733             formatv(
734                 "op has non-materializable derived attributes '{0}', skipping",
735                 os.str()));
736         body << formatv("  emitOpError(\"op has non-materializable derived "
737                         "attributes '{0}'\");\n",
738                         attrs);
739         body << "  return nullptr;";
740         return;
741       }
742 
743       body << "  ::mlir::MLIRContext* ctx = getContext();\n";
744       body << "  ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
745       body << "  return ::mlir::DictionaryAttr::get({\n";
746       interleave(
747           derivedAttrs, body,
748           [&](const NamedAttribute &namedAttr) {
749             auto tmpl = namedAttr.attr.getConvertFromStorageCall();
750             body << "    {::mlir::Identifier::get(\"" << namedAttr.name
751                  << "\", ctx),\n"
752                  << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()")
753                                      .withBuilder("odsBuilder")
754                                      .addSubst("_ctx", "ctx"))
755                  << "}";
756           },
757           ",\n");
758       body << "\n    }, ctx);";
759     }
760   }
761 }
762 
genAttrSetters()763 void OpEmitter::genAttrSetters() {
764   // Generate raw named setter type. This is a wrapper class that allows setting
765   // to the attributes via setters instead of having to use the string interface
766   // for better compile time verification.
767   auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
768     auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
769                                              attr.getStorageType(), "attr");
770     if (!method)
771       return;
772     auto &body = method->body();
773     body << "  (*this)->setAttr(\"" << name << "\", attr);";
774   };
775 
776   for (auto &namedAttr : op.getAttributes()) {
777     const auto &name = namedAttr.name;
778     const auto &attr = namedAttr.attr;
779     if (!attr.isDerivedAttr())
780       emitAttrWithStorageType(name, attr);
781   }
782 }
783 
genOptionalAttrRemovers()784 void OpEmitter::genOptionalAttrRemovers() {
785   // Generate methods for removing optional attributes, instead of having to
786   // use the string interface. Enables better compile time verification.
787   auto emitRemoveAttr = [&](StringRef name) {
788     auto upperInitial = name.take_front().upper();
789     auto suffix = name.drop_front();
790     auto *method = opClass.addMethodAndPrune(
791         "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
792     if (!method)
793       return;
794     auto &body = method->body();
795     body << "  return (*this)->removeAttr(\"" << name << "\");";
796   };
797 
798   for (const auto &namedAttr : op.getAttributes()) {
799     const auto &name = namedAttr.name;
800     const auto &attr = namedAttr.attr;
801     if (attr.isOptional())
802       emitRemoveAttr(name);
803   }
804 }
805 
806 // Generates the code to compute the start and end index of an operand or result
807 // range.
808 template <typename RangeT>
809 static void
generateValueRangeStartAndEnd(Class & opClass,StringRef methodName,int numVariadic,int numNonVariadic,StringRef rangeSizeCall,bool hasAttrSegmentSize,StringRef sizeAttrInit,RangeT && odsValues)810 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
811                               int numVariadic, int numNonVariadic,
812                               StringRef rangeSizeCall, bool hasAttrSegmentSize,
813                               StringRef sizeAttrInit, RangeT &&odsValues) {
814   auto *method = opClass.addMethodAndPrune("std::pair<unsigned, unsigned>",
815                                            methodName, "unsigned", "index");
816   if (!method)
817     return;
818   auto &body = method->body();
819   if (numVariadic == 0) {
820     body << "  return {index, 1};\n";
821   } else if (hasAttrSegmentSize) {
822     body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
823   } else {
824     // Because the op can have arbitrarily interleaved variadic and non-variadic
825     // operands, we need to embed a list in the "sink" getter method for
826     // calculation at run-time.
827     llvm::SmallVector<StringRef, 4> isVariadic;
828     isVariadic.reserve(llvm::size(odsValues));
829     for (auto &it : odsValues)
830       isVariadic.push_back(it.isVariableLength() ? "true" : "false");
831     std::string isVariadicList = llvm::join(isVariadic, ", ");
832     body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
833                     numNonVariadic, numVariadic, rangeSizeCall, "operand");
834   }
835 }
836 
837 // Generates the named operand getter methods for the given Operator `op` and
838 // puts them in `opClass`.  Uses `rangeType` as the return type of getters that
839 // return a range of operands (individual operands are `Value ` and each
840 // element in the range must also be `Value `); use `rangeBeginCall` to get
841 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
842 // obtain the number of operands. `getOperandCallPattern` contains the code
843 // necessary to obtain a single operand whose position will be substituted
844 // instead of
845 // "{0}" marker in the pattern.  Note that the pattern should work for any kind
846 // of ops, in particular for one-operand ops that may not have the
847 // `getOperand(unsigned)` method.
generateNamedOperandGetters(const Operator & op,Class & opClass,StringRef sizeAttrInit,StringRef rangeType,StringRef rangeBeginCall,StringRef rangeSizeCall,StringRef getOperandCallPattern)848 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
849                                         StringRef sizeAttrInit,
850                                         StringRef rangeType,
851                                         StringRef rangeBeginCall,
852                                         StringRef rangeSizeCall,
853                                         StringRef getOperandCallPattern) {
854   const int numOperands = op.getNumOperands();
855   const int numVariadicOperands = op.getNumVariableLengthOperands();
856   const int numNormalOperands = numOperands - numVariadicOperands;
857 
858   const auto *sameVariadicSize =
859       op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
860   const auto *attrSizedOperands =
861       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
862 
863   if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
864     PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
865                                  "specification over their sizes");
866   }
867 
868   if (numVariadicOperands < 2 && attrSizedOperands) {
869     PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
870                                  "to use 'AttrSizedOperandSegments' trait");
871   }
872 
873   if (attrSizedOperands && sameVariadicSize) {
874     PrintFatalError(op.getLoc(),
875                     "op cannot have both 'AttrSizedOperandSegments' and "
876                     "'SameVariadicOperandSize' traits");
877   }
878 
879   // First emit a few "sink" getter methods upon which we layer all nicer named
880   // getter methods.
881   generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
882                                 numVariadicOperands, numNormalOperands,
883                                 rangeSizeCall, attrSizedOperands, sizeAttrInit,
884                                 const_cast<Operator &>(op).getOperands());
885 
886   auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned",
887                                       "index");
888   auto &body = m->body();
889   body << formatv(valueRangeReturnCode, rangeBeginCall,
890                   "getODSOperandIndexAndLength(index)");
891 
892   // Then we emit nicer named getter methods by redirecting to the "sink" getter
893   // method.
894   for (int i = 0; i != numOperands; ++i) {
895     const auto &operand = op.getOperand(i);
896     if (operand.name.empty())
897       continue;
898 
899     if (operand.isOptional()) {
900       m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
901       m->body() << "  auto operands = getODSOperands(" << i << ");\n"
902                 << "  return operands.empty() ? Value() : *operands.begin();";
903     } else if (operand.isVariadic()) {
904       m = opClass.addMethodAndPrune(rangeType, operand.name);
905       m->body() << "  return getODSOperands(" << i << ");";
906     } else {
907       m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
908       m->body() << "  return *getODSOperands(" << i << ").begin();";
909     }
910   }
911 }
912 
genNamedOperandGetters()913 void OpEmitter::genNamedOperandGetters() {
914   generateNamedOperandGetters(
915       op, opClass,
916       /*sizeAttrInit=*/
917       formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(),
918       /*rangeType=*/"::mlir::Operation::operand_range",
919       /*rangeBeginCall=*/"getOperation()->operand_begin()",
920       /*rangeSizeCall=*/"getOperation()->getNumOperands()",
921       /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
922 }
923 
genNamedOperandSetters()924 void OpEmitter::genNamedOperandSetters() {
925   auto *attrSizedOperands =
926       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
927   for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
928     const auto &operand = op.getOperand(i);
929     if (operand.name.empty())
930       continue;
931     auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange",
932                                         (operand.name + "Mutable").str());
933     auto &body = m->body();
934     body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
935          << "  return ::mlir::MutableOperandRange(getOperation(), "
936             "range.first, range.second";
937     if (attrSizedOperands)
938       body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
939            << "u, *getOperation()->getAttrDictionary().getNamed("
940               "\"operand_segment_sizes\"))";
941     body << ");\n";
942   }
943 }
944 
genNamedResultGetters()945 void OpEmitter::genNamedResultGetters() {
946   const int numResults = op.getNumResults();
947   const int numVariadicResults = op.getNumVariableLengthResults();
948   const int numNormalResults = numResults - numVariadicResults;
949 
950   // If we have more than one variadic results, we need more complicated logic
951   // to calculate the value range for each result.
952 
953   const auto *sameVariadicSize =
954       op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
955   const auto *attrSizedResults =
956       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
957 
958   if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
959     PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
960                                  "specification over their sizes");
961   }
962 
963   if (numVariadicResults < 2 && attrSizedResults) {
964     PrintFatalError(op.getLoc(), "op must have at least two variadic results "
965                                  "to use 'AttrSizedResultSegments' trait");
966   }
967 
968   if (attrSizedResults && sameVariadicSize) {
969     PrintFatalError(op.getLoc(),
970                     "op cannot have both 'AttrSizedResultSegments' and "
971                     "'SameVariadicResultSize' traits");
972   }
973 
974   generateValueRangeStartAndEnd(
975       opClass, "getODSResultIndexAndLength", numVariadicResults,
976       numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
977       formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(),
978       op.getResults());
979 
980   auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
981                                       "getODSResults", "unsigned", "index");
982   m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
983                        "getODSResultIndexAndLength(index)");
984 
985   for (int i = 0; i != numResults; ++i) {
986     const auto &result = op.getResult(i);
987     if (result.name.empty())
988       continue;
989 
990     if (result.isOptional()) {
991       m = opClass.addMethodAndPrune("::mlir::Value", result.name);
992       m->body()
993           << "  auto results = getODSResults(" << i << ");\n"
994           << "  return results.empty() ? ::mlir::Value() : *results.begin();";
995     } else if (result.isVariadic()) {
996       m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
997                                     result.name);
998       m->body() << "  return getODSResults(" << i << ");";
999     } else {
1000       m = opClass.addMethodAndPrune("::mlir::Value", result.name);
1001       m->body() << "  return *getODSResults(" << i << ").begin();";
1002     }
1003   }
1004 }
1005 
genNamedRegionGetters()1006 void OpEmitter::genNamedRegionGetters() {
1007   unsigned numRegions = op.getNumRegions();
1008   for (unsigned i = 0; i < numRegions; ++i) {
1009     const auto &region = op.getRegion(i);
1010     if (region.name.empty())
1011       continue;
1012 
1013     // Generate the accessors for a variadic region.
1014     if (region.isVariadic()) {
1015       auto *m = opClass.addMethodAndPrune("::mlir::MutableArrayRef<Region>",
1016                                           region.name);
1017       m->body() << formatv("  return (*this)->getRegions().drop_front({0});",
1018                            i);
1019       continue;
1020     }
1021 
1022     auto *m = opClass.addMethodAndPrune("::mlir::Region &", region.name);
1023     m->body() << formatv("  return (*this)->getRegion({0});", i);
1024   }
1025 }
1026 
genNamedSuccessorGetters()1027 void OpEmitter::genNamedSuccessorGetters() {
1028   unsigned numSuccessors = op.getNumSuccessors();
1029   for (unsigned i = 0; i < numSuccessors; ++i) {
1030     const NamedSuccessor &successor = op.getSuccessor(i);
1031     if (successor.name.empty())
1032       continue;
1033 
1034     // Generate the accessors for a variadic successor list.
1035     if (successor.isVariadic()) {
1036       auto *m =
1037           opClass.addMethodAndPrune("::mlir::SuccessorRange", successor.name);
1038       m->body() << formatv(
1039           "  return {std::next((*this)->successor_begin(), {0}), "
1040           "(*this)->successor_end()};",
1041           i);
1042       continue;
1043     }
1044 
1045     auto *m = opClass.addMethodAndPrune("::mlir::Block *", successor.name);
1046     m->body() << formatv("  return (*this)->getSuccessor({0});", i);
1047   }
1048 }
1049 
canGenerateUnwrappedBuilder(Operator & op)1050 static bool canGenerateUnwrappedBuilder(Operator &op) {
1051   // If this op does not have native attributes at all, return directly to avoid
1052   // redefining builders.
1053   if (op.getNumNativeAttributes() == 0)
1054     return false;
1055 
1056   bool canGenerate = false;
1057   // We are generating builders that take raw values for attributes. We need to
1058   // make sure the native attributes have a meaningful "unwrapped" value type
1059   // different from the wrapped mlir::Attribute type to avoid redefining
1060   // builders. This checks for the op has at least one such native attribute.
1061   for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
1062     NamedAttribute &namedAttr = op.getAttribute(i);
1063     if (canUseUnwrappedRawValue(namedAttr.attr)) {
1064       canGenerate = true;
1065       break;
1066     }
1067   }
1068   return canGenerate;
1069 }
1070 
canInferType(Operator & op)1071 static bool canInferType(Operator &op) {
1072   return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
1073          op.getNumRegions() == 0;
1074 }
1075 
genSeparateArgParamBuilder()1076 void OpEmitter::genSeparateArgParamBuilder() {
1077   SmallVector<AttrParamKind, 2> attrBuilderType;
1078   attrBuilderType.push_back(AttrParamKind::WrappedAttr);
1079   if (canGenerateUnwrappedBuilder(op))
1080     attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
1081 
1082   // Emit with separate builders with or without unwrapped attributes and/or
1083   // inferring result type.
1084   auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
1085                   bool inferType) {
1086     llvm::SmallVector<OpMethodParameter, 4> paramList;
1087     llvm::SmallVector<std::string, 4> resultNames;
1088     buildParamList(paramList, resultNames, paramKind, attrType);
1089 
1090     auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1091                                         std::move(paramList));
1092     // If the builder is redundant, skip generating the method.
1093     if (!m)
1094       return;
1095     auto &body = m->body();
1096     genCodeForAddingArgAndRegionForBuilder(
1097         body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
1098 
1099     // Push all result types to the operation state
1100 
1101     if (inferType) {
1102       // Generate builder that infers type too.
1103       // TODO: Subsume this with general checking if type can be
1104       // inferred automatically.
1105       // TODO: Expand to handle regions.
1106       body << formatv(R"(
1107         ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1108         if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1109                       {1}.location, {1}.operands,
1110                       {1}.attributes.getDictionary({1}.getContext()),
1111                       /*regions=*/{{}, inferredReturnTypes)))
1112           {1}.addTypes(inferredReturnTypes);
1113         else
1114           ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1115                       opClass.getClassName(), builderOpState);
1116       return;
1117     }
1118 
1119     switch (paramKind) {
1120     case TypeParamKind::None:
1121       return;
1122     case TypeParamKind::Separate:
1123       for (int i = 0, e = op.getNumResults(); i < e; ++i) {
1124         if (op.getResult(i).isOptional())
1125           body << "  if (" << resultNames[i] << ")\n  ";
1126         body << "  " << builderOpState << ".addTypes(" << resultNames[i]
1127              << ");\n";
1128       }
1129       return;
1130     case TypeParamKind::Collective: {
1131       int numResults = op.getNumResults();
1132       int numVariadicResults = op.getNumVariableLengthResults();
1133       int numNonVariadicResults = numResults - numVariadicResults;
1134       bool hasVariadicResult = numVariadicResults != 0;
1135 
1136       // Avoid emitting "resultTypes.size() >= 0u" which is always true.
1137       if (!(hasVariadicResult && numNonVariadicResults == 0))
1138         body << "  "
1139              << "assert(resultTypes.size() "
1140              << (hasVariadicResult ? ">=" : "==") << " "
1141              << numNonVariadicResults
1142              << "u && \"mismatched number of results\");\n";
1143       body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1144     }
1145       return;
1146     }
1147     llvm_unreachable("unhandled TypeParamKind");
1148   };
1149 
1150   // Some of the build methods generated here may be ambiguous, but TableGen's
1151   // ambiguous function detection will elide those ones.
1152   for (auto attrType : attrBuilderType) {
1153     emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
1154     if (canInferType(op))
1155       emit(attrType, TypeParamKind::None, /*inferType=*/true);
1156     emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
1157   }
1158 }
1159 
genUseOperandAsResultTypeCollectiveParamBuilder()1160 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
1161   int numResults = op.getNumResults();
1162 
1163   // Signature
1164   llvm::SmallVector<OpMethodParameter, 4> paramList;
1165   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1166   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1167   paramList.emplace_back("::mlir::ValueRange", "operands");
1168   // Provide default value for `attributes` when its the last parameter
1169   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1170   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1171                          "attributes", attributesDefaultValue);
1172   if (op.getNumVariadicRegions())
1173     paramList.emplace_back("unsigned", "numRegions");
1174 
1175   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1176                                       std::move(paramList));
1177   // If the builder is redundant, skip generating the method
1178   if (!m)
1179     return;
1180   auto &body = m->body();
1181 
1182   // Operands
1183   body << "  " << builderOpState << ".addOperands(operands);\n";
1184 
1185   // Attributes
1186   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1187 
1188   // Create the correct number of regions
1189   if (int numRegions = op.getNumRegions()) {
1190     body << llvm::formatv(
1191         "  for (unsigned i = 0; i != {0}; ++i)\n",
1192         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1193     body << "    (void)" << builderOpState << ".addRegion();\n";
1194   }
1195 
1196   // Result types
1197   SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
1198   body << "  " << builderOpState << ".addTypes({"
1199        << llvm::join(resultTypes, ", ") << "});\n\n";
1200 }
1201 
genInferredTypeCollectiveParamBuilder()1202 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
1203   // TODO: Expand to support regions.
1204   SmallVector<OpMethodParameter, 4> paramList;
1205   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1206   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1207   paramList.emplace_back("::mlir::ValueRange", "operands");
1208   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1209                          "attributes", "{}");
1210   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1211                                       std::move(paramList));
1212   // If the builder is redundant, skip generating the method
1213   if (!m)
1214     return;
1215   auto &body = m->body();
1216 
1217   int numResults = op.getNumResults();
1218   int numVariadicResults = op.getNumVariableLengthResults();
1219   int numNonVariadicResults = numResults - numVariadicResults;
1220 
1221   int numOperands = op.getNumOperands();
1222   int numVariadicOperands = op.getNumVariableLengthOperands();
1223   int numNonVariadicOperands = numOperands - numVariadicOperands;
1224 
1225   // Operands
1226   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1227     body << "  assert(operands.size()"
1228          << (numVariadicOperands != 0 ? " >= " : " == ")
1229          << numNonVariadicOperands
1230          << "u && \"mismatched number of parameters\");\n";
1231   body << "  " << builderOpState << ".addOperands(operands);\n";
1232   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1233 
1234   // Create the correct number of regions
1235   if (int numRegions = op.getNumRegions()) {
1236     body << llvm::formatv(
1237         "  for (unsigned i = 0; i != {0}; ++i)\n",
1238         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1239     body << "    (void)" << builderOpState << ".addRegion();\n";
1240   }
1241 
1242   // Result types
1243   body << formatv(R"(
1244     ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1245     if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1246                   {1}.location, operands,
1247                   {1}.attributes.getDictionary({1}.getContext()),
1248                   /*regions=*/{{}, inferredReturnTypes))) {{)",
1249                   opClass.getClassName(), builderOpState);
1250   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1251     body << "  assert(inferredReturnTypes.size()"
1252          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1253          << "u && \"mismatched number of return types\");\n";
1254   body << "      " << builderOpState << ".addTypes(inferredReturnTypes);";
1255 
1256   body << formatv(R"(
1257     } else
1258       ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1259                   opClass.getClassName(), builderOpState);
1260 }
1261 
genUseOperandAsResultTypeSeparateParamBuilder()1262 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
1263   llvm::SmallVector<OpMethodParameter, 4> paramList;
1264   llvm::SmallVector<std::string, 4> resultNames;
1265   buildParamList(paramList, resultNames, TypeParamKind::None);
1266 
1267   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1268                                       std::move(paramList));
1269   // If the builder is redundant, skip generating the method
1270   if (!m)
1271     return;
1272   auto &body = m->body();
1273   genCodeForAddingArgAndRegionForBuilder(body);
1274 
1275   auto numResults = op.getNumResults();
1276   if (numResults == 0)
1277     return;
1278 
1279   // Push all result types to the operation state
1280   const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
1281   std::string resultType =
1282       formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
1283   body << "  " << builderOpState << ".addTypes({" << resultType;
1284   for (int i = 1; i != numResults; ++i)
1285     body << ", " << resultType;
1286   body << "});\n\n";
1287 }
1288 
genUseAttrAsResultTypeBuilder()1289 void OpEmitter::genUseAttrAsResultTypeBuilder() {
1290   SmallVector<OpMethodParameter, 4> paramList;
1291   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1292   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1293   paramList.emplace_back("::mlir::ValueRange", "operands");
1294   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1295                          "attributes", "{}");
1296   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1297                                       std::move(paramList));
1298   // If the builder is redundant, skip generating the method
1299   if (!m)
1300     return;
1301 
1302   auto &body = m->body();
1303 
1304   // Push all result types to the operation state
1305   std::string resultType;
1306   const auto &namedAttr = op.getAttribute(0);
1307 
1308   body << "  for (auto attr : attributes) {\n";
1309   body << "    if (attr.first != \"" << namedAttr.name << "\") continue;\n";
1310   if (namedAttr.attr.isTypeAttr()) {
1311     resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()";
1312   } else {
1313     resultType = "attr.second.getType()";
1314   }
1315 
1316   // Operands
1317   body << "  " << builderOpState << ".addOperands(operands);\n";
1318 
1319   // Attributes
1320   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1321 
1322   // Result types
1323   SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1324   body << "    " << builderOpState << ".addTypes({"
1325        << llvm::join(resultTypes, ", ") << "});\n";
1326   body << "  }\n";
1327 }
1328 
1329 /// Returns a signature of the builder. Updates the context `fctx` to enable
1330 /// replacement of $_builder and $_state in the body.
getBuilderSignature(const Builder & builder)1331 static std::string getBuilderSignature(const Builder &builder) {
1332   ArrayRef<Builder::Parameter> params(builder.getParameters());
1333 
1334   // Inject builder and state arguments.
1335   llvm::SmallVector<std::string, 8> arguments;
1336   arguments.reserve(params.size() + 2);
1337   arguments.push_back(
1338       llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str());
1339   arguments.push_back(
1340       llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());
1341 
1342   for (unsigned i = 0, e = params.size(); i < e; ++i) {
1343     // If no name is provided, generate one.
1344     Optional<StringRef> paramName = params[i].getName();
1345     std::string name =
1346         paramName ? paramName->str() : "odsArg" + std::to_string(i);
1347 
1348     std::string defaultValue;
1349     if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
1350       defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str();
1351     arguments.push_back(
1352         llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue)
1353             .str());
1354   }
1355 
1356   return llvm::join(arguments, ", ");
1357 }
1358 
genBuilder()1359 void OpEmitter::genBuilder() {
1360   // Handle custom builders if provided.
1361   for (const Builder &builder : op.getBuilders()) {
1362     std::string paramStr = getBuilderSignature(builder);
1363 
1364     Optional<StringRef> body = builder.getBody();
1365     OpMethod::Property properties =
1366         body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
1367     auto *method =
1368         opClass.addMethodAndPrune("void", "build", properties, paramStr);
1369 
1370     FmtContext fctx;
1371     fctx.withBuilder(odsBuilder);
1372     fctx.addSubst("_state", builderOpState);
1373     if (body)
1374       method->body() << tgfmt(*body, &fctx);
1375   }
1376 
1377   // Generate default builders that requires all result type, operands, and
1378   // attributes as parameters.
1379   if (op.skipDefaultBuilders())
1380     return;
1381 
1382   // We generate three classes of builders here:
1383   // 1. one having a stand-alone parameter for each operand / attribute, and
1384   genSeparateArgParamBuilder();
1385   // 2. one having an aggregated parameter for all result types / operands /
1386   //    attributes, and
1387   genCollectiveParamBuilder();
1388   // 3. one having a stand-alone parameter for each operand and attribute,
1389   //    use the first operand or attribute's type as all result types
1390   //    to facilitate different call patterns.
1391   if (op.getNumVariableLengthResults() == 0) {
1392     if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
1393       genUseOperandAsResultTypeSeparateParamBuilder();
1394       genUseOperandAsResultTypeCollectiveParamBuilder();
1395     }
1396     if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
1397       genUseAttrAsResultTypeBuilder();
1398   }
1399 }
1400 
genCollectiveParamBuilder()1401 void OpEmitter::genCollectiveParamBuilder() {
1402   int numResults = op.getNumResults();
1403   int numVariadicResults = op.getNumVariableLengthResults();
1404   int numNonVariadicResults = numResults - numVariadicResults;
1405 
1406   int numOperands = op.getNumOperands();
1407   int numVariadicOperands = op.getNumVariableLengthOperands();
1408   int numNonVariadicOperands = numOperands - numVariadicOperands;
1409 
1410   SmallVector<OpMethodParameter, 4> paramList;
1411   paramList.emplace_back("::mlir::OpBuilder &", "");
1412   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1413   paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1414   paramList.emplace_back("::mlir::ValueRange", "operands");
1415   // Provide default value for `attributes` when its the last parameter
1416   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1417   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1418                          "attributes", attributesDefaultValue);
1419   if (op.getNumVariadicRegions())
1420     paramList.emplace_back("unsigned", "numRegions");
1421 
1422   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1423                                       std::move(paramList));
1424   // If the builder is redundant, skip generating the method
1425   if (!m)
1426     return;
1427   auto &body = m->body();
1428 
1429   // Operands
1430   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1431     body << "  assert(operands.size()"
1432          << (numVariadicOperands != 0 ? " >= " : " == ")
1433          << numNonVariadicOperands
1434          << "u && \"mismatched number of parameters\");\n";
1435   body << "  " << builderOpState << ".addOperands(operands);\n";
1436 
1437   // Attributes
1438   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1439 
1440   // Create the correct number of regions
1441   if (int numRegions = op.getNumRegions()) {
1442     body << llvm::formatv(
1443         "  for (unsigned i = 0; i != {0}; ++i)\n",
1444         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1445     body << "    (void)" << builderOpState << ".addRegion();\n";
1446   }
1447 
1448   // Result types
1449   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1450     body << "  assert(resultTypes.size()"
1451          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1452          << "u && \"mismatched number of return types\");\n";
1453   body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1454 
1455   // Generate builder that infers type too.
1456   // TODO: Expand to handle regions and successors.
1457   if (canInferType(op) && op.getNumSuccessors() == 0)
1458     genInferredTypeCollectiveParamBuilder();
1459 }
1460 
buildParamList(SmallVectorImpl<OpMethodParameter> & paramList,SmallVectorImpl<std::string> & resultTypeNames,TypeParamKind typeParamKind,AttrParamKind attrParamKind)1461 void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
1462                                SmallVectorImpl<std::string> &resultTypeNames,
1463                                TypeParamKind typeParamKind,
1464                                AttrParamKind attrParamKind) {
1465   resultTypeNames.clear();
1466   auto numResults = op.getNumResults();
1467   resultTypeNames.reserve(numResults);
1468 
1469   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1470   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1471 
1472   switch (typeParamKind) {
1473   case TypeParamKind::None:
1474     break;
1475   case TypeParamKind::Separate: {
1476     // Add parameters for all return types
1477     for (int i = 0; i < numResults; ++i) {
1478       const auto &result = op.getResult(i);
1479       std::string resultName = std::string(result.name);
1480       if (resultName.empty())
1481         resultName = std::string(formatv("resultType{0}", i));
1482 
1483       StringRef type =
1484           result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
1485       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1486       if (result.isOptional())
1487         properties = OpMethodParameter::PP_Optional;
1488 
1489       paramList.emplace_back(type, resultName, properties);
1490       resultTypeNames.emplace_back(std::move(resultName));
1491     }
1492   } break;
1493   case TypeParamKind::Collective: {
1494     paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1495     resultTypeNames.push_back("resultTypes");
1496   } break;
1497   }
1498 
1499   // Add parameters for all arguments (operands and attributes).
1500 
1501   int numOperands = 0;
1502   int numAttrs = 0;
1503 
1504   int defaultValuedAttrStartIndex = op.getNumArgs();
1505   if (attrParamKind == AttrParamKind::UnwrappedValue) {
1506     // Calculate the start index from which we can attach default values in the
1507     // builder declaration.
1508     for (int i = op.getNumArgs() - 1; i >= 0; --i) {
1509       auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
1510       if (!namedAttr || !namedAttr->attr.hasDefaultValue())
1511         break;
1512 
1513       if (!canUseUnwrappedRawValue(namedAttr->attr))
1514         break;
1515 
1516       // Creating an APInt requires us to provide bitwidth, value, and
1517       // signedness, which is complicated compared to others. Similarly
1518       // for APFloat.
1519       // TODO: Adjust the 'returnType' field of such attributes
1520       // to support them.
1521       StringRef retType = namedAttr->attr.getReturnType();
1522       if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
1523         break;
1524 
1525       defaultValuedAttrStartIndex = i;
1526     }
1527   }
1528 
1529   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
1530     auto argument = op.getArg(i);
1531     if (argument.is<tblgen::NamedTypeConstraint *>()) {
1532       const auto &operand = op.getOperand(numOperands);
1533       StringRef type =
1534           operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value";
1535       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1536       if (operand.isOptional())
1537         properties = OpMethodParameter::PP_Optional;
1538 
1539       paramList.emplace_back(type, getArgumentName(op, numOperands),
1540                              properties);
1541       ++numOperands;
1542     } else {
1543       const auto &namedAttr = op.getAttribute(numAttrs);
1544       const auto &attr = namedAttr.attr;
1545 
1546       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1547       if (attr.isOptional())
1548         properties = OpMethodParameter::PP_Optional;
1549 
1550       StringRef type;
1551       switch (attrParamKind) {
1552       case AttrParamKind::WrappedAttr:
1553         type = attr.getStorageType();
1554         break;
1555       case AttrParamKind::UnwrappedValue:
1556         if (canUseUnwrappedRawValue(attr))
1557           type = attr.getReturnType();
1558         else
1559           type = attr.getStorageType();
1560         break;
1561       }
1562 
1563       std::string defaultValue;
1564       // Attach default value if requested and possible.
1565       if (attrParamKind == AttrParamKind::UnwrappedValue &&
1566           i >= defaultValuedAttrStartIndex) {
1567         bool isString = attr.getReturnType() == "::llvm::StringRef";
1568         if (isString)
1569           defaultValue.append("\"");
1570         defaultValue += attr.getDefaultValue();
1571         if (isString)
1572           defaultValue.append("\"");
1573       }
1574       paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
1575       ++numAttrs;
1576     }
1577   }
1578 
1579   /// Insert parameters for each successor.
1580   for (const NamedSuccessor &succ : op.getSuccessors()) {
1581     StringRef type =
1582         succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
1583     paramList.emplace_back(type, succ.name);
1584   }
1585 
1586   /// Insert parameters for variadic regions.
1587   for (const NamedRegion &region : op.getRegions())
1588     if (region.isVariadic())
1589       paramList.emplace_back("unsigned",
1590                              llvm::formatv("{0}Count", region.name).str());
1591 }
1592 
genCodeForAddingArgAndRegionForBuilder(OpMethodBody & body,bool isRawValueAttr)1593 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
1594                                                        bool isRawValueAttr) {
1595   // Push all operands to the result.
1596   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
1597     std::string argName = getArgumentName(op, i);
1598     if (op.getOperand(i).isOptional())
1599       body << "  if (" << argName << ")\n  ";
1600     body << "  " << builderOpState << ".addOperands(" << argName << ");\n";
1601   }
1602 
1603   // If the operation has the operand segment size attribute, add it here.
1604   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1605     body << "  " << builderOpState
1606          << ".addAttribute(\"operand_segment_sizes\", "
1607             "odsBuilder.getI32VectorAttr({";
1608     interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1609       if (op.getOperand(i).isOptional())
1610         body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
1611       else if (op.getOperand(i).isVariadic())
1612         body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
1613       else
1614         body << "1";
1615     });
1616     body << "}));\n";
1617   }
1618 
1619   // Push all attributes to the result.
1620   for (const auto &namedAttr : op.getAttributes()) {
1621     auto &attr = namedAttr.attr;
1622     if (!attr.isDerivedAttr()) {
1623       bool emitNotNullCheck = attr.isOptional();
1624       if (emitNotNullCheck) {
1625         body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
1626       }
1627       if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
1628         // If this is a raw value, then we need to wrap it in an Attribute
1629         // instance.
1630         FmtContext fctx;
1631         fctx.withBuilder("odsBuilder");
1632 
1633         std::string builderTemplate =
1634             std::string(attr.getConstBuilderTemplate());
1635 
1636         // For StringAttr, its constant builder call will wrap the input in
1637         // quotes, which is correct for normal string literals, but incorrect
1638         // here given we use function arguments. So we need to strip the
1639         // wrapping quotes.
1640         if (StringRef(builderTemplate).contains("\"$0\""))
1641           builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
1642 
1643         std::string value =
1644             std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
1645         body << formatv("  {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
1646                         namedAttr.name, value);
1647       } else {
1648         body << formatv("  {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
1649                         namedAttr.name);
1650       }
1651       if (emitNotNullCheck) {
1652         body << "  }\n";
1653       }
1654     }
1655   }
1656 
1657   // Create the correct number of regions.
1658   for (const NamedRegion &region : op.getRegions()) {
1659     if (region.isVariadic())
1660       body << formatv("  for (unsigned i = 0; i < {0}Count; ++i)\n  ",
1661                       region.name);
1662 
1663     body << "  (void)" << builderOpState << ".addRegion();\n";
1664   }
1665 
1666   // Push all successors to the result.
1667   for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
1668     body << formatv("  {0}.addSuccessors({1});\n", builderOpState,
1669                     namedSuccessor.name);
1670   }
1671 }
1672 
genCanonicalizerDecls()1673 void OpEmitter::genCanonicalizerDecls() {
1674   if (!def.getValueAsBit("hasCanonicalizer"))
1675     return;
1676 
1677   SmallVector<OpMethodParameter, 2> paramList;
1678   paramList.emplace_back("::mlir::OwningRewritePatternList &", "results");
1679   paramList.emplace_back("::mlir::MLIRContext *", "context");
1680   opClass.addMethodAndPrune("void", "getCanonicalizationPatterns",
1681                             OpMethod::MP_StaticDeclaration,
1682                             std::move(paramList));
1683 }
1684 
genFolderDecls()1685 void OpEmitter::genFolderDecls() {
1686   bool hasSingleResult =
1687       op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
1688 
1689   if (def.getValueAsBit("hasFolder")) {
1690     if (hasSingleResult) {
1691       opClass.addMethodAndPrune(
1692           "::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration,
1693           "::llvm::ArrayRef<::mlir::Attribute>", "operands");
1694     } else {
1695       SmallVector<OpMethodParameter, 2> paramList;
1696       paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
1697       paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
1698                              "results");
1699       opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
1700                                 OpMethod::MP_Declaration, std::move(paramList));
1701     }
1702   }
1703 }
1704 
genOpInterfaceMethods(const tblgen::InterfaceOpTrait * opTrait)1705 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) {
1706   auto interface = opTrait->getOpInterface();
1707 
1708   // Get the set of methods that should always be declared.
1709   auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
1710   llvm::StringSet<> alwaysDeclaredMethods;
1711   alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
1712                                alwaysDeclaredMethodsVec.end());
1713 
1714   for (const InterfaceMethod &method : interface.getMethods()) {
1715     // Don't declare if the method has a body.
1716     if (method.getBody())
1717       continue;
1718     // Don't declare if the method has a default implementation and the op
1719     // didn't request that it always be declared.
1720     if (method.getDefaultImplementation() &&
1721         !alwaysDeclaredMethods.count(method.getName()))
1722       continue;
1723     genOpInterfaceMethod(method);
1724   }
1725 }
1726 
genOpInterfaceMethod(const InterfaceMethod & method,bool declaration)1727 OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
1728                                           bool declaration) {
1729   SmallVector<OpMethodParameter, 4> paramList;
1730   for (const InterfaceMethod::Argument &arg : method.getArguments())
1731     paramList.emplace_back(arg.type, arg.name);
1732 
1733   auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
1734   if (declaration)
1735     properties =
1736         static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
1737   return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
1738                                    properties, std::move(paramList));
1739 }
1740 
genOpInterfaceMethods()1741 void OpEmitter::genOpInterfaceMethods() {
1742   for (const auto &trait : op.getTraits()) {
1743     if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
1744       if (opTrait->shouldDeclareMethods())
1745         genOpInterfaceMethods(opTrait);
1746   }
1747 }
1748 
genSideEffectInterfaceMethods()1749 void OpEmitter::genSideEffectInterfaceMethods() {
1750   enum EffectKind { Operand, Result, Symbol, Static };
1751   struct EffectLocation {
1752     /// The effect applied.
1753     SideEffect effect;
1754 
1755     /// The index if the kind is not static.
1756     unsigned index : 30;
1757 
1758     /// The kind of the location.
1759     unsigned kind : 2;
1760   };
1761 
1762   StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
1763   auto resolveDecorators = [&](Operator::var_decorator_range decorators,
1764                                unsigned index, unsigned kind) {
1765     for (auto decorator : decorators)
1766       if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
1767         opClass.addTrait(effect->getInterfaceTrait());
1768         interfaceEffects[effect->getBaseEffectName()].push_back(
1769             EffectLocation{*effect, index, kind});
1770       }
1771   };
1772 
1773   // Collect effects that were specified via:
1774   /// Traits.
1775   for (const auto &trait : op.getTraits()) {
1776     const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
1777     if (!opTrait)
1778       continue;
1779     auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
1780     for (auto decorator : opTrait->getEffects())
1781       effects.push_back(EffectLocation{cast<SideEffect>(decorator),
1782                                        /*index=*/0, EffectKind::Static});
1783   }
1784   /// Attributes and Operands.
1785   for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
1786     Argument arg = op.getArg(i);
1787     if (arg.is<NamedTypeConstraint *>()) {
1788       resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
1789       ++operandIt;
1790       continue;
1791     }
1792     const NamedAttribute *attr = arg.get<NamedAttribute *>();
1793     if (attr->attr.getBaseAttr().isSymbolRefAttr())
1794       resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
1795   }
1796   /// Results.
1797   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
1798     resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
1799 
1800   // The code used to add an effect instance.
1801   // {0}: The effect class.
1802   // {1}: Optional value or symbol reference.
1803   // {1}: The resource class.
1804   const char *addEffectCode =
1805       "  effects.emplace_back({0}::get(), {1}{2}::get());\n";
1806 
1807   for (auto &it : interfaceEffects) {
1808     // Generate the 'getEffects' method.
1809     std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
1810                                      "SideEffects::EffectInstance<{0}>> &",
1811                                      it.first())
1812                            .str();
1813     auto *getEffects =
1814         opClass.addMethodAndPrune("void", "getEffects", type, "effects");
1815     auto &body = getEffects->body();
1816 
1817     // Add effect instances for each of the locations marked on the operation.
1818     for (auto &location : it.second) {
1819       StringRef effect = location.effect.getName();
1820       StringRef resource = location.effect.getResource();
1821       if (location.kind == EffectKind::Static) {
1822         // A static instance has no attached value.
1823         body << llvm::formatv(addEffectCode, effect, "", resource).str();
1824       } else if (location.kind == EffectKind::Symbol) {
1825         // A symbol reference requires adding the proper attribute.
1826         const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
1827         if (attr->attr.isOptional()) {
1828           body << "  if (auto symbolRef = " << attr->name << "Attr())\n  "
1829                << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
1830                       .str();
1831         } else {
1832           body << llvm::formatv(addEffectCode, effect, attr->name + "(), ",
1833                                 resource)
1834                       .str();
1835         }
1836       } else {
1837         // Otherwise this is an operand/result, so we need to attach the Value.
1838         body << "  for (::mlir::Value value : getODS"
1839              << (location.kind == EffectKind::Operand ? "Operands" : "Results")
1840              << "(" << location.index << "))\n  "
1841              << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
1842       }
1843     }
1844   }
1845 }
1846 
genTypeInterfaceMethods()1847 void OpEmitter::genTypeInterfaceMethods() {
1848   if (!op.allResultTypesKnown())
1849     return;
1850   // Generate 'inferReturnTypes' method declaration using the interface method
1851   // declared in 'InferTypeOpInterface' op interface.
1852   const auto *trait = dyn_cast<InterfaceOpTrait>(
1853       op.getTrait("::mlir::InferTypeOpInterface::Trait"));
1854   auto interface = trait->getOpInterface();
1855   OpMethod *method = [&]() -> OpMethod * {
1856     for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
1857       if (interfaceMethod.getName() == "inferReturnTypes") {
1858         return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
1859       }
1860     }
1861     assert(0 && "unable to find inferReturnTypes interface method");
1862     return nullptr;
1863   }();
1864   auto &body = method->body();
1865   body << "  inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
1866 
1867   FmtContext fctx;
1868   fctx.withBuilder("odsBuilder");
1869   body << "  ::mlir::Builder odsBuilder(context);\n";
1870 
1871   auto emitType =
1872       [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & {
1873     if (type.isArg()) {
1874       auto argIndex = type.getArg();
1875       assert(!op.getArg(argIndex).is<NamedAttribute *>());
1876       auto arg = op.getArgToOperandOrAttribute(argIndex);
1877       if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
1878         return body << "operands[" << arg.operandOrAttributeIndex()
1879                     << "].getType()";
1880       return body << "attributes[" << arg.operandOrAttributeIndex()
1881                   << "].getType()";
1882     } else {
1883       return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
1884     }
1885   };
1886 
1887   for (int i = 0, e = op.getNumResults(); i != e; ++i) {
1888     body << "  inferredReturnTypes[" << i << "] = ";
1889     auto types = op.getSameTypeAsResult(i);
1890     emitType(types[0]) << ";\n";
1891     if (types.size() == 1)
1892       continue;
1893     // TODO: We could verify equality here, but skipping that for verification.
1894   }
1895   body << "  return ::mlir::success();";
1896 }
1897 
genParser()1898 void OpEmitter::genParser() {
1899   if (!hasStringAttribute(def, "parser") ||
1900       hasStringAttribute(def, "assemblyFormat"))
1901     return;
1902 
1903   SmallVector<OpMethodParameter, 2> paramList;
1904   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1905   paramList.emplace_back("::mlir::OperationState &", "result");
1906   auto *method =
1907       opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
1908                                 OpMethod::MP_Static, std::move(paramList));
1909 
1910   FmtContext fctx;
1911   fctx.addSubst("cppClass", opClass.getClassName());
1912   auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
1913   method->body() << "  " << tgfmt(parser, &fctx);
1914 }
1915 
genPrinter()1916 void OpEmitter::genPrinter() {
1917   if (hasStringAttribute(def, "assemblyFormat"))
1918     return;
1919 
1920   auto valueInit = def.getValueInit("printer");
1921   StringInit *stringInit = dyn_cast<StringInit>(valueInit);
1922   if (!stringInit)
1923     return;
1924 
1925   auto *method =
1926       opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p");
1927   FmtContext fctx;
1928   fctx.addSubst("cppClass", opClass.getClassName());
1929   auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
1930   method->body() << "  " << tgfmt(printer, &fctx);
1931 }
1932 
genVerifier()1933 void OpEmitter::genVerifier() {
1934   auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify");
1935   auto &body = method->body();
1936   body << "  if (failed(" << op.getAdaptorName()
1937        << "(*this).verify((*this)->getLoc()))) "
1938        << "return ::mlir::failure();\n";
1939 
1940   auto *valueInit = def.getValueInit("verifier");
1941   StringInit *stringInit = dyn_cast<StringInit>(valueInit);
1942   bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
1943   populateSubstitutions(op, "(*this)->getAttr", "this->getODSOperands",
1944                         "this->getODSResults", verifyCtx);
1945 
1946   genAttributeVerifier(op, "(*this)->getAttr", "emitOpError(",
1947                        /*emitVerificationRequiringOp=*/true, verifyCtx, body);
1948   genOperandResultVerifier(body, op.getOperands(), "operand");
1949   genOperandResultVerifier(body, op.getResults(), "result");
1950 
1951   for (auto &trait : op.getTraits()) {
1952     if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
1953       body << tgfmt("  if (!($0))\n    "
1954                     "return emitOpError(\"failed to verify that $1\");\n",
1955                     &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
1956                     t->getSummary());
1957     }
1958   }
1959 
1960   genRegionVerifier(body);
1961   genSuccessorVerifier(body);
1962 
1963   if (hasCustomVerify) {
1964     FmtContext fctx;
1965     fctx.addSubst("cppClass", opClass.getClassName());
1966     auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
1967     body << "  " << tgfmt(printer, &fctx);
1968   } else {
1969     body << "  return ::mlir::success();\n";
1970   }
1971 }
1972 
genOperandResultVerifier(OpMethodBody & body,Operator::value_range values,StringRef valueKind)1973 void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
1974                                          Operator::value_range values,
1975                                          StringRef valueKind) {
1976   FmtContext fctx;
1977 
1978   body << "  {\n";
1979   body << "    unsigned index = 0; (void)index;\n";
1980 
1981   for (auto staticValue : llvm::enumerate(values)) {
1982     bool hasPredicate = staticValue.value().hasPredicate();
1983     bool isOptional = staticValue.value().isOptional();
1984     if (!hasPredicate && !isOptional)
1985       continue;
1986     body << formatv("    auto valueGroup{2} = getODS{0}{1}s({2});\n",
1987                     // Capitalize the first letter to match the function name
1988                     valueKind.substr(0, 1).upper(), valueKind.substr(1),
1989                     staticValue.index());
1990 
1991     // If the constraint is optional check that the value group has at most 1
1992     // value.
1993     if (isOptional) {
1994       body << formatv("    if (valueGroup{0}.size() > 1)\n"
1995                       "      return emitOpError(\"{1} group starting at #\") "
1996                       "<< index << \" requires 0 or 1 element, but found \" << "
1997                       "valueGroup{0}.size();\n",
1998                       staticValue.index(), valueKind);
1999     }
2000 
2001     // Otherwise, if there is no predicate there is nothing left to do.
2002     if (!hasPredicate)
2003       continue;
2004     // Emit a loop to check all the dynamic values in the pack.
2005     StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn(
2006         staticValue.value().constraint);
2007     body << "    for (::mlir::Value v : valueGroup" << staticValue.index()
2008          << ") {\n"
2009          << "      if (::mlir::failed(" << constraintFn
2010          << "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n"
2011          << "        return ::mlir::failure();\n"
2012          << "      ++index;\n"
2013          << "    }\n";
2014   }
2015 
2016   body << "  }\n";
2017 }
2018 
genRegionVerifier(OpMethodBody & body)2019 void OpEmitter::genRegionVerifier(OpMethodBody &body) {
2020   // If we have no regions, there is nothing more to do.
2021   unsigned numRegions = op.getNumRegions();
2022   if (numRegions == 0)
2023     return;
2024 
2025   body << "{\n";
2026   body << "    unsigned index = 0; (void)index;\n";
2027 
2028   for (unsigned i = 0; i < numRegions; ++i) {
2029     const auto &region = op.getRegion(i);
2030     if (region.constraint.getPredicate().isNull())
2031       continue;
2032 
2033     body << "    for (::mlir::Region &region : ";
2034     body << formatv(region.isVariadic()
2035                         ? "{0}()"
2036                         : "::mlir::MutableArrayRef<::mlir::Region>((*this)"
2037                           "->getRegion({1}))",
2038                     region.name, i);
2039     body << ") {\n";
2040     auto constraint = tgfmt(region.constraint.getConditionTemplate(),
2041                             &verifyCtx.withSelf("region"))
2042                           .str();
2043 
2044     body << formatv("      (void)region;\n"
2045                     "      if (!({0})) {\n        "
2046                     "return emitOpError(\"region #\") << index << \" {1}"
2047                     "failed to "
2048                     "verify constraint: {2}\";\n      }\n",
2049                     constraint,
2050                     region.name.empty() ? "" : "('" + region.name + "') ",
2051                     region.constraint.getSummary())
2052          << "      ++index;\n"
2053          << "    }\n";
2054   }
2055   body << "  }\n";
2056 }
2057 
genSuccessorVerifier(OpMethodBody & body)2058 void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
2059   // If we have no successors, there is nothing more to do.
2060   unsigned numSuccessors = op.getNumSuccessors();
2061   if (numSuccessors == 0)
2062     return;
2063 
2064   body << "{\n";
2065   body << "    unsigned index = 0; (void)index;\n";
2066 
2067   for (unsigned i = 0; i < numSuccessors; ++i) {
2068     const auto &successor = op.getSuccessor(i);
2069     if (successor.constraint.getPredicate().isNull())
2070       continue;
2071 
2072     if (successor.isVariadic()) {
2073       body << formatv("    for (::mlir::Block *successor : {0}()) {\n",
2074                       successor.name);
2075     } else {
2076       body << "    {\n";
2077       body << formatv("      ::mlir::Block *successor = {0}();\n",
2078                       successor.name);
2079     }
2080     auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
2081                             &verifyCtx.withSelf("successor"))
2082                           .str();
2083 
2084     body << formatv("      (void)successor;\n"
2085                     "      if (!({0})) {\n        "
2086                     "return emitOpError(\"successor #\") << index << \"('{1}') "
2087                     "failed to "
2088                     "verify constraint: {2}\";\n      }\n",
2089                     constraint, successor.name,
2090                     successor.constraint.getSummary())
2091          << "      ++index;\n"
2092          << "    }\n";
2093   }
2094   body << "  }\n";
2095 }
2096 
2097 /// Add a size count trait to the given operation class.
addSizeCountTrait(OpClass & opClass,StringRef traitKind,int numTotal,int numVariadic)2098 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
2099                               int numTotal, int numVariadic) {
2100   if (numVariadic != 0) {
2101     if (numTotal == numVariadic)
2102       opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
2103     else
2104       opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
2105                        Twine(numTotal - numVariadic) + ">::Impl");
2106     return;
2107   }
2108   switch (numTotal) {
2109   case 0:
2110     opClass.addTrait("::mlir::OpTrait::Zero" + traitKind);
2111     break;
2112   case 1:
2113     opClass.addTrait("::mlir::OpTrait::One" + traitKind);
2114     break;
2115   default:
2116     opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
2117                      ">::Impl");
2118     break;
2119   }
2120 }
2121 
genTraits()2122 void OpEmitter::genTraits() {
2123   // Add region size trait.
2124   unsigned numRegions = op.getNumRegions();
2125   unsigned numVariadicRegions = op.getNumVariadicRegions();
2126   addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
2127 
2128   // Add result size traits.
2129   int numResults = op.getNumResults();
2130   int numVariadicResults = op.getNumVariableLengthResults();
2131   addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
2132 
2133   // For single result ops with a known specific type, generate a OneTypedResult
2134   // trait.
2135   if (numResults == 1 && numVariadicResults == 0) {
2136     auto cppName = op.getResults().begin()->constraint.getCPPClassName();
2137     opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
2138   }
2139 
2140   // Add successor size trait.
2141   unsigned numSuccessors = op.getNumSuccessors();
2142   unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
2143   addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
2144 
2145   // Add variadic size trait and normal op traits.
2146   int numOperands = op.getNumOperands();
2147   int numVariadicOperands = op.getNumVariableLengthOperands();
2148 
2149   // Add operand size trait.
2150   if (numVariadicOperands != 0) {
2151     if (numOperands == numVariadicOperands)
2152       opClass.addTrait("::mlir::OpTrait::VariadicOperands");
2153     else
2154       opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" +
2155                        Twine(numOperands - numVariadicOperands) + ">::Impl");
2156   } else {
2157     switch (numOperands) {
2158     case 0:
2159       opClass.addTrait("::mlir::OpTrait::ZeroOperands");
2160       break;
2161     case 1:
2162       opClass.addTrait("::mlir::OpTrait::OneOperand");
2163       break;
2164     default:
2165       opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) +
2166                        ">::Impl");
2167       break;
2168     }
2169   }
2170 
2171   // Add the native and interface traits.
2172   for (const auto &trait : op.getTraits()) {
2173     if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
2174       opClass.addTrait(opTrait->getTrait());
2175     else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
2176       opClass.addTrait(opTrait->getTrait());
2177   }
2178 }
2179 
genOpNameGetter()2180 void OpEmitter::genOpNameGetter() {
2181   auto *method = opClass.addMethodAndPrune(
2182       "::llvm::StringRef", "getOperationName", OpMethod::MP_Static);
2183   method->body() << "  return \"" << op.getOperationName() << "\";\n";
2184 }
2185 
genOpAsmInterface()2186 void OpEmitter::genOpAsmInterface() {
2187   // If the user only has one results or specifically added the Asm trait,
2188   // then don't generate it for them. We specifically only handle multi result
2189   // operations, because the name of a single result in the common case is not
2190   // interesting(generally 'result'/'output'/etc.).
2191   // TODO: We could also add a flag to allow operations to opt in to this
2192   // generation, even if they only have a single operation.
2193   int numResults = op.getNumResults();
2194   if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
2195     return;
2196 
2197   SmallVector<StringRef, 4> resultNames(numResults);
2198   for (int i = 0; i != numResults; ++i)
2199     resultNames[i] = op.getResultName(i);
2200 
2201   // Don't add the trait if none of the results have a valid name.
2202   if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
2203     return;
2204   opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
2205 
2206   // Generate the right accessor for the number of results.
2207   auto *method = opClass.addMethodAndPrune(
2208       "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn");
2209   auto &body = method->body();
2210   for (int i = 0; i != numResults; ++i) {
2211     body << "  auto resultGroup" << i << " = getODSResults(" << i << ");\n"
2212          << "  if (!llvm::empty(resultGroup" << i << "))\n"
2213          << "    setNameFn(*resultGroup" << i << ".begin(), \""
2214          << resultNames[i] << "\");\n";
2215   }
2216 }
2217 
2218 //===----------------------------------------------------------------------===//
2219 // OpOperandAdaptor emitter
2220 //===----------------------------------------------------------------------===//
2221 
2222 namespace {
2223 // Helper class to emit Op operand adaptors to an output stream.  Operand
2224 // adaptors are wrappers around ArrayRef<Value> that provide named operand
2225 // getters identical to those defined in the Op.
2226 class OpOperandAdaptorEmitter {
2227 public:
2228   static void emitDecl(const Operator &op, raw_ostream &os);
2229   static void emitDef(const Operator &op, raw_ostream &os);
2230 
2231 private:
2232   explicit OpOperandAdaptorEmitter(const Operator &op);
2233 
2234   // Add verification function. This generates a verify method for the adaptor
2235   // which verifies all the op-independent attribute constraints.
2236   void addVerification();
2237 
2238   const Operator &op;
2239   Class adaptor;
2240 };
2241 } // end namespace
2242 
OpOperandAdaptorEmitter(const Operator & op)2243 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
2244     : op(op), adaptor(op.getAdaptorName()) {
2245   adaptor.newField("::mlir::ValueRange", "odsOperands");
2246   adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
2247   const auto *attrSizedOperands =
2248       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2249   {
2250     SmallVector<OpMethodParameter, 2> paramList;
2251     paramList.emplace_back("::mlir::ValueRange", "values");
2252     paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
2253                            attrSizedOperands ? "" : "nullptr");
2254     auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList));
2255 
2256     constructor->addMemberInitializer("odsOperands", "values");
2257     constructor->addMemberInitializer("odsAttrs", "attrs");
2258   }
2259 
2260   {
2261     auto *constructor = adaptor.addConstructorAndPrune(
2262         llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
2263     constructor->addMemberInitializer("odsOperands", "op->getOperands()");
2264     constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
2265   }
2266 
2267   std::string sizeAttrInit =
2268       formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
2269   generateNamedOperandGetters(op, adaptor, sizeAttrInit,
2270                               /*rangeType=*/"::mlir::ValueRange",
2271                               /*rangeBeginCall=*/"odsOperands.begin()",
2272                               /*rangeSizeCall=*/"odsOperands.size()",
2273                               /*getOperandCallPattern=*/"odsOperands[{0}]");
2274 
2275   FmtContext fctx;
2276   fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
2277 
2278   auto emitAttr = [&](StringRef name, Attribute attr) {
2279     auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body();
2280     body << "  assert(odsAttrs && \"no attributes when constructing adapter\");"
2281          << "\n  " << attr.getStorageType() << " attr = "
2282          << "odsAttrs.get(\"" << name << "\").";
2283     if (attr.hasDefaultValue() || attr.isOptional())
2284       body << "dyn_cast_or_null<";
2285     else
2286       body << "cast<";
2287     body << attr.getStorageType() << ">();\n";
2288 
2289     if (attr.hasDefaultValue()) {
2290       // Use the default value if attribute is not set.
2291       // TODO: this is inefficient, we are recreating the attribute for every
2292       // call. This should be set instead.
2293       std::string defaultValue = std::string(
2294           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2295       body << "  if (!attr)\n    attr = " << defaultValue << ";\n";
2296     }
2297     body << "  return attr;\n";
2298   };
2299 
2300   for (auto &namedAttr : op.getAttributes()) {
2301     const auto &name = namedAttr.name;
2302     const auto &attr = namedAttr.attr;
2303     if (!attr.isDerivedAttr())
2304       emitAttr(name, attr);
2305   }
2306 
2307   // Add verification function.
2308   addVerification();
2309 }
2310 
addVerification()2311 void OpOperandAdaptorEmitter::addVerification() {
2312   auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
2313                                            "::mlir::Location", "loc");
2314   auto &body = method->body();
2315 
2316   const char *checkAttrSizedValueSegmentsCode = R"(
2317   {
2318     auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
2319     auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
2320     if (numElements != {1})
2321       return emitError(loc, "'{0}' attribute for specifying {2} segments "
2322                        "must have {1} elements, but got ") << numElements;
2323   }
2324   )";
2325 
2326   // Verify a few traits first so that we can use
2327   // getODSOperands()/getODSResults() in the rest of the verifier.
2328   for (auto &trait : op.getTraits()) {
2329     if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
2330       if (t->getTrait() == "::mlir::OpTrait::AttrSizedOperandSegments") {
2331         body << formatv(checkAttrSizedValueSegmentsCode,
2332                         "operand_segment_sizes", op.getNumOperands(),
2333                         "operand");
2334       } else if (t->getTrait() == "::mlir::OpTrait::AttrSizedResultSegments") {
2335         body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
2336                         op.getNumResults(), "result");
2337       }
2338     }
2339   }
2340 
2341   FmtContext verifyCtx;
2342   populateSubstitutions(op, "odsAttrs.get", "getODSOperands",
2343                         "<no results should be generated>", verifyCtx);
2344   genAttributeVerifier(op, "odsAttrs.get",
2345                        Twine("emitError(loc, \"'") + op.getOperationName() +
2346                            "' op \"",
2347                        /*emitVerificationRequiringOp*/ false, verifyCtx, body);
2348 
2349   body << "  return ::mlir::success();";
2350 }
2351 
emitDecl(const Operator & op,raw_ostream & os)2352 void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
2353   OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
2354 }
2355 
emitDef(const Operator & op,raw_ostream & os)2356 void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
2357   OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
2358 }
2359 
2360 // Emits the opcode enum and op classes.
emitOpClasses(const RecordKeeper & recordKeeper,const std::vector<Record * > & defs,raw_ostream & os,bool emitDecl)2361 static void emitOpClasses(const RecordKeeper &recordKeeper,
2362                           const std::vector<Record *> &defs, raw_ostream &os,
2363                           bool emitDecl) {
2364   // First emit forward declaration for each class, this allows them to refer
2365   // to each others in traits for example.
2366   if (emitDecl) {
2367     os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
2368     os << "#undef GET_OP_FWD_DEFINES\n";
2369     for (auto *def : defs) {
2370       Operator op(*def);
2371       NamespaceEmitter emitter(os, op.getDialect());
2372       os << "class " << op.getCppClassName() << ";\n";
2373     }
2374     os << "#endif\n\n";
2375   }
2376 
2377   IfDefScope scope("GET_OP_CLASSES", os);
2378   if (defs.empty())
2379     return;
2380 
2381   // Generate all of the locally instantiated methods first.
2382   StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os,
2383                                                       emitDecl);
2384   for (auto *def : defs) {
2385     Operator op(*def);
2386     NamespaceEmitter emitter(os, op.getDialect());
2387     if (emitDecl) {
2388       os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
2389       OpOperandAdaptorEmitter::emitDecl(op, os);
2390       OpEmitter::emitDecl(op, os, staticVerifierEmitter);
2391     } else {
2392       os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
2393       OpOperandAdaptorEmitter::emitDef(op, os);
2394       OpEmitter::emitDef(op, os, staticVerifierEmitter);
2395     }
2396   }
2397 }
2398 
2399 // Emits a comma-separated list of the ops.
emitOpList(const std::vector<Record * > & defs,raw_ostream & os)2400 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
2401   IfDefScope scope("GET_OP_LIST", os);
2402 
2403   interleave(
2404       // TODO: We are constructing the Operator wrapper instance just for
2405       // getting it's qualified class name here. Reduce the overhead by having a
2406       // lightweight version of Operator class just for that purpose.
2407       defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
2408       [&os]() { os << ",\n"; });
2409 }
2410 
getOperationName(const Record & def)2411 static std::string getOperationName(const Record &def) {
2412   auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
2413   auto opName = def.getValueAsString("opName");
2414   if (prefix.empty())
2415     return std::string(opName);
2416   return std::string(llvm::formatv("{0}.{1}", prefix, opName));
2417 }
2418 
2419 static std::vector<Record *>
getAllDerivedDefinitions(const RecordKeeper & recordKeeper,StringRef className)2420 getAllDerivedDefinitions(const RecordKeeper &recordKeeper,
2421                          StringRef className) {
2422   Record *classDef = recordKeeper.getClass(className);
2423   if (!classDef)
2424     PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n");
2425 
2426   llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
2427   std::vector<Record *> defs;
2428   for (const auto &def : recordKeeper.getDefs()) {
2429     if (!def.second->isSubClassOf(classDef))
2430       continue;
2431     // Include if no include filter or include filter matches.
2432     if (!opIncFilter.empty() &&
2433         !includeRegex.match(getOperationName(*def.second)))
2434       continue;
2435     // Unless there is an exclude filter and it matches.
2436     if (!opExcFilter.empty() &&
2437         excludeRegex.match(getOperationName(*def.second)))
2438       continue;
2439     defs.push_back(def.second.get());
2440   }
2441 
2442   return defs;
2443 }
2444 
emitOpDecls(const RecordKeeper & recordKeeper,raw_ostream & os)2445 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
2446   emitSourceFileHeader("Op Declarations", os);
2447 
2448   const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2449   emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
2450 
2451   return false;
2452 }
2453 
emitOpDefs(const RecordKeeper & recordKeeper,raw_ostream & os)2454 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
2455   emitSourceFileHeader("Op Definitions", os);
2456 
2457   const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2458   emitOpList(defs, os);
2459   emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
2460 
2461   return false;
2462 }
2463 
2464 static mlir::GenRegistration
2465     genOpDecls("gen-op-decls", "Generate op declarations",
__anon984378f01602(const RecordKeeper &records, raw_ostream &os) 2466                [](const RecordKeeper &records, raw_ostream &os) {
2467                  return emitOpDecls(records, os);
2468                });
2469 
2470 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
2471                                        [](const RecordKeeper &records,
__anon984378f01702(const RecordKeeper &records, raw_ostream &os) 2472                                           raw_ostream &os) {
2473                                          return emitOpDefs(records, os);
2474                                        });
2475