1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "OpFormatGen.h"
15 #include "OpGenHelpers.h"
16 #include "mlir/TableGen/CodeGenHelpers.h"
17 #include "mlir/TableGen/Format.h"
18 #include "mlir/TableGen/GenInfo.h"
19 #include "mlir/TableGen/Interfaces.h"
20 #include "mlir/TableGen/OpClass.h"
21 #include "mlir/TableGen/Operator.h"
22 #include "mlir/TableGen/SideEffects.h"
23 #include "mlir/TableGen/Trait.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/Support/Path.h"
28 #include "llvm/Support/Signals.h"
29 #include "llvm/TableGen/Error.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
32 
33 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
34 
35 using namespace llvm;
36 using namespace mlir;
37 using namespace mlir::tblgen;
38 
39 static const char *const tblgenNamePrefix = "tblgen_";
40 static const char *const generatedArgName = "odsArg";
41 static const char *const odsBuilder = "odsBuilder";
42 static const char *const builderOpState = "odsState";
43 
44 // The logic to calculate the actual value range for a declared operand/result
45 // of an op with variadic operands/results. Note that this logic is not for
46 // general use; it assumes all variadic operands/results must have the same
47 // number of values.
48 //
49 // {0}: The list of whether each declared operand/result is variadic.
50 // {1}: The total number of non-variadic operands/results.
51 // {2}: The total number of variadic operands/results.
52 // {3}: The total number of actual values.
53 // {4}: "operand" or "result".
54 const char *sameVariadicSizeValueRangeCalcCode = R"(
55   bool isVariadic[] = {{{0}};
56   int prevVariadicCount = 0;
57   for (unsigned i = 0; i < index; ++i)
58     if (isVariadic[i]) ++prevVariadicCount;
59 
60   // Calculate how many dynamic values a static variadic {4} corresponds to.
61   // This assumes all static variadic {4}s have the same dynamic value count.
62   int variadicSize = ({3} - {1}) / {2};
63   // `index` passed in as the parameter is the static index which counts each
64   // {4} (variadic or not) as size 1. So here for each previous static variadic
65   // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
66   // value pack for this static {4} starts.
67   int start = index + (variadicSize - 1) * prevVariadicCount;
68   int size = isVariadic[index] ? variadicSize : 1;
69   return {{start, size};
70 )";
71 
72 // The logic to calculate the actual value range for a declared operand/result
73 // of an op with variadic operands/results. Note that this logic is assumes
74 // the op has an attribute specifying the size of each operand/result segment
75 // (variadic or not).
76 //
77 // {0}: The name of the attribute specifying the segment sizes.
78 const char *adapterSegmentSizeAttrInitCode = R"(
79   assert(odsAttrs && "missing segment size attribute for op");
80   auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
81 )";
82 const char *opSegmentSizeAttrInitCode = R"(
83   auto sizeAttr = (*this)->getAttr({0}).cast<::mlir::DenseIntElementsAttr>();
84 )";
85 const char *attrSizedSegmentValueRangeCalcCode = R"(
86   auto sizeAttrValues = sizeAttr.getValues<uint32_t>();
87   unsigned start = 0;
88   for (unsigned i = 0; i < index; ++i)
89     start += *(sizeAttrValues.begin() + i);
90   unsigned size = *(sizeAttrValues.begin() + index);
91   return {start, size};
92 )";
93 
94 // The logic to build a range of either operand or result values.
95 //
96 // {0}: The begin iterator of the actual values.
97 // {1}: The call to generate the start and length of the value range.
98 const char *valueRangeReturnCode = R"(
99   auto valueRange = {1};
100   return {{std::next({0}, valueRange.first),
101            std::next({0}, valueRange.first + valueRange.second)};
102 )";
103 
104 static const char *const opCommentHeader = R"(
105 //===----------------------------------------------------------------------===//
106 // {0} {1}
107 //===----------------------------------------------------------------------===//
108 
109 )";
110 
111 //===----------------------------------------------------------------------===//
112 // StaticVerifierFunctionEmitter
113 //===----------------------------------------------------------------------===//
114 
115 namespace {
116 /// This class deduplicates shared operation verification code by emitting
117 /// static functions alongside the op definitions. These methods are local to
118 /// the definition file, and are invoked within the operation verify methods.
119 /// An example is shown below:
120 ///
121 /// static LogicalResult localVerify(...)
122 ///
123 /// LogicalResult OpA::verify(...) {
124 ///  if (failed(localVerify(...)))
125 ///    return failure();
126 ///  ...
127 /// }
128 ///
129 /// LogicalResult OpB::verify(...) {
130 ///  if (failed(localVerify(...)))
131 ///    return failure();
132 ///  ...
133 /// }
134 ///
135 class StaticVerifierFunctionEmitter {
136 public:
137   StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
138                                 ArrayRef<llvm::Record *> opDefs,
139                                 raw_ostream &os, bool emitDecl);
140 
141   /// Get the name of the local function used for the given type constraint.
142   /// These functions are used for operand and result constraints and have the
143   /// form:
144   ///   LogicalResult(Operation *op, Type type, StringRef valueKind,
145   ///                 unsigned valueGroupStartIndex);
getTypeConstraintFn(const Constraint & constraint) const146   StringRef getTypeConstraintFn(const Constraint &constraint) const {
147     auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
148     assert(it != localTypeConstraints.end() && "expected valid constraint fn");
149     return it->second;
150   }
151 
152 private:
153   /// Returns a unique name to use when generating local methods.
154   static std::string getUniqueName(const llvm::RecordKeeper &records);
155 
156   /// Emit local methods for the type constraints used within the provided op
157   /// definitions.
158   void emitTypeConstraintMethods(ArrayRef<llvm::Record *> opDefs,
159                                  raw_ostream &os, bool emitDecl);
160 
161   /// A unique label for the file currently being generated. This is used to
162   /// ensure that the local functions have a unique name.
163   std::string uniqueOutputLabel;
164 
165   /// A set of functions implementing type constraints, used for operand and
166   /// result verification.
167   llvm::DenseMap<const void *, std::string> localTypeConstraints;
168 };
169 } // namespace
170 
StaticVerifierFunctionEmitter(const llvm::RecordKeeper & records,ArrayRef<llvm::Record * > opDefs,raw_ostream & os,bool emitDecl)171 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
172     const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
173     raw_ostream &os, bool emitDecl)
174     : uniqueOutputLabel(getUniqueName(records)) {
175   llvm::Optional<NamespaceEmitter> namespaceEmitter;
176   if (!emitDecl) {
177     os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
178     namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
179   }
180 
181   emitTypeConstraintMethods(opDefs, os, emitDecl);
182 }
183 
getUniqueName(const llvm::RecordKeeper & records)184 std::string StaticVerifierFunctionEmitter::getUniqueName(
185     const llvm::RecordKeeper &records) {
186   // Use the input file name when generating a unique name.
187   std::string inputFilename = records.getInputFilename();
188 
189   // Drop all but the base filename.
190   StringRef nameRef = llvm::sys::path::filename(inputFilename);
191   nameRef.consume_back(".td");
192 
193   // Sanitize any invalid characters.
194   std::string uniqueName;
195   for (char c : nameRef) {
196     if (llvm::isAlnum(c) || c == '_')
197       uniqueName.push_back(c);
198     else
199       uniqueName.append(llvm::utohexstr((unsigned char)c));
200   }
201   return uniqueName;
202 }
203 
emitTypeConstraintMethods(ArrayRef<llvm::Record * > opDefs,raw_ostream & os,bool emitDecl)204 void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
205     ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
206   // Collect a set of all of the used type constraints within the operation
207   // definitions.
208   llvm::SetVector<const void *> typeConstraints;
209   for (Record *def : opDefs) {
210     Operator op(*def);
211     for (NamedTypeConstraint &operand : op.getOperands())
212       if (operand.hasPredicate())
213         typeConstraints.insert(operand.constraint.getAsOpaquePointer());
214     for (NamedTypeConstraint &result : op.getResults())
215       if (result.hasPredicate())
216         typeConstraints.insert(result.constraint.getAsOpaquePointer());
217   }
218 
219   // Record the mapping from predicate to constraint. If two constraints has the
220   // same predicate and constraint summary, they can share the same verification
221   // function.
222   llvm::DenseMap<Pred, const void *> predToConstraint;
223   FmtContext fctx;
224   for (auto it : llvm::enumerate(typeConstraints)) {
225     std::string name;
226     Constraint constraint = Constraint::getFromOpaquePointer(it.value());
227     Pred pred = constraint.getPredicate();
228     auto iter = predToConstraint.find(pred);
229     if (iter != predToConstraint.end()) {
230       do {
231         Constraint built = Constraint::getFromOpaquePointer(iter->second);
232         // We may have the different constraints but have the same predicate,
233         // for example, ConstraintA and Variadic<ConstraintA>, note that
234         // Variadic<> doesn't introduce new predicate. In this case, we can
235         // share the same predicate function if they also have consistent
236         // summary, otherwise we may report the wrong message while verification
237         // fails.
238         if (constraint.getSummary() == built.getSummary()) {
239           name = getTypeConstraintFn(built).str();
240           break;
241         }
242         ++iter;
243       } while (iter != predToConstraint.end() && iter->first == pred);
244     }
245 
246     if (!name.empty()) {
247       localTypeConstraints.try_emplace(it.value(), name);
248       continue;
249     }
250 
251     // Generate an obscure and unique name for this type constraint.
252     name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
253             Twine(it.index()))
254                .str();
255     predToConstraint.insert(
256         std::make_pair(constraint.getPredicate(), it.value()));
257     localTypeConstraints.try_emplace(it.value(), name);
258 
259     // Only generate the methods if we are generating definitions.
260     if (emitDecl)
261       continue;
262 
263     os << "static ::mlir::LogicalResult " << name
264        << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
265           "valueKind, unsigned valueGroupStartIndex) {\n";
266 
267     os << "  if (!("
268        << tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type"))
269        << ")) {\n"
270        << formatv(
271               "    return op->emitOpError(valueKind) << \" #\" << "
272               "valueGroupStartIndex << \" must be {0}, but got \" << type;\n",
273               constraint.getSummary())
274        << "  }\n"
275        << "  return ::mlir::success();\n"
276        << "}\n\n";
277   }
278 }
279 
280 //===----------------------------------------------------------------------===//
281 // Utility structs and functions
282 //===----------------------------------------------------------------------===//
283 
284 // Replaces all occurrences of `match` in `str` with `substitute`.
replaceAllSubstrs(std::string str,const std::string & match,const std::string & substitute)285 static std::string replaceAllSubstrs(std::string str, const std::string &match,
286                                      const std::string &substitute) {
287   std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
288   while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
289     str = str.replace(matchLoc, match.size(), substitute);
290     scanLoc = matchLoc + substitute.size();
291   }
292   return str;
293 }
294 
295 // Returns whether the record has a value of the given name that can be returned
296 // via getValueAsString.
hasStringAttribute(const Record & record,StringRef fieldName)297 static inline bool hasStringAttribute(const Record &record,
298                                       StringRef fieldName) {
299   auto valueInit = record.getValueInit(fieldName);
300   return isa<StringInit>(valueInit);
301 }
302 
getArgumentName(const Operator & op,int index)303 static std::string getArgumentName(const Operator &op, int index) {
304   const auto &operand = op.getOperand(index);
305   if (!operand.name.empty())
306     return std::string(operand.name);
307   else
308     return std::string(formatv("{0}_{1}", generatedArgName, index));
309 }
310 
311 // Returns true if we can use unwrapped value for the given `attr` in builders.
canUseUnwrappedRawValue(const tblgen::Attribute & attr)312 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
313   return attr.getReturnType() != attr.getStorageType() &&
314          // We need to wrap the raw value into an attribute in the builder impl
315          // so we need to make sure that the attribute specifies how to do that.
316          !attr.getConstBuilderTemplate().empty();
317 }
318 
319 //===----------------------------------------------------------------------===//
320 // Op emitter
321 //===----------------------------------------------------------------------===//
322 
323 namespace {
324 // Helper class to emit a record into the given output stream.
325 class OpEmitter {
326 public:
327   static void
328   emitDecl(const Operator &op, raw_ostream &os,
329            const StaticVerifierFunctionEmitter &staticVerifierEmitter);
330   static void
331   emitDef(const Operator &op, raw_ostream &os,
332           const StaticVerifierFunctionEmitter &staticVerifierEmitter);
333 
334 private:
335   OpEmitter(const Operator &op,
336             const StaticVerifierFunctionEmitter &staticVerifierEmitter);
337 
338   void emitDecl(raw_ostream &os);
339   void emitDef(raw_ostream &os);
340 
341   // Generate methods for accessing the attribute names of this operation.
342   void genAttrNameGetters();
343 
344   // Return the index of the given attribute name. This is a relative ordering
345   // for this name, used in attribute getters.
346   unsigned getAttrNameIndex(StringRef attrName) const;
347 
348   // Generates the OpAsmOpInterface for this operation if possible.
349   void genOpAsmInterface();
350 
351   // Generates the `getOperationName` method for this op.
352   void genOpNameGetter();
353 
354   // Generates getters for the attributes.
355   void genAttrGetters();
356 
357   // Generates setter for the attributes.
358   void genAttrSetters();
359 
360   // Generates removers for optional attributes.
361   void genOptionalAttrRemovers();
362 
363   // Generates getters for named operands.
364   void genNamedOperandGetters();
365 
366   // Generates setters for named operands.
367   void genNamedOperandSetters();
368 
369   // Generates getters for named results.
370   void genNamedResultGetters();
371 
372   // Generates getters for named regions.
373   void genNamedRegionGetters();
374 
375   // Generates getters for named successors.
376   void genNamedSuccessorGetters();
377 
378   // Generates builder methods for the operation.
379   void genBuilder();
380 
381   // Generates the build() method that takes each operand/attribute
382   // as a stand-alone parameter.
383   void genSeparateArgParamBuilder();
384 
385   // Generates the build() method that takes each operand/attribute as a
386   // stand-alone parameter. The generated build() method uses first operand's
387   // type as all results' types.
388   void genUseOperandAsResultTypeSeparateParamBuilder();
389 
390   // Generates the build() method that takes all operands/attributes
391   // collectively as one parameter. The generated build() method uses first
392   // operand's type as all results' types.
393   void genUseOperandAsResultTypeCollectiveParamBuilder();
394 
395   // Generates the build() method that takes aggregate operands/attributes
396   // parameters. This build() method uses inferred types as result types.
397   // Requires: The type needs to be inferable via InferTypeOpInterface.
398   void genInferredTypeCollectiveParamBuilder();
399 
400   // Generates the build() method that takes each operand/attribute as a
401   // stand-alone parameter. The generated build() method uses first attribute's
402   // type as all result's types.
403   void genUseAttrAsResultTypeBuilder();
404 
405   // Generates the build() method that takes all result types collectively as
406   // one parameter. Similarly for operands and attributes.
407   void genCollectiveParamBuilder();
408 
409   // The kind of parameter to generate for result types in builders.
410   enum class TypeParamKind {
411     None,       // No result type in parameter list.
412     Separate,   // A separate parameter for each result type.
413     Collective, // An ArrayRef<Type> for all result types.
414   };
415 
416   // The kind of parameter to generate for attributes in builders.
417   enum class AttrParamKind {
418     WrappedAttr,    // A wrapped MLIR Attribute instance.
419     UnwrappedValue, // A raw value without MLIR Attribute wrapper.
420   };
421 
422   // Builds the parameter list for build() method of this op. This method writes
423   // to `paramList` the comma-separated parameter list and updates
424   // `resultTypeNames` with the names for parameters for specifying result
425   // types. The given `typeParamKind` and `attrParamKind` controls how result
426   // types and attributes are placed in the parameter list.
427   void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> &paramList,
428                       SmallVectorImpl<std::string> &resultTypeNames,
429                       TypeParamKind typeParamKind,
430                       AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
431 
432   // Adds op arguments and regions into operation state for build() methods.
433   void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
434                                               bool isRawValueAttr = false);
435 
436   // Generates canonicalizer declaration for the operation.
437   void genCanonicalizerDecls();
438 
439   // Generates the folder declaration for the operation.
440   void genFolderDecls();
441 
442   // Generates the parser for the operation.
443   void genParser();
444 
445   // Generates the printer for the operation.
446   void genPrinter();
447 
448   // Generates verify method for the operation.
449   void genVerifier();
450 
451   // Generates verify statements for operands and results in the operation.
452   // The generated code will be attached to `body`.
453   void genOperandResultVerifier(OpMethodBody &body,
454                                 Operator::value_range values,
455                                 StringRef valueKind);
456 
457   // Generates verify statements for regions in the operation.
458   // The generated code will be attached to `body`.
459   void genRegionVerifier(OpMethodBody &body);
460 
461   // Generates verify statements for successors in the operation.
462   // The generated code will be attached to `body`.
463   void genSuccessorVerifier(OpMethodBody &body);
464 
465   // Generates the traits used by the object.
466   void genTraits();
467 
468   // Generate the OpInterface methods for all interfaces.
469   void genOpInterfaceMethods();
470 
471   // Generate op interface methods for the given interface.
472   void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
473 
474   // Generate op interface method for the given interface method. If
475   // 'declaration' is true, generates a declaration, else a definition.
476   OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
477                                  bool declaration = true);
478 
479   // Generate the side effect interface methods.
480   void genSideEffectInterfaceMethods();
481 
482   // Generate the type inference interface methods.
483   void genTypeInterfaceMethods();
484 
485 private:
486   // The TableGen record for this op.
487   // TODO: OpEmitter should not have a Record directly,
488   // it should rather go through the Operator for better abstraction.
489   const Record &def;
490 
491   // The wrapper operator class for querying information from this op.
492   Operator op;
493 
494   // The C++ code builder for this op
495   OpClass opClass;
496 
497   // The format context for verification code generation.
498   FmtContext verifyCtx;
499 
500   // The emitter containing all of the locally emitted verification functions.
501   const StaticVerifierFunctionEmitter &staticVerifierEmitter;
502 
503   // A map of attribute names (including implicit attributes) registered to the
504   // current operation, to the relative order in which they were registered.
505   llvm::MapVector<StringRef, unsigned> attributeNames;
506 };
507 } // end anonymous namespace
508 
509 // Populate the format context `ctx` with substitutions of attributes, operands
510 // and results.
511 // - attrGet corresponds to the name of the function to call to get value of
512 //   attribute (the generated function call returns an Attribute);
513 // - operandGet corresponds to the name of the function with which to retrieve
514 //   an operand (the generated function call returns an OperandRange);
515 // - resultGet corresponds to the name of the function to get an result (the
516 //   generated function call returns a ValueRange);
populateSubstitutions(const Operator & op,const char * attrGet,const char * operandGet,const char * resultGet,FmtContext & ctx)517 static void populateSubstitutions(const Operator &op, const char *attrGet,
518                                   const char *operandGet, const char *resultGet,
519                                   FmtContext &ctx) {
520   // Populate substitutions for attributes and named operands.
521   for (const auto &namedAttr : op.getAttributes())
522     ctx.addSubst(namedAttr.name,
523                  formatv("{0}(\"{1}\")", attrGet, namedAttr.name));
524   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
525     auto &value = op.getOperand(i);
526     if (value.name.empty())
527       continue;
528 
529     if (value.isVariadic())
530       ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i));
531     else
532       ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i));
533   }
534 
535   // Populate substitutions for results.
536   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
537     auto &value = op.getResult(i);
538     if (value.name.empty())
539       continue;
540 
541     if (value.isVariadic())
542       ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i));
543     else
544       ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i));
545   }
546 }
547 
548 // Generate attribute verification. If emitVerificationRequiringOp is set then
549 // only verification for attributes whose value depend on op being known are
550 // emitted, else only verification that doesn't depend on the op being known are
551 // generated.
552 // - emitErrorPrefix is the prefix for the error emitting call which consists
553 //   of the entire function call up to start of error message fragment;
554 // - emitVerificationRequiringOp specifies whether verification should be
555 //   emitted for verification that require the op to exist;
genAttributeVerifier(const Operator & op,const char * attrGet,const Twine & emitErrorPrefix,bool emitVerificationRequiringOp,FmtContext & ctx,OpMethodBody & body)556 static void genAttributeVerifier(const Operator &op, const char *attrGet,
557                                  const Twine &emitErrorPrefix,
558                                  bool emitVerificationRequiringOp,
559                                  FmtContext &ctx, OpMethodBody &body) {
560   for (const auto &namedAttr : op.getAttributes()) {
561     const auto &attr = namedAttr.attr;
562     if (attr.isDerivedAttr())
563       continue;
564 
565     auto attrName = namedAttr.name;
566     bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
567     auto attrPred = attr.getPredicate();
568     auto condition = attrPred.isNull() ? "" : attrPred.getCondition();
569     // There is a condition to emit only if the use of $_op and whether to
570     // emit verifications for op matches.
571     bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^
572                                emitVerificationRequiringOp);
573 
574     // Prefix with `tblgen_` to avoid hiding the attribute accessor.
575     auto varName = tblgenNamePrefix + attrName;
576 
577     // If the attribute is
578     //  1. Required (not allowed missing) and not in op verification, or
579     //  2. Has a condition that will get verified
580     // then the variable will be used.
581     //
582     // Therefore, for optional attributes whose verification requires that an
583     // op already exists for verification/emitVerificationRequiringOp is set
584     // has nothing that can be verified here.
585     if ((allowMissingAttr || emitVerificationRequiringOp) &&
586         !hasConditionToEmit)
587       continue;
588 
589     body << formatv("  {\n  auto {0} = {1}(\"{2}\");\n", varName, attrGet,
590                     attrName);
591 
592     if (!emitVerificationRequiringOp && !allowMissingAttr) {
593       body << "  if (!" << varName << ") return " << emitErrorPrefix
594            << "\"requires attribute '" << attrName << "'\");\n";
595     }
596 
597     if (!hasConditionToEmit) {
598       body << "  }\n";
599       continue;
600     }
601 
602     if (allowMissingAttr) {
603       // If the attribute has a default value, then only verify the predicate if
604       // set. This does effectively assume that the default value is valid.
605       // TODO: verify the debug value is valid (perhaps in debug mode only).
606       body << "  if (" << varName << ") {\n";
607     }
608 
609     body << tgfmt("    if (!($0)) return $1\"attribute '$2' "
610                   "failed to satisfy constraint: $3\");\n",
611                   /*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)),
612                   emitErrorPrefix, attrName, attr.getSummary());
613     if (allowMissingAttr)
614       body << "  }\n";
615     body << "  }\n";
616   }
617 }
618 
OpEmitter(const Operator & op,const StaticVerifierFunctionEmitter & staticVerifierEmitter)619 OpEmitter::OpEmitter(const Operator &op,
620                      const StaticVerifierFunctionEmitter &staticVerifierEmitter)
621     : def(op.getDef()), op(op),
622       opClass(op.getCppClassName(), op.getExtraClassDeclaration()),
623       staticVerifierEmitter(staticVerifierEmitter) {
624   verifyCtx.withOp("(*this->getOperation())");
625   verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
626 
627   genTraits();
628 
629   // Generate C++ code for various op methods. The order here determines the
630   // methods in the generated file.
631   genAttrNameGetters();
632   genOpAsmInterface();
633   genOpNameGetter();
634   genNamedOperandGetters();
635   genNamedOperandSetters();
636   genNamedResultGetters();
637   genNamedRegionGetters();
638   genNamedSuccessorGetters();
639   genAttrGetters();
640   genAttrSetters();
641   genOptionalAttrRemovers();
642   genBuilder();
643   genParser();
644   genPrinter();
645   genVerifier();
646   genCanonicalizerDecls();
647   genFolderDecls();
648   genTypeInterfaceMethods();
649   genOpInterfaceMethods();
650   generateOpFormat(op, opClass);
651   genSideEffectInterfaceMethods();
652 }
653 
emitDecl(const Operator & op,raw_ostream & os,const StaticVerifierFunctionEmitter & staticVerifierEmitter)654 void OpEmitter::emitDecl(
655     const Operator &op, raw_ostream &os,
656     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
657   OpEmitter(op, staticVerifierEmitter).emitDecl(os);
658 }
659 
emitDef(const Operator & op,raw_ostream & os,const StaticVerifierFunctionEmitter & staticVerifierEmitter)660 void OpEmitter::emitDef(
661     const Operator &op, raw_ostream &os,
662     const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
663   OpEmitter(op, staticVerifierEmitter).emitDef(os);
664 }
665 
emitDecl(raw_ostream & os)666 void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
667 
emitDef(raw_ostream & os)668 void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
669 
genAttrNameGetters()670 void OpEmitter::genAttrNameGetters() {
671   // Enumerate the attribute names of this op, assigning each a relative
672   // ordering.
673   auto addAttrName = [&](StringRef name) {
674     unsigned index = attributeNames.size();
675     attributeNames.insert({name, index});
676   };
677   for (const NamedAttribute &namedAttr : op.getAttributes())
678     addAttrName(namedAttr.name);
679   // Include key attributes from several traits as implicitly registered.
680   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
681     addAttrName("operand_segment_sizes");
682   if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
683     addAttrName("result_segment_sizes");
684 
685   // Emit the getAttributeNames method.
686   {
687     auto *method = opClass.addMethodAndPrune(
688         "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames",
689         OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Inline));
690     auto &body = method->body();
691     if (attributeNames.empty()) {
692       body << "  return {};";
693     } else {
694       body << "  static ::llvm::StringRef attrNames[] = {";
695       llvm::interleaveComma(llvm::make_first_range(attributeNames), body,
696                             [&](StringRef attrName) {
697                               body << "::llvm::StringRef(\"" << attrName
698                                    << "\")";
699                             });
700       body << "};\n  return ::llvm::makeArrayRef(attrNames);";
701     }
702   }
703   if (attributeNames.empty())
704     return;
705 
706   // Emit the getAttributeNameForIndex methods.
707   {
708     auto *method = opClass.addMethodAndPrune(
709         "::mlir::Identifier", "getAttributeNameForIndex",
710         OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline),
711         "unsigned", "index");
712     method->body()
713         << "  return getAttributeNameForIndex((*this)->getName(), index);";
714   }
715   {
716     auto *method = opClass.addMethodAndPrune(
717         "::mlir::Identifier", "getAttributeNameForIndex",
718         OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline |
719                            OpMethod::MP_Static),
720         "::mlir::OperationName name, unsigned index");
721     method->body() << "assert(index < " << attributeNames.size()
722                    << " && \"invalid attribute index\");\n"
723                       "  return name.getAbstractOperation()"
724                       "->getAttributeNames()[index];";
725   }
726 
727   // Generate the <attr>AttrName methods, that expose the attribute names to
728   // users.
729   const char *attrNameMethodBody = "  return getAttributeNameForIndex({0});";
730   for (const std::pair<StringRef, unsigned> &attrIt : attributeNames) {
731     std::string methodName = (attrIt.first + "AttrName").str();
732 
733     // Generate the non-static variant.
734     {
735       auto *method =
736           opClass.addMethodAndPrune("::mlir::Identifier", methodName,
737                                     OpMethod::Property(OpMethod::MP_Inline));
738       method->body() << llvm::formatv(attrNameMethodBody, attrIt.second).str();
739     }
740 
741     // Generate the static variant.
742     {
743       auto *method = opClass.addMethodAndPrune(
744           "::mlir::Identifier", methodName,
745           OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static),
746           "::mlir::OperationName", "name");
747       method->body() << llvm::formatv(attrNameMethodBody,
748                                       "name, " + Twine(attrIt.second))
749                             .str();
750     }
751   }
752 }
753 
getAttrNameIndex(StringRef attrName) const754 unsigned OpEmitter::getAttrNameIndex(StringRef attrName) const {
755   auto it = attributeNames.find(attrName);
756   assert(it != attributeNames.end() && "expected attribute name to have been "
757                                        "registered in genAttrNameGetters");
758   return it->second;
759 }
760 
genAttrGetters()761 void OpEmitter::genAttrGetters() {
762   FmtContext fctx;
763   fctx.withBuilder("::mlir::Builder((*this)->getContext())");
764 
765   // Emit the derived attribute body.
766   auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
767     if (auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name))
768       method->body() << "  " << attr.getDerivedCodeBody() << "\n";
769   };
770 
771   // Emit with return type specified.
772   auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
773     auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
774     auto &body = method->body();
775     body << "  auto attr = " << name << "Attr();\n";
776     if (attr.hasDefaultValue()) {
777       // Returns the default value if not set.
778       // TODO: this is inefficient, we are recreating the attribute for every
779       // call. This should be set instead.
780       std::string defaultValue = std::string(
781           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
782       body << "    if (!attr)\n      return "
783            << tgfmt(attr.getConvertFromStorageCall(),
784                     &fctx.withSelf(defaultValue))
785            << ";\n";
786     }
787     body << "  return "
788          << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
789          << ";\n";
790   };
791 
792   // Generate raw named accessor type. This is a wrapper class that allows
793   // referring to the attributes via accessors instead of having to use
794   // the string interface for better compile time verification.
795   auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
796     auto *method =
797         opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
798     if (!method)
799       return;
800     auto &body = method->body();
801     body << "  return (*this)->getAttr(" << name << "AttrName()).template ";
802     if (attr.isOptional() || attr.hasDefaultValue())
803       body << "dyn_cast_or_null<";
804     else
805       body << "cast<";
806     body << attr.getStorageType() << ">();";
807   };
808 
809   for (const NamedAttribute &namedAttr : op.getAttributes()) {
810     if (namedAttr.attr.isDerivedAttr()) {
811       emitDerivedAttr(namedAttr.name, namedAttr.attr);
812     } else {
813       emitAttrWithStorageType(namedAttr.name, namedAttr.attr);
814       emitAttrWithReturnType(namedAttr.name, namedAttr.attr);
815     }
816   }
817 
818   auto derivedAttrs = make_filter_range(op.getAttributes(),
819                                         [](const NamedAttribute &namedAttr) {
820                                           return namedAttr.attr.isDerivedAttr();
821                                         });
822   if (!derivedAttrs.empty()) {
823     opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
824     // Generate helper method to query whether a named attribute is a derived
825     // attribute. This enables, for example, avoiding adding an attribute that
826     // overlaps with a derived attribute.
827     {
828       auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute",
829                                                OpMethod::MP_Static,
830                                                "::llvm::StringRef", "name");
831       auto &body = method->body();
832       for (auto namedAttr : derivedAttrs)
833         body << "  if (name == \"" << namedAttr.name << "\") return true;\n";
834       body << " return false;";
835     }
836     // Generate method to materialize derived attributes as a DictionaryAttr.
837     {
838       auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr",
839                                                "materializeDerivedAttributes");
840       auto &body = method->body();
841 
842       auto nonMaterializable =
843           make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
844             return namedAttr.attr.getConvertFromStorageCall().empty();
845           });
846       if (!nonMaterializable.empty()) {
847         std::string attrs;
848         llvm::raw_string_ostream os(attrs);
849         interleaveComma(nonMaterializable, os,
850                         [&](const NamedAttribute &attr) { os << attr.name; });
851         PrintWarning(
852             op.getLoc(),
853             formatv(
854                 "op has non-materializable derived attributes '{0}', skipping",
855                 os.str()));
856         body << formatv("  emitOpError(\"op has non-materializable derived "
857                         "attributes '{0}'\");\n",
858                         attrs);
859         body << "  return nullptr;";
860         return;
861       }
862 
863       body << "  ::mlir::MLIRContext* ctx = getContext();\n";
864       body << "  ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
865       body << "  return ::mlir::DictionaryAttr::get(";
866       body << "  ctx, {\n";
867       interleave(
868           derivedAttrs, body,
869           [&](const NamedAttribute &namedAttr) {
870             auto tmpl = namedAttr.attr.getConvertFromStorageCall();
871             body << "    {" << namedAttr.name << "AttrName(),\n"
872                  << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()")
873                                      .withBuilder("odsBuilder")
874                                      .addSubst("_ctx", "ctx"))
875                  << "}";
876           },
877           ",\n");
878       body << "});";
879     }
880   }
881 }
882 
genAttrSetters()883 void OpEmitter::genAttrSetters() {
884   // Generate raw named setter type. This is a wrapper class that allows setting
885   // to the attributes via setters instead of having to use the string interface
886   // for better compile time verification.
887   auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
888     auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
889                                              attr.getStorageType(), "attr");
890     if (method)
891       method->body() << "  (*this)->setAttr(" << name << "AttrName(), attr);";
892   };
893 
894   for (const NamedAttribute &namedAttr : op.getAttributes())
895     if (!namedAttr.attr.isDerivedAttr())
896       emitAttrWithStorageType(namedAttr.name, namedAttr.attr);
897 }
898 
genOptionalAttrRemovers()899 void OpEmitter::genOptionalAttrRemovers() {
900   // Generate methods for removing optional attributes, instead of having to
901   // use the string interface. Enables better compile time verification.
902   auto emitRemoveAttr = [&](StringRef name) {
903     auto upperInitial = name.take_front().upper();
904     auto suffix = name.drop_front();
905     auto *method = opClass.addMethodAndPrune(
906         "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
907     if (!method)
908       return;
909     method->body() << "  return (*this)->removeAttr(" << name << "AttrName());";
910   };
911 
912   for (const NamedAttribute &namedAttr : op.getAttributes())
913     if (namedAttr.attr.isOptional())
914       emitRemoveAttr(namedAttr.name);
915 }
916 
917 // Generates the code to compute the start and end index of an operand or result
918 // range.
919 template <typename RangeT>
920 static void
generateValueRangeStartAndEnd(Class & opClass,StringRef methodName,int numVariadic,int numNonVariadic,StringRef rangeSizeCall,bool hasAttrSegmentSize,StringRef sizeAttrInit,RangeT && odsValues)921 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
922                               int numVariadic, int numNonVariadic,
923                               StringRef rangeSizeCall, bool hasAttrSegmentSize,
924                               StringRef sizeAttrInit, RangeT &&odsValues) {
925   auto *method = opClass.addMethodAndPrune("std::pair<unsigned, unsigned>",
926                                            methodName, "unsigned", "index");
927   if (!method)
928     return;
929   auto &body = method->body();
930   if (numVariadic == 0) {
931     body << "  return {index, 1};\n";
932   } else if (hasAttrSegmentSize) {
933     body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
934   } else {
935     // Because the op can have arbitrarily interleaved variadic and non-variadic
936     // operands, we need to embed a list in the "sink" getter method for
937     // calculation at run-time.
938     llvm::SmallVector<StringRef, 4> isVariadic;
939     isVariadic.reserve(llvm::size(odsValues));
940     for (auto &it : odsValues)
941       isVariadic.push_back(it.isVariableLength() ? "true" : "false");
942     std::string isVariadicList = llvm::join(isVariadic, ", ");
943     body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
944                     numNonVariadic, numVariadic, rangeSizeCall, "operand");
945   }
946 }
947 
948 // Generates the named operand getter methods for the given Operator `op` and
949 // puts them in `opClass`.  Uses `rangeType` as the return type of getters that
950 // return a range of operands (individual operands are `Value ` and each
951 // element in the range must also be `Value `); use `rangeBeginCall` to get
952 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
953 // obtain the number of operands. `getOperandCallPattern` contains the code
954 // necessary to obtain a single operand whose position will be substituted
955 // instead of
956 // "{0}" marker in the pattern.  Note that the pattern should work for any kind
957 // of ops, in particular for one-operand ops that may not have the
958 // `getOperand(unsigned)` method.
generateNamedOperandGetters(const Operator & op,Class & opClass,StringRef sizeAttrInit,StringRef rangeType,StringRef rangeBeginCall,StringRef rangeSizeCall,StringRef getOperandCallPattern)959 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
960                                         StringRef sizeAttrInit,
961                                         StringRef rangeType,
962                                         StringRef rangeBeginCall,
963                                         StringRef rangeSizeCall,
964                                         StringRef getOperandCallPattern) {
965   const int numOperands = op.getNumOperands();
966   const int numVariadicOperands = op.getNumVariableLengthOperands();
967   const int numNormalOperands = numOperands - numVariadicOperands;
968 
969   const auto *sameVariadicSize =
970       op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
971   const auto *attrSizedOperands =
972       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
973 
974   if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
975     PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
976                                  "specification over their sizes");
977   }
978 
979   if (numVariadicOperands < 2 && attrSizedOperands) {
980     PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
981                                  "to use 'AttrSizedOperandSegments' trait");
982   }
983 
984   if (attrSizedOperands && sameVariadicSize) {
985     PrintFatalError(op.getLoc(),
986                     "op cannot have both 'AttrSizedOperandSegments' and "
987                     "'SameVariadicOperandSize' traits");
988   }
989 
990   // First emit a few "sink" getter methods upon which we layer all nicer named
991   // getter methods.
992   generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
993                                 numVariadicOperands, numNormalOperands,
994                                 rangeSizeCall, attrSizedOperands, sizeAttrInit,
995                                 const_cast<Operator &>(op).getOperands());
996 
997   auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned",
998                                       "index");
999   auto &body = m->body();
1000   body << formatv(valueRangeReturnCode, rangeBeginCall,
1001                   "getODSOperandIndexAndLength(index)");
1002 
1003   // Then we emit nicer named getter methods by redirecting to the "sink" getter
1004   // method.
1005   for (int i = 0; i != numOperands; ++i) {
1006     const auto &operand = op.getOperand(i);
1007     if (operand.name.empty())
1008       continue;
1009 
1010     if (operand.isOptional()) {
1011       m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
1012       m->body()
1013           << "  auto operands = getODSOperands(" << i << ");\n"
1014           << "  return operands.empty() ? ::mlir::Value() : *operands.begin();";
1015     } else if (operand.isVariadic()) {
1016       m = opClass.addMethodAndPrune(rangeType, operand.name);
1017       m->body() << "  return getODSOperands(" << i << ");";
1018     } else {
1019       m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
1020       m->body() << "  return *getODSOperands(" << i << ").begin();";
1021     }
1022   }
1023 }
1024 
genNamedOperandGetters()1025 void OpEmitter::genNamedOperandGetters() {
1026   // Build the code snippet used for initializing the operand_segment_sizes
1027   // array.
1028   std::string attrSizeInitCode;
1029   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1030     attrSizeInitCode =
1031         formatv(opSegmentSizeAttrInitCode, "operand_segment_sizesAttrName()")
1032             .str();
1033   }
1034 
1035   generateNamedOperandGetters(
1036       op, opClass,
1037       /*sizeAttrInit=*/attrSizeInitCode,
1038       /*rangeType=*/"::mlir::Operation::operand_range",
1039       /*rangeBeginCall=*/"getOperation()->operand_begin()",
1040       /*rangeSizeCall=*/"getOperation()->getNumOperands()",
1041       /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
1042 }
1043 
genNamedOperandSetters()1044 void OpEmitter::genNamedOperandSetters() {
1045   auto *attrSizedOperands =
1046       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1047   for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
1048     const auto &operand = op.getOperand(i);
1049     if (operand.name.empty())
1050       continue;
1051     auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange",
1052                                         (operand.name + "Mutable").str());
1053     auto &body = m->body();
1054     body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
1055          << "  return ::mlir::MutableOperandRange(getOperation(), "
1056             "range.first, range.second";
1057     if (attrSizedOperands)
1058       body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
1059            << "u, *getOperation()->getAttrDictionary().getNamed("
1060               "operand_segment_sizesAttrName()))";
1061     body << ");\n";
1062   }
1063 }
1064 
genNamedResultGetters()1065 void OpEmitter::genNamedResultGetters() {
1066   const int numResults = op.getNumResults();
1067   const int numVariadicResults = op.getNumVariableLengthResults();
1068   const int numNormalResults = numResults - numVariadicResults;
1069 
1070   // If we have more than one variadic results, we need more complicated logic
1071   // to calculate the value range for each result.
1072 
1073   const auto *sameVariadicSize =
1074       op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
1075   const auto *attrSizedResults =
1076       op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
1077 
1078   if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
1079     PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
1080                                  "specification over their sizes");
1081   }
1082 
1083   if (numVariadicResults < 2 && attrSizedResults) {
1084     PrintFatalError(op.getLoc(), "op must have at least two variadic results "
1085                                  "to use 'AttrSizedResultSegments' trait");
1086   }
1087 
1088   if (attrSizedResults && sameVariadicSize) {
1089     PrintFatalError(op.getLoc(),
1090                     "op cannot have both 'AttrSizedResultSegments' and "
1091                     "'SameVariadicResultSize' traits");
1092   }
1093 
1094   // Build the initializer string for the result segment size attribute.
1095   std::string attrSizeInitCode;
1096   if (attrSizedResults) {
1097     attrSizeInitCode =
1098         formatv(opSegmentSizeAttrInitCode, "result_segment_sizesAttrName()")
1099             .str();
1100   }
1101 
1102   generateValueRangeStartAndEnd(
1103       opClass, "getODSResultIndexAndLength", numVariadicResults,
1104       numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
1105       attrSizeInitCode, op.getResults());
1106 
1107   auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
1108                                       "getODSResults", "unsigned", "index");
1109   m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
1110                        "getODSResultIndexAndLength(index)");
1111 
1112   for (int i = 0; i != numResults; ++i) {
1113     const auto &result = op.getResult(i);
1114     if (result.name.empty())
1115       continue;
1116 
1117     if (result.isOptional()) {
1118       m = opClass.addMethodAndPrune("::mlir::Value", result.name);
1119       m->body()
1120           << "  auto results = getODSResults(" << i << ");\n"
1121           << "  return results.empty() ? ::mlir::Value() : *results.begin();";
1122     } else if (result.isVariadic()) {
1123       m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
1124                                     result.name);
1125       m->body() << "  return getODSResults(" << i << ");";
1126     } else {
1127       m = opClass.addMethodAndPrune("::mlir::Value", result.name);
1128       m->body() << "  return *getODSResults(" << i << ").begin();";
1129     }
1130   }
1131 }
1132 
genNamedRegionGetters()1133 void OpEmitter::genNamedRegionGetters() {
1134   unsigned numRegions = op.getNumRegions();
1135   for (unsigned i = 0; i < numRegions; ++i) {
1136     const auto &region = op.getRegion(i);
1137     if (region.name.empty())
1138       continue;
1139 
1140     // Generate the accessors for a variadic region.
1141     if (region.isVariadic()) {
1142       auto *m = opClass.addMethodAndPrune(
1143           "::mlir::MutableArrayRef<::mlir::Region>", region.name);
1144       m->body() << formatv("  return (*this)->getRegions().drop_front({0});",
1145                            i);
1146       continue;
1147     }
1148 
1149     auto *m = opClass.addMethodAndPrune("::mlir::Region &", region.name);
1150     m->body() << formatv("  return (*this)->getRegion({0});", i);
1151   }
1152 }
1153 
genNamedSuccessorGetters()1154 void OpEmitter::genNamedSuccessorGetters() {
1155   unsigned numSuccessors = op.getNumSuccessors();
1156   for (unsigned i = 0; i < numSuccessors; ++i) {
1157     const NamedSuccessor &successor = op.getSuccessor(i);
1158     if (successor.name.empty())
1159       continue;
1160 
1161     // Generate the accessors for a variadic successor list.
1162     if (successor.isVariadic()) {
1163       auto *m =
1164           opClass.addMethodAndPrune("::mlir::SuccessorRange", successor.name);
1165       m->body() << formatv(
1166           "  return {std::next((*this)->successor_begin(), {0}), "
1167           "(*this)->successor_end()};",
1168           i);
1169       continue;
1170     }
1171 
1172     auto *m = opClass.addMethodAndPrune("::mlir::Block *", successor.name);
1173     m->body() << formatv("  return (*this)->getSuccessor({0});", i);
1174   }
1175 }
1176 
canGenerateUnwrappedBuilder(Operator & op)1177 static bool canGenerateUnwrappedBuilder(Operator &op) {
1178   // If this op does not have native attributes at all, return directly to avoid
1179   // redefining builders.
1180   if (op.getNumNativeAttributes() == 0)
1181     return false;
1182 
1183   bool canGenerate = false;
1184   // We are generating builders that take raw values for attributes. We need to
1185   // make sure the native attributes have a meaningful "unwrapped" value type
1186   // different from the wrapped mlir::Attribute type to avoid redefining
1187   // builders. This checks for the op has at least one such native attribute.
1188   for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
1189     NamedAttribute &namedAttr = op.getAttribute(i);
1190     if (canUseUnwrappedRawValue(namedAttr.attr)) {
1191       canGenerate = true;
1192       break;
1193     }
1194   }
1195   return canGenerate;
1196 }
1197 
canInferType(Operator & op)1198 static bool canInferType(Operator &op) {
1199   return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
1200          op.getNumRegions() == 0;
1201 }
1202 
genSeparateArgParamBuilder()1203 void OpEmitter::genSeparateArgParamBuilder() {
1204   SmallVector<AttrParamKind, 2> attrBuilderType;
1205   attrBuilderType.push_back(AttrParamKind::WrappedAttr);
1206   if (canGenerateUnwrappedBuilder(op))
1207     attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
1208 
1209   // Emit with separate builders with or without unwrapped attributes and/or
1210   // inferring result type.
1211   auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
1212                   bool inferType) {
1213     llvm::SmallVector<OpMethodParameter, 4> paramList;
1214     llvm::SmallVector<std::string, 4> resultNames;
1215     buildParamList(paramList, resultNames, paramKind, attrType);
1216 
1217     auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1218                                         std::move(paramList));
1219     // If the builder is redundant, skip generating the method.
1220     if (!m)
1221       return;
1222     auto &body = m->body();
1223     genCodeForAddingArgAndRegionForBuilder(
1224         body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
1225 
1226     // Push all result types to the operation state
1227 
1228     if (inferType) {
1229       // Generate builder that infers type too.
1230       // TODO: Subsume this with general checking if type can be
1231       // inferred automatically.
1232       // TODO: Expand to handle regions.
1233       body << formatv(R"(
1234         ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1235         if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1236                       {1}.location, {1}.operands,
1237                       {1}.attributes.getDictionary({1}.getContext()),
1238                       /*regions=*/{{}, inferredReturnTypes)))
1239           {1}.addTypes(inferredReturnTypes);
1240         else
1241           ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1242                       opClass.getClassName(), builderOpState);
1243       return;
1244     }
1245 
1246     switch (paramKind) {
1247     case TypeParamKind::None:
1248       return;
1249     case TypeParamKind::Separate:
1250       for (int i = 0, e = op.getNumResults(); i < e; ++i) {
1251         if (op.getResult(i).isOptional())
1252           body << "  if (" << resultNames[i] << ")\n  ";
1253         body << "  " << builderOpState << ".addTypes(" << resultNames[i]
1254              << ");\n";
1255       }
1256       return;
1257     case TypeParamKind::Collective: {
1258       int numResults = op.getNumResults();
1259       int numVariadicResults = op.getNumVariableLengthResults();
1260       int numNonVariadicResults = numResults - numVariadicResults;
1261       bool hasVariadicResult = numVariadicResults != 0;
1262 
1263       // Avoid emitting "resultTypes.size() >= 0u" which is always true.
1264       if (!(hasVariadicResult && numNonVariadicResults == 0))
1265         body << "  "
1266              << "assert(resultTypes.size() "
1267              << (hasVariadicResult ? ">=" : "==") << " "
1268              << numNonVariadicResults
1269              << "u && \"mismatched number of results\");\n";
1270       body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1271     }
1272       return;
1273     }
1274     llvm_unreachable("unhandled TypeParamKind");
1275   };
1276 
1277   // Some of the build methods generated here may be ambiguous, but TableGen's
1278   // ambiguous function detection will elide those ones.
1279   for (auto attrType : attrBuilderType) {
1280     emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
1281     if (canInferType(op))
1282       emit(attrType, TypeParamKind::None, /*inferType=*/true);
1283     emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
1284   }
1285 }
1286 
genUseOperandAsResultTypeCollectiveParamBuilder()1287 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
1288   int numResults = op.getNumResults();
1289 
1290   // Signature
1291   llvm::SmallVector<OpMethodParameter, 4> paramList;
1292   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1293   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1294   paramList.emplace_back("::mlir::ValueRange", "operands");
1295   // Provide default value for `attributes` when its the last parameter
1296   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1297   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1298                          "attributes", attributesDefaultValue);
1299   if (op.getNumVariadicRegions())
1300     paramList.emplace_back("unsigned", "numRegions");
1301 
1302   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1303                                       std::move(paramList));
1304   // If the builder is redundant, skip generating the method
1305   if (!m)
1306     return;
1307   auto &body = m->body();
1308 
1309   // Operands
1310   body << "  " << builderOpState << ".addOperands(operands);\n";
1311 
1312   // Attributes
1313   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1314 
1315   // Create the correct number of regions
1316   if (int numRegions = op.getNumRegions()) {
1317     body << llvm::formatv(
1318         "  for (unsigned i = 0; i != {0}; ++i)\n",
1319         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1320     body << "    (void)" << builderOpState << ".addRegion();\n";
1321   }
1322 
1323   // Result types
1324   SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
1325   body << "  " << builderOpState << ".addTypes({"
1326        << llvm::join(resultTypes, ", ") << "});\n\n";
1327 }
1328 
genInferredTypeCollectiveParamBuilder()1329 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
1330   // TODO: Expand to support regions.
1331   SmallVector<OpMethodParameter, 4> paramList;
1332   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1333   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1334   paramList.emplace_back("::mlir::ValueRange", "operands");
1335   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1336                          "attributes", "{}");
1337   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1338                                       std::move(paramList));
1339   // If the builder is redundant, skip generating the method
1340   if (!m)
1341     return;
1342   auto &body = m->body();
1343 
1344   int numResults = op.getNumResults();
1345   int numVariadicResults = op.getNumVariableLengthResults();
1346   int numNonVariadicResults = numResults - numVariadicResults;
1347 
1348   int numOperands = op.getNumOperands();
1349   int numVariadicOperands = op.getNumVariableLengthOperands();
1350   int numNonVariadicOperands = numOperands - numVariadicOperands;
1351 
1352   // Operands
1353   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1354     body << "  assert(operands.size()"
1355          << (numVariadicOperands != 0 ? " >= " : " == ")
1356          << numNonVariadicOperands
1357          << "u && \"mismatched number of parameters\");\n";
1358   body << "  " << builderOpState << ".addOperands(operands);\n";
1359   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1360 
1361   // Create the correct number of regions
1362   if (int numRegions = op.getNumRegions()) {
1363     body << llvm::formatv(
1364         "  for (unsigned i = 0; i != {0}; ++i)\n",
1365         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1366     body << "    (void)" << builderOpState << ".addRegion();\n";
1367   }
1368 
1369   // Result types
1370   body << formatv(R"(
1371     ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1372     if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1373                   {1}.location, operands,
1374                   {1}.attributes.getDictionary({1}.getContext()),
1375                   /*regions=*/{{}, inferredReturnTypes))) {{)",
1376                   opClass.getClassName(), builderOpState);
1377   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1378     body << "  assert(inferredReturnTypes.size()"
1379          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1380          << "u && \"mismatched number of return types\");\n";
1381   body << "      " << builderOpState << ".addTypes(inferredReturnTypes);";
1382 
1383   body << formatv(R"(
1384     } else
1385       ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1386                   opClass.getClassName(), builderOpState);
1387 }
1388 
genUseOperandAsResultTypeSeparateParamBuilder()1389 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
1390   llvm::SmallVector<OpMethodParameter, 4> paramList;
1391   llvm::SmallVector<std::string, 4> resultNames;
1392   buildParamList(paramList, resultNames, TypeParamKind::None);
1393 
1394   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1395                                       std::move(paramList));
1396   // If the builder is redundant, skip generating the method
1397   if (!m)
1398     return;
1399   auto &body = m->body();
1400   genCodeForAddingArgAndRegionForBuilder(body);
1401 
1402   auto numResults = op.getNumResults();
1403   if (numResults == 0)
1404     return;
1405 
1406   // Push all result types to the operation state
1407   const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
1408   std::string resultType =
1409       formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
1410   body << "  " << builderOpState << ".addTypes({" << resultType;
1411   for (int i = 1; i != numResults; ++i)
1412     body << ", " << resultType;
1413   body << "});\n\n";
1414 }
1415 
genUseAttrAsResultTypeBuilder()1416 void OpEmitter::genUseAttrAsResultTypeBuilder() {
1417   SmallVector<OpMethodParameter, 4> paramList;
1418   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1419   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1420   paramList.emplace_back("::mlir::ValueRange", "operands");
1421   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1422                          "attributes", "{}");
1423   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1424                                       std::move(paramList));
1425   // If the builder is redundant, skip generating the method
1426   if (!m)
1427     return;
1428 
1429   auto &body = m->body();
1430 
1431   // Push all result types to the operation state
1432   std::string resultType;
1433   const auto &namedAttr = op.getAttribute(0);
1434 
1435   body << "  auto attrName = " << namedAttr.name << "AttrName("
1436        << builderOpState
1437        << ".name);\n"
1438           "  for (auto attr : attributes) {\n"
1439           "    if (attr.first != attrName) continue;\n";
1440   if (namedAttr.attr.isTypeAttr()) {
1441     resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()";
1442   } else {
1443     resultType = "attr.second.getType()";
1444   }
1445 
1446   // Operands
1447   body << "  " << builderOpState << ".addOperands(operands);\n";
1448 
1449   // Attributes
1450   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1451 
1452   // Result types
1453   SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1454   body << "    " << builderOpState << ".addTypes({"
1455        << llvm::join(resultTypes, ", ") << "});\n";
1456   body << "  }\n";
1457 }
1458 
1459 /// Returns a signature of the builder. Updates the context `fctx` to enable
1460 /// replacement of $_builder and $_state in the body.
getBuilderSignature(const Builder & builder)1461 static std::string getBuilderSignature(const Builder &builder) {
1462   ArrayRef<Builder::Parameter> params(builder.getParameters());
1463 
1464   // Inject builder and state arguments.
1465   llvm::SmallVector<std::string, 8> arguments;
1466   arguments.reserve(params.size() + 2);
1467   arguments.push_back(
1468       llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str());
1469   arguments.push_back(
1470       llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());
1471 
1472   for (unsigned i = 0, e = params.size(); i < e; ++i) {
1473     // If no name is provided, generate one.
1474     Optional<StringRef> paramName = params[i].getName();
1475     std::string name =
1476         paramName ? paramName->str() : "odsArg" + std::to_string(i);
1477 
1478     std::string defaultValue;
1479     if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
1480       defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str();
1481     arguments.push_back(
1482         llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue)
1483             .str());
1484   }
1485 
1486   return llvm::join(arguments, ", ");
1487 }
1488 
genBuilder()1489 void OpEmitter::genBuilder() {
1490   // Handle custom builders if provided.
1491   for (const Builder &builder : op.getBuilders()) {
1492     std::string paramStr = getBuilderSignature(builder);
1493 
1494     Optional<StringRef> body = builder.getBody();
1495     OpMethod::Property properties =
1496         body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
1497     auto *method =
1498         opClass.addMethodAndPrune("void", "build", properties, paramStr);
1499 
1500     FmtContext fctx;
1501     fctx.withBuilder(odsBuilder);
1502     fctx.addSubst("_state", builderOpState);
1503     if (body)
1504       method->body() << tgfmt(*body, &fctx);
1505   }
1506 
1507   // Generate default builders that requires all result type, operands, and
1508   // attributes as parameters.
1509   if (op.skipDefaultBuilders())
1510     return;
1511 
1512   // We generate three classes of builders here:
1513   // 1. one having a stand-alone parameter for each operand / attribute, and
1514   genSeparateArgParamBuilder();
1515   // 2. one having an aggregated parameter for all result types / operands /
1516   //    attributes, and
1517   genCollectiveParamBuilder();
1518   // 3. one having a stand-alone parameter for each operand and attribute,
1519   //    use the first operand or attribute's type as all result types
1520   //    to facilitate different call patterns.
1521   if (op.getNumVariableLengthResults() == 0) {
1522     if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
1523       genUseOperandAsResultTypeSeparateParamBuilder();
1524       genUseOperandAsResultTypeCollectiveParamBuilder();
1525     }
1526     if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
1527       genUseAttrAsResultTypeBuilder();
1528   }
1529 }
1530 
genCollectiveParamBuilder()1531 void OpEmitter::genCollectiveParamBuilder() {
1532   int numResults = op.getNumResults();
1533   int numVariadicResults = op.getNumVariableLengthResults();
1534   int numNonVariadicResults = numResults - numVariadicResults;
1535 
1536   int numOperands = op.getNumOperands();
1537   int numVariadicOperands = op.getNumVariableLengthOperands();
1538   int numNonVariadicOperands = numOperands - numVariadicOperands;
1539 
1540   SmallVector<OpMethodParameter, 4> paramList;
1541   paramList.emplace_back("::mlir::OpBuilder &", "");
1542   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1543   paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1544   paramList.emplace_back("::mlir::ValueRange", "operands");
1545   // Provide default value for `attributes` when its the last parameter
1546   StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1547   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1548                          "attributes", attributesDefaultValue);
1549   if (op.getNumVariadicRegions())
1550     paramList.emplace_back("unsigned", "numRegions");
1551 
1552   auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1553                                       std::move(paramList));
1554   // If the builder is redundant, skip generating the method
1555   if (!m)
1556     return;
1557   auto &body = m->body();
1558 
1559   // Operands
1560   if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1561     body << "  assert(operands.size()"
1562          << (numVariadicOperands != 0 ? " >= " : " == ")
1563          << numNonVariadicOperands
1564          << "u && \"mismatched number of parameters\");\n";
1565   body << "  " << builderOpState << ".addOperands(operands);\n";
1566 
1567   // Attributes
1568   body << "  " << builderOpState << ".addAttributes(attributes);\n";
1569 
1570   // Create the correct number of regions
1571   if (int numRegions = op.getNumRegions()) {
1572     body << llvm::formatv(
1573         "  for (unsigned i = 0; i != {0}; ++i)\n",
1574         (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1575     body << "    (void)" << builderOpState << ".addRegion();\n";
1576   }
1577 
1578   // Result types
1579   if (numVariadicResults == 0 || numNonVariadicResults != 0)
1580     body << "  assert(resultTypes.size()"
1581          << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1582          << "u && \"mismatched number of return types\");\n";
1583   body << "  " << builderOpState << ".addTypes(resultTypes);\n";
1584 
1585   // Generate builder that infers type too.
1586   // TODO: Expand to handle regions and successors.
1587   if (canInferType(op) && op.getNumSuccessors() == 0)
1588     genInferredTypeCollectiveParamBuilder();
1589 }
1590 
buildParamList(SmallVectorImpl<OpMethodParameter> & paramList,SmallVectorImpl<std::string> & resultTypeNames,TypeParamKind typeParamKind,AttrParamKind attrParamKind)1591 void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
1592                                SmallVectorImpl<std::string> &resultTypeNames,
1593                                TypeParamKind typeParamKind,
1594                                AttrParamKind attrParamKind) {
1595   resultTypeNames.clear();
1596   auto numResults = op.getNumResults();
1597   resultTypeNames.reserve(numResults);
1598 
1599   paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1600   paramList.emplace_back("::mlir::OperationState &", builderOpState);
1601 
1602   switch (typeParamKind) {
1603   case TypeParamKind::None:
1604     break;
1605   case TypeParamKind::Separate: {
1606     // Add parameters for all return types
1607     for (int i = 0; i < numResults; ++i) {
1608       const auto &result = op.getResult(i);
1609       std::string resultName = std::string(result.name);
1610       if (resultName.empty())
1611         resultName = std::string(formatv("resultType{0}", i));
1612 
1613       StringRef type =
1614           result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
1615       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1616       if (result.isOptional())
1617         properties = OpMethodParameter::PP_Optional;
1618 
1619       paramList.emplace_back(type, resultName, properties);
1620       resultTypeNames.emplace_back(std::move(resultName));
1621     }
1622   } break;
1623   case TypeParamKind::Collective: {
1624     paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1625     resultTypeNames.push_back("resultTypes");
1626   } break;
1627   }
1628 
1629   // Add parameters for all arguments (operands and attributes).
1630 
1631   int numOperands = 0;
1632   int numAttrs = 0;
1633 
1634   int defaultValuedAttrStartIndex = op.getNumArgs();
1635   if (attrParamKind == AttrParamKind::UnwrappedValue) {
1636     // Calculate the start index from which we can attach default values in the
1637     // builder declaration.
1638     for (int i = op.getNumArgs() - 1; i >= 0; --i) {
1639       auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
1640       if (!namedAttr || !namedAttr->attr.hasDefaultValue())
1641         break;
1642 
1643       if (!canUseUnwrappedRawValue(namedAttr->attr))
1644         break;
1645 
1646       // Creating an APInt requires us to provide bitwidth, value, and
1647       // signedness, which is complicated compared to others. Similarly
1648       // for APFloat.
1649       // TODO: Adjust the 'returnType' field of such attributes
1650       // to support them.
1651       StringRef retType = namedAttr->attr.getReturnType();
1652       if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
1653         break;
1654 
1655       defaultValuedAttrStartIndex = i;
1656     }
1657   }
1658 
1659   for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
1660     auto argument = op.getArg(i);
1661     if (argument.is<tblgen::NamedTypeConstraint *>()) {
1662       const auto &operand = op.getOperand(numOperands);
1663       StringRef type =
1664           operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value";
1665       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1666       if (operand.isOptional())
1667         properties = OpMethodParameter::PP_Optional;
1668 
1669       paramList.emplace_back(type, getArgumentName(op, numOperands),
1670                              properties);
1671       ++numOperands;
1672     } else {
1673       const auto &namedAttr = op.getAttribute(numAttrs);
1674       const auto &attr = namedAttr.attr;
1675 
1676       OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1677       if (attr.isOptional())
1678         properties = OpMethodParameter::PP_Optional;
1679 
1680       StringRef type;
1681       switch (attrParamKind) {
1682       case AttrParamKind::WrappedAttr:
1683         type = attr.getStorageType();
1684         break;
1685       case AttrParamKind::UnwrappedValue:
1686         if (canUseUnwrappedRawValue(attr))
1687           type = attr.getReturnType();
1688         else
1689           type = attr.getStorageType();
1690         break;
1691       }
1692 
1693       std::string defaultValue;
1694       // Attach default value if requested and possible.
1695       if (attrParamKind == AttrParamKind::UnwrappedValue &&
1696           i >= defaultValuedAttrStartIndex) {
1697         bool isString = attr.getReturnType() == "::llvm::StringRef";
1698         if (isString)
1699           defaultValue.append("\"");
1700         defaultValue += attr.getDefaultValue();
1701         if (isString)
1702           defaultValue.append("\"");
1703       }
1704       paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
1705       ++numAttrs;
1706     }
1707   }
1708 
1709   /// Insert parameters for each successor.
1710   for (const NamedSuccessor &succ : op.getSuccessors()) {
1711     StringRef type =
1712         succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
1713     paramList.emplace_back(type, succ.name);
1714   }
1715 
1716   /// Insert parameters for variadic regions.
1717   for (const NamedRegion &region : op.getRegions())
1718     if (region.isVariadic())
1719       paramList.emplace_back("unsigned",
1720                              llvm::formatv("{0}Count", region.name).str());
1721 }
1722 
genCodeForAddingArgAndRegionForBuilder(OpMethodBody & body,bool isRawValueAttr)1723 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
1724                                                        bool isRawValueAttr) {
1725   // Push all operands to the result.
1726   for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
1727     std::string argName = getArgumentName(op, i);
1728     if (op.getOperand(i).isOptional())
1729       body << "  if (" << argName << ")\n  ";
1730     body << "  " << builderOpState << ".addOperands(" << argName << ");\n";
1731   }
1732 
1733   // If the operation has the operand segment size attribute, add it here.
1734   if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1735     body << "  " << builderOpState
1736          << ".addAttribute(operand_segment_sizesAttrName(" << builderOpState
1737          << ".name), "
1738          << "odsBuilder.getI32VectorAttr({";
1739     interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1740       if (op.getOperand(i).isOptional())
1741         body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
1742       else if (op.getOperand(i).isVariadic())
1743         body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
1744       else
1745         body << "1";
1746     });
1747     body << "}));\n";
1748   }
1749 
1750   // Push all attributes to the result.
1751   for (const auto &namedAttr : op.getAttributes()) {
1752     auto &attr = namedAttr.attr;
1753     if (!attr.isDerivedAttr()) {
1754       bool emitNotNullCheck = attr.isOptional();
1755       if (emitNotNullCheck)
1756         body << formatv("  if ({0}) ", namedAttr.name) << "{\n";
1757 
1758       if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
1759         // If this is a raw value, then we need to wrap it in an Attribute
1760         // instance.
1761         FmtContext fctx;
1762         fctx.withBuilder("odsBuilder");
1763 
1764         std::string builderTemplate =
1765             std::string(attr.getConstBuilderTemplate());
1766 
1767         // For StringAttr, its constant builder call will wrap the input in
1768         // quotes, which is correct for normal string literals, but incorrect
1769         // here given we use function arguments. So we need to strip the
1770         // wrapping quotes.
1771         if (StringRef(builderTemplate).contains("\"$0\""))
1772           builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
1773 
1774         std::string value =
1775             std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
1776         body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {2});\n",
1777                         builderOpState, namedAttr.name, value);
1778       } else {
1779         body << formatv("  {0}.addAttribute({1}AttrName({0}.name), {1});\n",
1780                         builderOpState, namedAttr.name);
1781       }
1782       if (emitNotNullCheck)
1783         body << "  }\n";
1784     }
1785   }
1786 
1787   // Create the correct number of regions.
1788   for (const NamedRegion &region : op.getRegions()) {
1789     if (region.isVariadic())
1790       body << formatv("  for (unsigned i = 0; i < {0}Count; ++i)\n  ",
1791                       region.name);
1792 
1793     body << "  (void)" << builderOpState << ".addRegion();\n";
1794   }
1795 
1796   // Push all successors to the result.
1797   for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
1798     body << formatv("  {0}.addSuccessors({1});\n", builderOpState,
1799                     namedSuccessor.name);
1800   }
1801 }
1802 
genCanonicalizerDecls()1803 void OpEmitter::genCanonicalizerDecls() {
1804   bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod");
1805   if (hasCanonicalizeMethod) {
1806     // static LogicResult FooOp::
1807     // canonicalize(FooOp op, PatternRewriter &rewriter);
1808     SmallVector<OpMethodParameter, 2> paramList;
1809     paramList.emplace_back(op.getCppClassName(), "op");
1810     paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
1811     opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize",
1812                               OpMethod::MP_StaticDeclaration,
1813                               std::move(paramList));
1814   }
1815 
1816   // We get a prototype for 'getCanonicalizationPatterns' if requested directly
1817   // or if using a 'canonicalize' method.
1818   bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer");
1819   if (!hasCanonicalizeMethod && !hasCanonicalizer)
1820     return;
1821 
1822   // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
1823   // method, but not implementing 'getCanonicalizationPatterns' manually.
1824   bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
1825 
1826   // Add a signature for getCanonicalizationPatterns if implemented by the
1827   // dialect or if synthesized to call 'canonicalize'.
1828   SmallVector<OpMethodParameter, 2> paramList;
1829   paramList.emplace_back("::mlir::RewritePatternSet &", "results");
1830   paramList.emplace_back("::mlir::MLIRContext *", "context");
1831   auto kind = hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
1832   auto *method = opClass.addMethodAndPrune(
1833       "void", "getCanonicalizationPatterns", kind, std::move(paramList));
1834 
1835   // If synthesizing the method, fill it it.
1836   if (hasBody)
1837     method->body() << "  results.add(canonicalize);\n";
1838 }
1839 
genFolderDecls()1840 void OpEmitter::genFolderDecls() {
1841   bool hasSingleResult =
1842       op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
1843 
1844   if (def.getValueAsBit("hasFolder")) {
1845     if (hasSingleResult) {
1846       opClass.addMethodAndPrune(
1847           "::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration,
1848           "::llvm::ArrayRef<::mlir::Attribute>", "operands");
1849     } else {
1850       SmallVector<OpMethodParameter, 2> paramList;
1851       paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
1852       paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
1853                              "results");
1854       opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
1855                                 OpMethod::MP_Declaration, std::move(paramList));
1856     }
1857   }
1858 }
1859 
genOpInterfaceMethods(const tblgen::InterfaceTrait * opTrait)1860 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
1861   Interface interface = opTrait->getInterface();
1862 
1863   // Get the set of methods that should always be declared.
1864   auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
1865   llvm::StringSet<> alwaysDeclaredMethods;
1866   alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
1867                                alwaysDeclaredMethodsVec.end());
1868 
1869   for (const InterfaceMethod &method : interface.getMethods()) {
1870     // Don't declare if the method has a body.
1871     if (method.getBody())
1872       continue;
1873     // Don't declare if the method has a default implementation and the op
1874     // didn't request that it always be declared.
1875     if (method.getDefaultImplementation() &&
1876         !alwaysDeclaredMethods.count(method.getName()))
1877       continue;
1878     genOpInterfaceMethod(method);
1879   }
1880 }
1881 
genOpInterfaceMethod(const InterfaceMethod & method,bool declaration)1882 OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
1883                                           bool declaration) {
1884   SmallVector<OpMethodParameter, 4> paramList;
1885   for (const InterfaceMethod::Argument &arg : method.getArguments())
1886     paramList.emplace_back(arg.type, arg.name);
1887 
1888   auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
1889   if (declaration)
1890     properties =
1891         static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
1892   return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
1893                                    properties, std::move(paramList));
1894 }
1895 
genOpInterfaceMethods()1896 void OpEmitter::genOpInterfaceMethods() {
1897   for (const auto &trait : op.getTraits()) {
1898     if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
1899       if (opTrait->shouldDeclareMethods())
1900         genOpInterfaceMethods(opTrait);
1901   }
1902 }
1903 
genSideEffectInterfaceMethods()1904 void OpEmitter::genSideEffectInterfaceMethods() {
1905   enum EffectKind { Operand, Result, Symbol, Static };
1906   struct EffectLocation {
1907     /// The effect applied.
1908     SideEffect effect;
1909 
1910     /// The index if the kind is not static.
1911     unsigned index : 30;
1912 
1913     /// The kind of the location.
1914     unsigned kind : 2;
1915   };
1916 
1917   StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
1918   auto resolveDecorators = [&](Operator::var_decorator_range decorators,
1919                                unsigned index, unsigned kind) {
1920     for (auto decorator : decorators)
1921       if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
1922         opClass.addTrait(effect->getInterfaceTrait());
1923         interfaceEffects[effect->getBaseEffectName()].push_back(
1924             EffectLocation{*effect, index, kind});
1925       }
1926   };
1927 
1928   // Collect effects that were specified via:
1929   /// Traits.
1930   for (const auto &trait : op.getTraits()) {
1931     const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
1932     if (!opTrait)
1933       continue;
1934     auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
1935     for (auto decorator : opTrait->getEffects())
1936       effects.push_back(EffectLocation{cast<SideEffect>(decorator),
1937                                        /*index=*/0, EffectKind::Static});
1938   }
1939   /// Attributes and Operands.
1940   for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
1941     Argument arg = op.getArg(i);
1942     if (arg.is<NamedTypeConstraint *>()) {
1943       resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
1944       ++operandIt;
1945       continue;
1946     }
1947     const NamedAttribute *attr = arg.get<NamedAttribute *>();
1948     if (attr->attr.getBaseAttr().isSymbolRefAttr())
1949       resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
1950   }
1951   /// Results.
1952   for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
1953     resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
1954 
1955   // The code used to add an effect instance.
1956   // {0}: The effect class.
1957   // {1}: Optional value or symbol reference.
1958   // {1}: The resource class.
1959   const char *addEffectCode =
1960       "  effects.emplace_back({0}::get(), {1}{2}::get());\n";
1961 
1962   for (auto &it : interfaceEffects) {
1963     // Generate the 'getEffects' method.
1964     std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
1965                                      "SideEffects::EffectInstance<{0}>> &",
1966                                      it.first())
1967                            .str();
1968     auto *getEffects =
1969         opClass.addMethodAndPrune("void", "getEffects", type, "effects");
1970     auto &body = getEffects->body();
1971 
1972     // Add effect instances for each of the locations marked on the operation.
1973     for (auto &location : it.second) {
1974       StringRef effect = location.effect.getName();
1975       StringRef resource = location.effect.getResource();
1976       if (location.kind == EffectKind::Static) {
1977         // A static instance has no attached value.
1978         body << llvm::formatv(addEffectCode, effect, "", resource).str();
1979       } else if (location.kind == EffectKind::Symbol) {
1980         // A symbol reference requires adding the proper attribute.
1981         const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
1982         if (attr->attr.isOptional()) {
1983           body << "  if (auto symbolRef = " << attr->name << "Attr())\n  "
1984                << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
1985                       .str();
1986         } else {
1987           body << llvm::formatv(addEffectCode, effect, attr->name + "(), ",
1988                                 resource)
1989                       .str();
1990         }
1991       } else {
1992         // Otherwise this is an operand/result, so we need to attach the Value.
1993         body << "  for (::mlir::Value value : getODS"
1994              << (location.kind == EffectKind::Operand ? "Operands" : "Results")
1995              << "(" << location.index << "))\n  "
1996              << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
1997       }
1998     }
1999   }
2000 }
2001 
genTypeInterfaceMethods()2002 void OpEmitter::genTypeInterfaceMethods() {
2003   if (!op.allResultTypesKnown())
2004     return;
2005   // Generate 'inferReturnTypes' method declaration using the interface method
2006   // declared in 'InferTypeOpInterface' op interface.
2007   const auto *trait = dyn_cast<InterfaceTrait>(
2008       op.getTrait("::mlir::InferTypeOpInterface::Trait"));
2009   Interface interface = trait->getInterface();
2010   OpMethod *method = [&]() -> OpMethod * {
2011     for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
2012       if (interfaceMethod.getName() == "inferReturnTypes") {
2013         return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
2014       }
2015     }
2016     assert(0 && "unable to find inferReturnTypes interface method");
2017     return nullptr;
2018   }();
2019   auto &body = method->body();
2020   body << "  inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
2021 
2022   FmtContext fctx;
2023   fctx.withBuilder("odsBuilder");
2024   body << "  ::mlir::Builder odsBuilder(context);\n";
2025 
2026   auto emitType =
2027       [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & {
2028     if (type.isArg()) {
2029       auto argIndex = type.getArg();
2030       assert(!op.getArg(argIndex).is<NamedAttribute *>());
2031       auto arg = op.getArgToOperandOrAttribute(argIndex);
2032       if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
2033         return body << "operands[" << arg.operandOrAttributeIndex()
2034                     << "].getType()";
2035       return body << "attributes[" << arg.operandOrAttributeIndex()
2036                   << "].getType()";
2037     } else {
2038       return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
2039     }
2040   };
2041 
2042   for (int i = 0, e = op.getNumResults(); i != e; ++i) {
2043     body << "  inferredReturnTypes[" << i << "] = ";
2044     auto types = op.getSameTypeAsResult(i);
2045     emitType(types[0]) << ";\n";
2046     if (types.size() == 1)
2047       continue;
2048     // TODO: We could verify equality here, but skipping that for verification.
2049   }
2050   body << "  return ::mlir::success();";
2051 }
2052 
genParser()2053 void OpEmitter::genParser() {
2054   if (!hasStringAttribute(def, "parser") ||
2055       hasStringAttribute(def, "assemblyFormat"))
2056     return;
2057 
2058   SmallVector<OpMethodParameter, 2> paramList;
2059   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
2060   paramList.emplace_back("::mlir::OperationState &", "result");
2061   auto *method =
2062       opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
2063                                 OpMethod::MP_Static, std::move(paramList));
2064 
2065   FmtContext fctx;
2066   fctx.addSubst("cppClass", opClass.getClassName());
2067   auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
2068   method->body() << "  " << tgfmt(parser, &fctx);
2069 }
2070 
genPrinter()2071 void OpEmitter::genPrinter() {
2072   if (hasStringAttribute(def, "assemblyFormat"))
2073     return;
2074 
2075   auto valueInit = def.getValueInit("printer");
2076   StringInit *stringInit = dyn_cast<StringInit>(valueInit);
2077   if (!stringInit)
2078     return;
2079 
2080   auto *method =
2081       opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p");
2082   FmtContext fctx;
2083   fctx.addSubst("cppClass", opClass.getClassName());
2084   auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
2085   method->body() << "  " << tgfmt(printer, &fctx);
2086 }
2087 
genVerifier()2088 void OpEmitter::genVerifier() {
2089   auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify");
2090   auto &body = method->body();
2091   body << "  if (failed(" << op.getAdaptorName()
2092        << "(*this).verify((*this)->getLoc()))) "
2093        << "return ::mlir::failure();\n";
2094 
2095   auto *valueInit = def.getValueInit("verifier");
2096   StringInit *stringInit = dyn_cast<StringInit>(valueInit);
2097   bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
2098   populateSubstitutions(op, "(*this)->getAttr", "this->getODSOperands",
2099                         "this->getODSResults", verifyCtx);
2100 
2101   genAttributeVerifier(op, "(*this)->getAttr", "emitOpError(",
2102                        /*emitVerificationRequiringOp=*/true, verifyCtx, body);
2103   genOperandResultVerifier(body, op.getOperands(), "operand");
2104   genOperandResultVerifier(body, op.getResults(), "result");
2105 
2106   for (auto &trait : op.getTraits()) {
2107     if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
2108       body << tgfmt("  if (!($0))\n    "
2109                     "return emitOpError(\"failed to verify that $1\");\n",
2110                     &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
2111                     t->getSummary());
2112     }
2113   }
2114 
2115   genRegionVerifier(body);
2116   genSuccessorVerifier(body);
2117 
2118   if (hasCustomVerify) {
2119     FmtContext fctx;
2120     fctx.addSubst("cppClass", opClass.getClassName());
2121     auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
2122     body << "  " << tgfmt(printer, &fctx);
2123   } else {
2124     body << "  return ::mlir::success();\n";
2125   }
2126 }
2127 
genOperandResultVerifier(OpMethodBody & body,Operator::value_range values,StringRef valueKind)2128 void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
2129                                          Operator::value_range values,
2130                                          StringRef valueKind) {
2131   FmtContext fctx;
2132 
2133   body << "  {\n";
2134   body << "    unsigned index = 0; (void)index;\n";
2135 
2136   for (auto staticValue : llvm::enumerate(values)) {
2137     bool hasPredicate = staticValue.value().hasPredicate();
2138     bool isOptional = staticValue.value().isOptional();
2139     if (!hasPredicate && !isOptional)
2140       continue;
2141     body << formatv("    auto valueGroup{2} = getODS{0}{1}s({2});\n",
2142                     // Capitalize the first letter to match the function name
2143                     valueKind.substr(0, 1).upper(), valueKind.substr(1),
2144                     staticValue.index());
2145 
2146     // If the constraint is optional check that the value group has at most 1
2147     // value.
2148     if (isOptional) {
2149       body << formatv("    if (valueGroup{0}.size() > 1)\n"
2150                       "      return emitOpError(\"{1} group starting at #\") "
2151                       "<< index << \" requires 0 or 1 element, but found \" << "
2152                       "valueGroup{0}.size();\n",
2153                       staticValue.index(), valueKind);
2154     }
2155 
2156     // Otherwise, if there is no predicate there is nothing left to do.
2157     if (!hasPredicate)
2158       continue;
2159     // Emit a loop to check all the dynamic values in the pack.
2160     StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn(
2161         staticValue.value().constraint);
2162     body << "    for (::mlir::Value v : valueGroup" << staticValue.index()
2163          << ") {\n"
2164          << "      if (::mlir::failed(" << constraintFn
2165          << "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n"
2166          << "        return ::mlir::failure();\n"
2167          << "      ++index;\n"
2168          << "    }\n";
2169   }
2170 
2171   body << "  }\n";
2172 }
2173 
genRegionVerifier(OpMethodBody & body)2174 void OpEmitter::genRegionVerifier(OpMethodBody &body) {
2175   // If we have no regions, there is nothing more to do.
2176   unsigned numRegions = op.getNumRegions();
2177   if (numRegions == 0)
2178     return;
2179 
2180   body << "{\n";
2181   body << "    unsigned index = 0; (void)index;\n";
2182 
2183   for (unsigned i = 0; i < numRegions; ++i) {
2184     const auto &region = op.getRegion(i);
2185     if (region.constraint.getPredicate().isNull())
2186       continue;
2187 
2188     body << "    for (::mlir::Region &region : ";
2189     body << formatv(region.isVariadic()
2190                         ? "{0}()"
2191                         : "::mlir::MutableArrayRef<::mlir::Region>((*this)"
2192                           "->getRegion({1}))",
2193                     region.name, i);
2194     body << ") {\n";
2195     auto constraint = tgfmt(region.constraint.getConditionTemplate(),
2196                             &verifyCtx.withSelf("region"))
2197                           .str();
2198 
2199     body << formatv("      (void)region;\n"
2200                     "      if (!({0})) {\n        "
2201                     "return emitOpError(\"region #\") << index << \" {1}"
2202                     "failed to "
2203                     "verify constraint: {2}\";\n      }\n",
2204                     constraint,
2205                     region.name.empty() ? "" : "('" + region.name + "') ",
2206                     region.constraint.getSummary())
2207          << "      ++index;\n"
2208          << "    }\n";
2209   }
2210   body << "  }\n";
2211 }
2212 
genSuccessorVerifier(OpMethodBody & body)2213 void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
2214   // If we have no successors, there is nothing more to do.
2215   unsigned numSuccessors = op.getNumSuccessors();
2216   if (numSuccessors == 0)
2217     return;
2218 
2219   body << "{\n";
2220   body << "    unsigned index = 0; (void)index;\n";
2221 
2222   for (unsigned i = 0; i < numSuccessors; ++i) {
2223     const auto &successor = op.getSuccessor(i);
2224     if (successor.constraint.getPredicate().isNull())
2225       continue;
2226 
2227     if (successor.isVariadic()) {
2228       body << formatv("    for (::mlir::Block *successor : {0}()) {\n",
2229                       successor.name);
2230     } else {
2231       body << "    {\n";
2232       body << formatv("      ::mlir::Block *successor = {0}();\n",
2233                       successor.name);
2234     }
2235     auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
2236                             &verifyCtx.withSelf("successor"))
2237                           .str();
2238 
2239     body << formatv("      (void)successor;\n"
2240                     "      if (!({0})) {\n        "
2241                     "return emitOpError(\"successor #\") << index << \"('{1}') "
2242                     "failed to "
2243                     "verify constraint: {2}\";\n      }\n",
2244                     constraint, successor.name,
2245                     successor.constraint.getSummary())
2246          << "      ++index;\n"
2247          << "    }\n";
2248   }
2249   body << "  }\n";
2250 }
2251 
2252 /// Add a size count trait to the given operation class.
addSizeCountTrait(OpClass & opClass,StringRef traitKind,int numTotal,int numVariadic)2253 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
2254                               int numTotal, int numVariadic) {
2255   if (numVariadic != 0) {
2256     if (numTotal == numVariadic)
2257       opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
2258     else
2259       opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
2260                        Twine(numTotal - numVariadic) + ">::Impl");
2261     return;
2262   }
2263   switch (numTotal) {
2264   case 0:
2265     opClass.addTrait("::mlir::OpTrait::Zero" + traitKind);
2266     break;
2267   case 1:
2268     opClass.addTrait("::mlir::OpTrait::One" + traitKind);
2269     break;
2270   default:
2271     opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
2272                      ">::Impl");
2273     break;
2274   }
2275 }
2276 
genTraits()2277 void OpEmitter::genTraits() {
2278   // Add region size trait.
2279   unsigned numRegions = op.getNumRegions();
2280   unsigned numVariadicRegions = op.getNumVariadicRegions();
2281   addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
2282 
2283   // Add result size traits.
2284   int numResults = op.getNumResults();
2285   int numVariadicResults = op.getNumVariableLengthResults();
2286   addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
2287 
2288   // For single result ops with a known specific type, generate a OneTypedResult
2289   // trait.
2290   if (numResults == 1 && numVariadicResults == 0) {
2291     auto cppName = op.getResults().begin()->constraint.getCPPClassName();
2292     opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
2293   }
2294 
2295   // Add successor size trait.
2296   unsigned numSuccessors = op.getNumSuccessors();
2297   unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
2298   addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
2299 
2300   // Add variadic size trait and normal op traits.
2301   int numOperands = op.getNumOperands();
2302   int numVariadicOperands = op.getNumVariableLengthOperands();
2303 
2304   // Add operand size trait.
2305   if (numVariadicOperands != 0) {
2306     if (numOperands == numVariadicOperands)
2307       opClass.addTrait("::mlir::OpTrait::VariadicOperands");
2308     else
2309       opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" +
2310                        Twine(numOperands - numVariadicOperands) + ">::Impl");
2311   } else {
2312     switch (numOperands) {
2313     case 0:
2314       opClass.addTrait("::mlir::OpTrait::ZeroOperands");
2315       break;
2316     case 1:
2317       opClass.addTrait("::mlir::OpTrait::OneOperand");
2318       break;
2319     default:
2320       opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) +
2321                        ">::Impl");
2322       break;
2323     }
2324   }
2325 
2326   // Add the native and interface traits.
2327   for (const auto &trait : op.getTraits()) {
2328     if (auto opTrait = dyn_cast<tblgen::NativeTrait>(&trait))
2329       opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2330     else if (auto opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
2331       opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2332   }
2333 }
2334 
genOpNameGetter()2335 void OpEmitter::genOpNameGetter() {
2336   auto *method = opClass.addMethodAndPrune(
2337       "::llvm::StringLiteral", "getOperationName",
2338       OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
2339   method->body() << "  return ::llvm::StringLiteral(\"" << op.getOperationName()
2340                  << "\");";
2341 }
2342 
genOpAsmInterface()2343 void OpEmitter::genOpAsmInterface() {
2344   // If the user only has one results or specifically added the Asm trait,
2345   // then don't generate it for them. We specifically only handle multi result
2346   // operations, because the name of a single result in the common case is not
2347   // interesting(generally 'result'/'output'/etc.).
2348   // TODO: We could also add a flag to allow operations to opt in to this
2349   // generation, even if they only have a single operation.
2350   int numResults = op.getNumResults();
2351   if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
2352     return;
2353 
2354   SmallVector<StringRef, 4> resultNames(numResults);
2355   for (int i = 0; i != numResults; ++i)
2356     resultNames[i] = op.getResultName(i);
2357 
2358   // Don't add the trait if none of the results have a valid name.
2359   if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
2360     return;
2361   opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
2362 
2363   // Generate the right accessor for the number of results.
2364   auto *method = opClass.addMethodAndPrune(
2365       "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn");
2366   auto &body = method->body();
2367   for (int i = 0; i != numResults; ++i) {
2368     body << "  auto resultGroup" << i << " = getODSResults(" << i << ");\n"
2369          << "  if (!llvm::empty(resultGroup" << i << "))\n"
2370          << "    setNameFn(*resultGroup" << i << ".begin(), \""
2371          << resultNames[i] << "\");\n";
2372   }
2373 }
2374 
2375 //===----------------------------------------------------------------------===//
2376 // OpOperandAdaptor emitter
2377 //===----------------------------------------------------------------------===//
2378 
2379 namespace {
2380 // Helper class to emit Op operand adaptors to an output stream.  Operand
2381 // adaptors are wrappers around ArrayRef<Value> that provide named operand
2382 // getters identical to those defined in the Op.
2383 class OpOperandAdaptorEmitter {
2384 public:
2385   static void emitDecl(const Operator &op, raw_ostream &os);
2386   static void emitDef(const Operator &op, raw_ostream &os);
2387 
2388 private:
2389   explicit OpOperandAdaptorEmitter(const Operator &op);
2390 
2391   // Add verification function. This generates a verify method for the adaptor
2392   // which verifies all the op-independent attribute constraints.
2393   void addVerification();
2394 
2395   const Operator &op;
2396   Class adaptor;
2397 };
2398 } // end namespace
2399 
OpOperandAdaptorEmitter(const Operator & op)2400 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
2401     : op(op), adaptor(op.getAdaptorName()) {
2402   adaptor.newField("::mlir::ValueRange", "odsOperands");
2403   adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
2404   adaptor.newField("::mlir::RegionRange", "odsRegions");
2405   const auto *attrSizedOperands =
2406       op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2407   {
2408     SmallVector<OpMethodParameter, 2> paramList;
2409     paramList.emplace_back("::mlir::ValueRange", "values");
2410     paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
2411                            attrSizedOperands ? "" : "nullptr");
2412     paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
2413     auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList));
2414 
2415     constructor->addMemberInitializer("odsOperands", "values");
2416     constructor->addMemberInitializer("odsAttrs", "attrs");
2417     constructor->addMemberInitializer("odsRegions", "regions");
2418   }
2419 
2420   {
2421     auto *constructor = adaptor.addConstructorAndPrune(
2422         llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
2423     constructor->addMemberInitializer("odsOperands", "op->getOperands()");
2424     constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
2425     constructor->addMemberInitializer("odsRegions", "op->getRegions()");
2426   }
2427 
2428   {
2429     auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands");
2430     m->body() << "  return odsOperands;";
2431   }
2432   std::string sizeAttrInit =
2433       formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
2434   generateNamedOperandGetters(op, adaptor, sizeAttrInit,
2435                               /*rangeType=*/"::mlir::ValueRange",
2436                               /*rangeBeginCall=*/"odsOperands.begin()",
2437                               /*rangeSizeCall=*/"odsOperands.size()",
2438                               /*getOperandCallPattern=*/"odsOperands[{0}]");
2439 
2440   FmtContext fctx;
2441   fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
2442 
2443   auto emitAttr = [&](StringRef name, Attribute attr) {
2444     auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body();
2445     body << "  assert(odsAttrs && \"no attributes when constructing adapter\");"
2446          << "\n  " << attr.getStorageType() << " attr = "
2447          << "odsAttrs.get(\"" << name << "\").";
2448     if (attr.hasDefaultValue() || attr.isOptional())
2449       body << "dyn_cast_or_null<";
2450     else
2451       body << "cast<";
2452     body << attr.getStorageType() << ">();\n";
2453 
2454     if (attr.hasDefaultValue()) {
2455       // Use the default value if attribute is not set.
2456       // TODO: this is inefficient, we are recreating the attribute for every
2457       // call. This should be set instead.
2458       std::string defaultValue = std::string(
2459           tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2460       body << "  if (!attr)\n    attr = " << defaultValue << ";\n";
2461     }
2462     body << "  return attr;\n";
2463   };
2464 
2465   {
2466     auto *m =
2467         adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes");
2468     m->body() << "  return odsAttrs;";
2469   }
2470   for (auto &namedAttr : op.getAttributes()) {
2471     const auto &name = namedAttr.name;
2472     const auto &attr = namedAttr.attr;
2473     if (!attr.isDerivedAttr())
2474       emitAttr(name, attr);
2475   }
2476 
2477   unsigned numRegions = op.getNumRegions();
2478   if (numRegions > 0) {
2479     auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions");
2480     m->body() << "  return odsRegions;";
2481   }
2482   for (unsigned i = 0; i < numRegions; ++i) {
2483     const auto &region = op.getRegion(i);
2484     if (region.name.empty())
2485       continue;
2486 
2487     // Generate the accessors for a variadic region.
2488     if (region.isVariadic()) {
2489       auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", region.name);
2490       m->body() << formatv("  return odsRegions.drop_front({0});", i);
2491       continue;
2492     }
2493 
2494     auto *m = adaptor.addMethodAndPrune("::mlir::Region &", region.name);
2495     m->body() << formatv("  return *odsRegions[{0}];", i);
2496   }
2497 
2498   // Add verification function.
2499   addVerification();
2500 }
2501 
addVerification()2502 void OpOperandAdaptorEmitter::addVerification() {
2503   auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
2504                                            "::mlir::Location", "loc");
2505   auto &body = method->body();
2506 
2507   const char *checkAttrSizedValueSegmentsCode = R"(
2508   {
2509     auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
2510     auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
2511     if (numElements != {1})
2512       return emitError(loc, "'{0}' attribute for specifying {2} segments "
2513                        "must have {1} elements, but got ") << numElements;
2514   }
2515   )";
2516 
2517   // Verify a few traits first so that we can use
2518   // getODSOperands()/getODSResults() in the rest of the verifier.
2519   for (auto &trait : op.getTraits()) {
2520     if (auto *t = dyn_cast<tblgen::NativeTrait>(&trait)) {
2521       if (t->getFullyQualifiedTraitName() ==
2522           "::mlir::OpTrait::AttrSizedOperandSegments") {
2523         body << formatv(checkAttrSizedValueSegmentsCode,
2524                         "operand_segment_sizes", op.getNumOperands(),
2525                         "operand");
2526       } else if (t->getFullyQualifiedTraitName() ==
2527                  "::mlir::OpTrait::AttrSizedResultSegments") {
2528         body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
2529                         op.getNumResults(), "result");
2530       }
2531     }
2532   }
2533 
2534   FmtContext verifyCtx;
2535   populateSubstitutions(op, "odsAttrs.get", "getODSOperands",
2536                         "<no results should be generated>", verifyCtx);
2537   genAttributeVerifier(op, "odsAttrs.get",
2538                        Twine("emitError(loc, \"'") + op.getOperationName() +
2539                            "' op \"",
2540                        /*emitVerificationRequiringOp*/ false, verifyCtx, body);
2541 
2542   body << "  return ::mlir::success();";
2543 }
2544 
emitDecl(const Operator & op,raw_ostream & os)2545 void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
2546   OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
2547 }
2548 
emitDef(const Operator & op,raw_ostream & os)2549 void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
2550   OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
2551 }
2552 
2553 // Emits the opcode enum and op classes.
emitOpClasses(const RecordKeeper & recordKeeper,const std::vector<Record * > & defs,raw_ostream & os,bool emitDecl)2554 static void emitOpClasses(const RecordKeeper &recordKeeper,
2555                           const std::vector<Record *> &defs, raw_ostream &os,
2556                           bool emitDecl) {
2557   // First emit forward declaration for each class, this allows them to refer
2558   // to each others in traits for example.
2559   if (emitDecl) {
2560     os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
2561     os << "#undef GET_OP_FWD_DEFINES\n";
2562     for (auto *def : defs) {
2563       Operator op(*def);
2564       NamespaceEmitter emitter(os, op.getCppNamespace());
2565       os << "class " << op.getCppClassName() << ";\n";
2566     }
2567     os << "#endif\n\n";
2568   }
2569 
2570   IfDefScope scope("GET_OP_CLASSES", os);
2571   if (defs.empty())
2572     return;
2573 
2574   // Generate all of the locally instantiated methods first.
2575   StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os,
2576                                                       emitDecl);
2577   for (auto *def : defs) {
2578     Operator op(*def);
2579     NamespaceEmitter emitter(os, op.getCppNamespace());
2580     if (emitDecl) {
2581       os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
2582       OpOperandAdaptorEmitter::emitDecl(op, os);
2583       OpEmitter::emitDecl(op, os, staticVerifierEmitter);
2584     } else {
2585       os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
2586       OpOperandAdaptorEmitter::emitDef(op, os);
2587       OpEmitter::emitDef(op, os, staticVerifierEmitter);
2588     }
2589   }
2590 }
2591 
2592 // Emits a comma-separated list of the ops.
emitOpList(const std::vector<Record * > & defs,raw_ostream & os)2593 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
2594   IfDefScope scope("GET_OP_LIST", os);
2595 
2596   interleave(
2597       // TODO: We are constructing the Operator wrapper instance just for
2598       // getting it's qualified class name here. Reduce the overhead by having a
2599       // lightweight version of Operator class just for that purpose.
2600       defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
2601       [&os]() { os << ",\n"; });
2602 }
2603 
emitOpDecls(const RecordKeeper & recordKeeper,raw_ostream & os)2604 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
2605   emitSourceFileHeader("Op Declarations", os);
2606 
2607   std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
2608   emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
2609 
2610   return false;
2611 }
2612 
emitOpDefs(const RecordKeeper & recordKeeper,raw_ostream & os)2613 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
2614   emitSourceFileHeader("Op Definitions", os);
2615 
2616   std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
2617   emitOpList(defs, os);
2618   emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
2619 
2620   return false;
2621 }
2622 
2623 static mlir::GenRegistration
2624     genOpDecls("gen-op-decls", "Generate op declarations",
__anon96b8615e1802(const RecordKeeper &records, raw_ostream &os) 2625                [](const RecordKeeper &records, raw_ostream &os) {
2626                  return emitOpDecls(records, os);
2627                });
2628 
2629 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
2630                                        [](const RecordKeeper &records,
__anon96b8615e1902(const RecordKeeper &records, raw_ostream &os) 2631                                           raw_ostream &os) {
2632                                          return emitOpDefs(records, os);
2633                                        });
2634