1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "OpFormatGen.h"
15 #include "OpGenHelpers.h"
16 #include "mlir/TableGen/CodeGenHelpers.h"
17 #include "mlir/TableGen/Format.h"
18 #include "mlir/TableGen/GenInfo.h"
19 #include "mlir/TableGen/Interfaces.h"
20 #include "mlir/TableGen/OpClass.h"
21 #include "mlir/TableGen/Operator.h"
22 #include "mlir/TableGen/SideEffects.h"
23 #include "mlir/TableGen/Trait.h"
24 #include "llvm/ADT/MapVector.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/Support/Path.h"
28 #include "llvm/Support/Signals.h"
29 #include "llvm/TableGen/Error.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
32
33 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
34
35 using namespace llvm;
36 using namespace mlir;
37 using namespace mlir::tblgen;
38
39 static const char *const tblgenNamePrefix = "tblgen_";
40 static const char *const generatedArgName = "odsArg";
41 static const char *const odsBuilder = "odsBuilder";
42 static const char *const builderOpState = "odsState";
43
44 // The logic to calculate the actual value range for a declared operand/result
45 // of an op with variadic operands/results. Note that this logic is not for
46 // general use; it assumes all variadic operands/results must have the same
47 // number of values.
48 //
49 // {0}: The list of whether each declared operand/result is variadic.
50 // {1}: The total number of non-variadic operands/results.
51 // {2}: The total number of variadic operands/results.
52 // {3}: The total number of actual values.
53 // {4}: "operand" or "result".
54 const char *sameVariadicSizeValueRangeCalcCode = R"(
55 bool isVariadic[] = {{{0}};
56 int prevVariadicCount = 0;
57 for (unsigned i = 0; i < index; ++i)
58 if (isVariadic[i]) ++prevVariadicCount;
59
60 // Calculate how many dynamic values a static variadic {4} corresponds to.
61 // This assumes all static variadic {4}s have the same dynamic value count.
62 int variadicSize = ({3} - {1}) / {2};
63 // `index` passed in as the parameter is the static index which counts each
64 // {4} (variadic or not) as size 1. So here for each previous static variadic
65 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
66 // value pack for this static {4} starts.
67 int start = index + (variadicSize - 1) * prevVariadicCount;
68 int size = isVariadic[index] ? variadicSize : 1;
69 return {{start, size};
70 )";
71
72 // The logic to calculate the actual value range for a declared operand/result
73 // of an op with variadic operands/results. Note that this logic is assumes
74 // the op has an attribute specifying the size of each operand/result segment
75 // (variadic or not).
76 //
77 // {0}: The name of the attribute specifying the segment sizes.
78 const char *adapterSegmentSizeAttrInitCode = R"(
79 assert(odsAttrs && "missing segment size attribute for op");
80 auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
81 )";
82 const char *opSegmentSizeAttrInitCode = R"(
83 auto sizeAttr = (*this)->getAttr({0}).cast<::mlir::DenseIntElementsAttr>();
84 )";
85 const char *attrSizedSegmentValueRangeCalcCode = R"(
86 auto sizeAttrValues = sizeAttr.getValues<uint32_t>();
87 unsigned start = 0;
88 for (unsigned i = 0; i < index; ++i)
89 start += *(sizeAttrValues.begin() + i);
90 unsigned size = *(sizeAttrValues.begin() + index);
91 return {start, size};
92 )";
93
94 // The logic to build a range of either operand or result values.
95 //
96 // {0}: The begin iterator of the actual values.
97 // {1}: The call to generate the start and length of the value range.
98 const char *valueRangeReturnCode = R"(
99 auto valueRange = {1};
100 return {{std::next({0}, valueRange.first),
101 std::next({0}, valueRange.first + valueRange.second)};
102 )";
103
104 static const char *const opCommentHeader = R"(
105 //===----------------------------------------------------------------------===//
106 // {0} {1}
107 //===----------------------------------------------------------------------===//
108
109 )";
110
111 //===----------------------------------------------------------------------===//
112 // StaticVerifierFunctionEmitter
113 //===----------------------------------------------------------------------===//
114
115 namespace {
116 /// This class deduplicates shared operation verification code by emitting
117 /// static functions alongside the op definitions. These methods are local to
118 /// the definition file, and are invoked within the operation verify methods.
119 /// An example is shown below:
120 ///
121 /// static LogicalResult localVerify(...)
122 ///
123 /// LogicalResult OpA::verify(...) {
124 /// if (failed(localVerify(...)))
125 /// return failure();
126 /// ...
127 /// }
128 ///
129 /// LogicalResult OpB::verify(...) {
130 /// if (failed(localVerify(...)))
131 /// return failure();
132 /// ...
133 /// }
134 ///
135 class StaticVerifierFunctionEmitter {
136 public:
137 StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
138 ArrayRef<llvm::Record *> opDefs,
139 raw_ostream &os, bool emitDecl);
140
141 /// Get the name of the local function used for the given type constraint.
142 /// These functions are used for operand and result constraints and have the
143 /// form:
144 /// LogicalResult(Operation *op, Type type, StringRef valueKind,
145 /// unsigned valueGroupStartIndex);
getTypeConstraintFn(const Constraint & constraint) const146 StringRef getTypeConstraintFn(const Constraint &constraint) const {
147 auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
148 assert(it != localTypeConstraints.end() && "expected valid constraint fn");
149 return it->second;
150 }
151
152 private:
153 /// Returns a unique name to use when generating local methods.
154 static std::string getUniqueName(const llvm::RecordKeeper &records);
155
156 /// Emit local methods for the type constraints used within the provided op
157 /// definitions.
158 void emitTypeConstraintMethods(ArrayRef<llvm::Record *> opDefs,
159 raw_ostream &os, bool emitDecl);
160
161 /// A unique label for the file currently being generated. This is used to
162 /// ensure that the local functions have a unique name.
163 std::string uniqueOutputLabel;
164
165 /// A set of functions implementing type constraints, used for operand and
166 /// result verification.
167 llvm::DenseMap<const void *, std::string> localTypeConstraints;
168 };
169 } // namespace
170
StaticVerifierFunctionEmitter(const llvm::RecordKeeper & records,ArrayRef<llvm::Record * > opDefs,raw_ostream & os,bool emitDecl)171 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
172 const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
173 raw_ostream &os, bool emitDecl)
174 : uniqueOutputLabel(getUniqueName(records)) {
175 llvm::Optional<NamespaceEmitter> namespaceEmitter;
176 if (!emitDecl) {
177 os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
178 namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
179 }
180
181 emitTypeConstraintMethods(opDefs, os, emitDecl);
182 }
183
getUniqueName(const llvm::RecordKeeper & records)184 std::string StaticVerifierFunctionEmitter::getUniqueName(
185 const llvm::RecordKeeper &records) {
186 // Use the input file name when generating a unique name.
187 std::string inputFilename = records.getInputFilename();
188
189 // Drop all but the base filename.
190 StringRef nameRef = llvm::sys::path::filename(inputFilename);
191 nameRef.consume_back(".td");
192
193 // Sanitize any invalid characters.
194 std::string uniqueName;
195 for (char c : nameRef) {
196 if (llvm::isAlnum(c) || c == '_')
197 uniqueName.push_back(c);
198 else
199 uniqueName.append(llvm::utohexstr((unsigned char)c));
200 }
201 return uniqueName;
202 }
203
emitTypeConstraintMethods(ArrayRef<llvm::Record * > opDefs,raw_ostream & os,bool emitDecl)204 void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
205 ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
206 // Collect a set of all of the used type constraints within the operation
207 // definitions.
208 llvm::SetVector<const void *> typeConstraints;
209 for (Record *def : opDefs) {
210 Operator op(*def);
211 for (NamedTypeConstraint &operand : op.getOperands())
212 if (operand.hasPredicate())
213 typeConstraints.insert(operand.constraint.getAsOpaquePointer());
214 for (NamedTypeConstraint &result : op.getResults())
215 if (result.hasPredicate())
216 typeConstraints.insert(result.constraint.getAsOpaquePointer());
217 }
218
219 // Record the mapping from predicate to constraint. If two constraints has the
220 // same predicate and constraint summary, they can share the same verification
221 // function.
222 llvm::DenseMap<Pred, const void *> predToConstraint;
223 FmtContext fctx;
224 for (auto it : llvm::enumerate(typeConstraints)) {
225 std::string name;
226 Constraint constraint = Constraint::getFromOpaquePointer(it.value());
227 Pred pred = constraint.getPredicate();
228 auto iter = predToConstraint.find(pred);
229 if (iter != predToConstraint.end()) {
230 do {
231 Constraint built = Constraint::getFromOpaquePointer(iter->second);
232 // We may have the different constraints but have the same predicate,
233 // for example, ConstraintA and Variadic<ConstraintA>, note that
234 // Variadic<> doesn't introduce new predicate. In this case, we can
235 // share the same predicate function if they also have consistent
236 // summary, otherwise we may report the wrong message while verification
237 // fails.
238 if (constraint.getSummary() == built.getSummary()) {
239 name = getTypeConstraintFn(built).str();
240 break;
241 }
242 ++iter;
243 } while (iter != predToConstraint.end() && iter->first == pred);
244 }
245
246 if (!name.empty()) {
247 localTypeConstraints.try_emplace(it.value(), name);
248 continue;
249 }
250
251 // Generate an obscure and unique name for this type constraint.
252 name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
253 Twine(it.index()))
254 .str();
255 predToConstraint.insert(
256 std::make_pair(constraint.getPredicate(), it.value()));
257 localTypeConstraints.try_emplace(it.value(), name);
258
259 // Only generate the methods if we are generating definitions.
260 if (emitDecl)
261 continue;
262
263 os << "static ::mlir::LogicalResult " << name
264 << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
265 "valueKind, unsigned valueGroupStartIndex) {\n";
266
267 os << " if (!("
268 << tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type"))
269 << ")) {\n"
270 << formatv(
271 " return op->emitOpError(valueKind) << \" #\" << "
272 "valueGroupStartIndex << \" must be {0}, but got \" << type;\n",
273 constraint.getSummary())
274 << " }\n"
275 << " return ::mlir::success();\n"
276 << "}\n\n";
277 }
278 }
279
280 //===----------------------------------------------------------------------===//
281 // Utility structs and functions
282 //===----------------------------------------------------------------------===//
283
284 // Replaces all occurrences of `match` in `str` with `substitute`.
replaceAllSubstrs(std::string str,const std::string & match,const std::string & substitute)285 static std::string replaceAllSubstrs(std::string str, const std::string &match,
286 const std::string &substitute) {
287 std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
288 while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
289 str = str.replace(matchLoc, match.size(), substitute);
290 scanLoc = matchLoc + substitute.size();
291 }
292 return str;
293 }
294
295 // Returns whether the record has a value of the given name that can be returned
296 // via getValueAsString.
hasStringAttribute(const Record & record,StringRef fieldName)297 static inline bool hasStringAttribute(const Record &record,
298 StringRef fieldName) {
299 auto valueInit = record.getValueInit(fieldName);
300 return isa<StringInit>(valueInit);
301 }
302
getArgumentName(const Operator & op,int index)303 static std::string getArgumentName(const Operator &op, int index) {
304 const auto &operand = op.getOperand(index);
305 if (!operand.name.empty())
306 return std::string(operand.name);
307 else
308 return std::string(formatv("{0}_{1}", generatedArgName, index));
309 }
310
311 // Returns true if we can use unwrapped value for the given `attr` in builders.
canUseUnwrappedRawValue(const tblgen::Attribute & attr)312 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
313 return attr.getReturnType() != attr.getStorageType() &&
314 // We need to wrap the raw value into an attribute in the builder impl
315 // so we need to make sure that the attribute specifies how to do that.
316 !attr.getConstBuilderTemplate().empty();
317 }
318
319 //===----------------------------------------------------------------------===//
320 // Op emitter
321 //===----------------------------------------------------------------------===//
322
323 namespace {
324 // Helper class to emit a record into the given output stream.
325 class OpEmitter {
326 public:
327 static void
328 emitDecl(const Operator &op, raw_ostream &os,
329 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
330 static void
331 emitDef(const Operator &op, raw_ostream &os,
332 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
333
334 private:
335 OpEmitter(const Operator &op,
336 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
337
338 void emitDecl(raw_ostream &os);
339 void emitDef(raw_ostream &os);
340
341 // Generate methods for accessing the attribute names of this operation.
342 void genAttrNameGetters();
343
344 // Return the index of the given attribute name. This is a relative ordering
345 // for this name, used in attribute getters.
346 unsigned getAttrNameIndex(StringRef attrName) const;
347
348 // Generates the OpAsmOpInterface for this operation if possible.
349 void genOpAsmInterface();
350
351 // Generates the `getOperationName` method for this op.
352 void genOpNameGetter();
353
354 // Generates getters for the attributes.
355 void genAttrGetters();
356
357 // Generates setter for the attributes.
358 void genAttrSetters();
359
360 // Generates removers for optional attributes.
361 void genOptionalAttrRemovers();
362
363 // Generates getters for named operands.
364 void genNamedOperandGetters();
365
366 // Generates setters for named operands.
367 void genNamedOperandSetters();
368
369 // Generates getters for named results.
370 void genNamedResultGetters();
371
372 // Generates getters for named regions.
373 void genNamedRegionGetters();
374
375 // Generates getters for named successors.
376 void genNamedSuccessorGetters();
377
378 // Generates builder methods for the operation.
379 void genBuilder();
380
381 // Generates the build() method that takes each operand/attribute
382 // as a stand-alone parameter.
383 void genSeparateArgParamBuilder();
384
385 // Generates the build() method that takes each operand/attribute as a
386 // stand-alone parameter. The generated build() method uses first operand's
387 // type as all results' types.
388 void genUseOperandAsResultTypeSeparateParamBuilder();
389
390 // Generates the build() method that takes all operands/attributes
391 // collectively as one parameter. The generated build() method uses first
392 // operand's type as all results' types.
393 void genUseOperandAsResultTypeCollectiveParamBuilder();
394
395 // Generates the build() method that takes aggregate operands/attributes
396 // parameters. This build() method uses inferred types as result types.
397 // Requires: The type needs to be inferable via InferTypeOpInterface.
398 void genInferredTypeCollectiveParamBuilder();
399
400 // Generates the build() method that takes each operand/attribute as a
401 // stand-alone parameter. The generated build() method uses first attribute's
402 // type as all result's types.
403 void genUseAttrAsResultTypeBuilder();
404
405 // Generates the build() method that takes all result types collectively as
406 // one parameter. Similarly for operands and attributes.
407 void genCollectiveParamBuilder();
408
409 // The kind of parameter to generate for result types in builders.
410 enum class TypeParamKind {
411 None, // No result type in parameter list.
412 Separate, // A separate parameter for each result type.
413 Collective, // An ArrayRef<Type> for all result types.
414 };
415
416 // The kind of parameter to generate for attributes in builders.
417 enum class AttrParamKind {
418 WrappedAttr, // A wrapped MLIR Attribute instance.
419 UnwrappedValue, // A raw value without MLIR Attribute wrapper.
420 };
421
422 // Builds the parameter list for build() method of this op. This method writes
423 // to `paramList` the comma-separated parameter list and updates
424 // `resultTypeNames` with the names for parameters for specifying result
425 // types. The given `typeParamKind` and `attrParamKind` controls how result
426 // types and attributes are placed in the parameter list.
427 void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> ¶mList,
428 SmallVectorImpl<std::string> &resultTypeNames,
429 TypeParamKind typeParamKind,
430 AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
431
432 // Adds op arguments and regions into operation state for build() methods.
433 void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
434 bool isRawValueAttr = false);
435
436 // Generates canonicalizer declaration for the operation.
437 void genCanonicalizerDecls();
438
439 // Generates the folder declaration for the operation.
440 void genFolderDecls();
441
442 // Generates the parser for the operation.
443 void genParser();
444
445 // Generates the printer for the operation.
446 void genPrinter();
447
448 // Generates verify method for the operation.
449 void genVerifier();
450
451 // Generates verify statements for operands and results in the operation.
452 // The generated code will be attached to `body`.
453 void genOperandResultVerifier(OpMethodBody &body,
454 Operator::value_range values,
455 StringRef valueKind);
456
457 // Generates verify statements for regions in the operation.
458 // The generated code will be attached to `body`.
459 void genRegionVerifier(OpMethodBody &body);
460
461 // Generates verify statements for successors in the operation.
462 // The generated code will be attached to `body`.
463 void genSuccessorVerifier(OpMethodBody &body);
464
465 // Generates the traits used by the object.
466 void genTraits();
467
468 // Generate the OpInterface methods for all interfaces.
469 void genOpInterfaceMethods();
470
471 // Generate op interface methods for the given interface.
472 void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
473
474 // Generate op interface method for the given interface method. If
475 // 'declaration' is true, generates a declaration, else a definition.
476 OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
477 bool declaration = true);
478
479 // Generate the side effect interface methods.
480 void genSideEffectInterfaceMethods();
481
482 // Generate the type inference interface methods.
483 void genTypeInterfaceMethods();
484
485 private:
486 // The TableGen record for this op.
487 // TODO: OpEmitter should not have a Record directly,
488 // it should rather go through the Operator for better abstraction.
489 const Record &def;
490
491 // The wrapper operator class for querying information from this op.
492 Operator op;
493
494 // The C++ code builder for this op
495 OpClass opClass;
496
497 // The format context for verification code generation.
498 FmtContext verifyCtx;
499
500 // The emitter containing all of the locally emitted verification functions.
501 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
502
503 // A map of attribute names (including implicit attributes) registered to the
504 // current operation, to the relative order in which they were registered.
505 llvm::MapVector<StringRef, unsigned> attributeNames;
506 };
507 } // end anonymous namespace
508
509 // Populate the format context `ctx` with substitutions of attributes, operands
510 // and results.
511 // - attrGet corresponds to the name of the function to call to get value of
512 // attribute (the generated function call returns an Attribute);
513 // - operandGet corresponds to the name of the function with which to retrieve
514 // an operand (the generated function call returns an OperandRange);
515 // - resultGet corresponds to the name of the function to get an result (the
516 // generated function call returns a ValueRange);
populateSubstitutions(const Operator & op,const char * attrGet,const char * operandGet,const char * resultGet,FmtContext & ctx)517 static void populateSubstitutions(const Operator &op, const char *attrGet,
518 const char *operandGet, const char *resultGet,
519 FmtContext &ctx) {
520 // Populate substitutions for attributes and named operands.
521 for (const auto &namedAttr : op.getAttributes())
522 ctx.addSubst(namedAttr.name,
523 formatv("{0}(\"{1}\")", attrGet, namedAttr.name));
524 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
525 auto &value = op.getOperand(i);
526 if (value.name.empty())
527 continue;
528
529 if (value.isVariadic())
530 ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i));
531 else
532 ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i));
533 }
534
535 // Populate substitutions for results.
536 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
537 auto &value = op.getResult(i);
538 if (value.name.empty())
539 continue;
540
541 if (value.isVariadic())
542 ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i));
543 else
544 ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i));
545 }
546 }
547
548 // Generate attribute verification. If emitVerificationRequiringOp is set then
549 // only verification for attributes whose value depend on op being known are
550 // emitted, else only verification that doesn't depend on the op being known are
551 // generated.
552 // - emitErrorPrefix is the prefix for the error emitting call which consists
553 // of the entire function call up to start of error message fragment;
554 // - emitVerificationRequiringOp specifies whether verification should be
555 // emitted for verification that require the op to exist;
genAttributeVerifier(const Operator & op,const char * attrGet,const Twine & emitErrorPrefix,bool emitVerificationRequiringOp,FmtContext & ctx,OpMethodBody & body)556 static void genAttributeVerifier(const Operator &op, const char *attrGet,
557 const Twine &emitErrorPrefix,
558 bool emitVerificationRequiringOp,
559 FmtContext &ctx, OpMethodBody &body) {
560 for (const auto &namedAttr : op.getAttributes()) {
561 const auto &attr = namedAttr.attr;
562 if (attr.isDerivedAttr())
563 continue;
564
565 auto attrName = namedAttr.name;
566 bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
567 auto attrPred = attr.getPredicate();
568 auto condition = attrPred.isNull() ? "" : attrPred.getCondition();
569 // There is a condition to emit only if the use of $_op and whether to
570 // emit verifications for op matches.
571 bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^
572 emitVerificationRequiringOp);
573
574 // Prefix with `tblgen_` to avoid hiding the attribute accessor.
575 auto varName = tblgenNamePrefix + attrName;
576
577 // If the attribute is
578 // 1. Required (not allowed missing) and not in op verification, or
579 // 2. Has a condition that will get verified
580 // then the variable will be used.
581 //
582 // Therefore, for optional attributes whose verification requires that an
583 // op already exists for verification/emitVerificationRequiringOp is set
584 // has nothing that can be verified here.
585 if ((allowMissingAttr || emitVerificationRequiringOp) &&
586 !hasConditionToEmit)
587 continue;
588
589 body << formatv(" {\n auto {0} = {1}(\"{2}\");\n", varName, attrGet,
590 attrName);
591
592 if (!emitVerificationRequiringOp && !allowMissingAttr) {
593 body << " if (!" << varName << ") return " << emitErrorPrefix
594 << "\"requires attribute '" << attrName << "'\");\n";
595 }
596
597 if (!hasConditionToEmit) {
598 body << " }\n";
599 continue;
600 }
601
602 if (allowMissingAttr) {
603 // If the attribute has a default value, then only verify the predicate if
604 // set. This does effectively assume that the default value is valid.
605 // TODO: verify the debug value is valid (perhaps in debug mode only).
606 body << " if (" << varName << ") {\n";
607 }
608
609 body << tgfmt(" if (!($0)) return $1\"attribute '$2' "
610 "failed to satisfy constraint: $3\");\n",
611 /*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)),
612 emitErrorPrefix, attrName, attr.getSummary());
613 if (allowMissingAttr)
614 body << " }\n";
615 body << " }\n";
616 }
617 }
618
OpEmitter(const Operator & op,const StaticVerifierFunctionEmitter & staticVerifierEmitter)619 OpEmitter::OpEmitter(const Operator &op,
620 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
621 : def(op.getDef()), op(op),
622 opClass(op.getCppClassName(), op.getExtraClassDeclaration()),
623 staticVerifierEmitter(staticVerifierEmitter) {
624 verifyCtx.withOp("(*this->getOperation())");
625 verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
626
627 genTraits();
628
629 // Generate C++ code for various op methods. The order here determines the
630 // methods in the generated file.
631 genAttrNameGetters();
632 genOpAsmInterface();
633 genOpNameGetter();
634 genNamedOperandGetters();
635 genNamedOperandSetters();
636 genNamedResultGetters();
637 genNamedRegionGetters();
638 genNamedSuccessorGetters();
639 genAttrGetters();
640 genAttrSetters();
641 genOptionalAttrRemovers();
642 genBuilder();
643 genParser();
644 genPrinter();
645 genVerifier();
646 genCanonicalizerDecls();
647 genFolderDecls();
648 genTypeInterfaceMethods();
649 genOpInterfaceMethods();
650 generateOpFormat(op, opClass);
651 genSideEffectInterfaceMethods();
652 }
653
emitDecl(const Operator & op,raw_ostream & os,const StaticVerifierFunctionEmitter & staticVerifierEmitter)654 void OpEmitter::emitDecl(
655 const Operator &op, raw_ostream &os,
656 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
657 OpEmitter(op, staticVerifierEmitter).emitDecl(os);
658 }
659
emitDef(const Operator & op,raw_ostream & os,const StaticVerifierFunctionEmitter & staticVerifierEmitter)660 void OpEmitter::emitDef(
661 const Operator &op, raw_ostream &os,
662 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
663 OpEmitter(op, staticVerifierEmitter).emitDef(os);
664 }
665
emitDecl(raw_ostream & os)666 void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
667
emitDef(raw_ostream & os)668 void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
669
genAttrNameGetters()670 void OpEmitter::genAttrNameGetters() {
671 // Enumerate the attribute names of this op, assigning each a relative
672 // ordering.
673 auto addAttrName = [&](StringRef name) {
674 unsigned index = attributeNames.size();
675 attributeNames.insert({name, index});
676 };
677 for (const NamedAttribute &namedAttr : op.getAttributes())
678 addAttrName(namedAttr.name);
679 // Include key attributes from several traits as implicitly registered.
680 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
681 addAttrName("operand_segment_sizes");
682 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
683 addAttrName("result_segment_sizes");
684
685 // Emit the getAttributeNames method.
686 {
687 auto *method = opClass.addMethodAndPrune(
688 "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames",
689 OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Inline));
690 auto &body = method->body();
691 if (attributeNames.empty()) {
692 body << " return {};";
693 } else {
694 body << " static ::llvm::StringRef attrNames[] = {";
695 llvm::interleaveComma(llvm::make_first_range(attributeNames), body,
696 [&](StringRef attrName) {
697 body << "::llvm::StringRef(\"" << attrName
698 << "\")";
699 });
700 body << "};\n return ::llvm::makeArrayRef(attrNames);";
701 }
702 }
703 if (attributeNames.empty())
704 return;
705
706 // Emit the getAttributeNameForIndex methods.
707 {
708 auto *method = opClass.addMethodAndPrune(
709 "::mlir::Identifier", "getAttributeNameForIndex",
710 OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline),
711 "unsigned", "index");
712 method->body()
713 << " return getAttributeNameForIndex((*this)->getName(), index);";
714 }
715 {
716 auto *method = opClass.addMethodAndPrune(
717 "::mlir::Identifier", "getAttributeNameForIndex",
718 OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline |
719 OpMethod::MP_Static),
720 "::mlir::OperationName name, unsigned index");
721 method->body() << "assert(index < " << attributeNames.size()
722 << " && \"invalid attribute index\");\n"
723 " return name.getAbstractOperation()"
724 "->getAttributeNames()[index];";
725 }
726
727 // Generate the <attr>AttrName methods, that expose the attribute names to
728 // users.
729 const char *attrNameMethodBody = " return getAttributeNameForIndex({0});";
730 for (const std::pair<StringRef, unsigned> &attrIt : attributeNames) {
731 std::string methodName = (attrIt.first + "AttrName").str();
732
733 // Generate the non-static variant.
734 {
735 auto *method =
736 opClass.addMethodAndPrune("::mlir::Identifier", methodName,
737 OpMethod::Property(OpMethod::MP_Inline));
738 method->body() << llvm::formatv(attrNameMethodBody, attrIt.second).str();
739 }
740
741 // Generate the static variant.
742 {
743 auto *method = opClass.addMethodAndPrune(
744 "::mlir::Identifier", methodName,
745 OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static),
746 "::mlir::OperationName", "name");
747 method->body() << llvm::formatv(attrNameMethodBody,
748 "name, " + Twine(attrIt.second))
749 .str();
750 }
751 }
752 }
753
getAttrNameIndex(StringRef attrName) const754 unsigned OpEmitter::getAttrNameIndex(StringRef attrName) const {
755 auto it = attributeNames.find(attrName);
756 assert(it != attributeNames.end() && "expected attribute name to have been "
757 "registered in genAttrNameGetters");
758 return it->second;
759 }
760
genAttrGetters()761 void OpEmitter::genAttrGetters() {
762 FmtContext fctx;
763 fctx.withBuilder("::mlir::Builder((*this)->getContext())");
764
765 // Emit the derived attribute body.
766 auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
767 if (auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name))
768 method->body() << " " << attr.getDerivedCodeBody() << "\n";
769 };
770
771 // Emit with return type specified.
772 auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
773 auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
774 auto &body = method->body();
775 body << " auto attr = " << name << "Attr();\n";
776 if (attr.hasDefaultValue()) {
777 // Returns the default value if not set.
778 // TODO: this is inefficient, we are recreating the attribute for every
779 // call. This should be set instead.
780 std::string defaultValue = std::string(
781 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
782 body << " if (!attr)\n return "
783 << tgfmt(attr.getConvertFromStorageCall(),
784 &fctx.withSelf(defaultValue))
785 << ";\n";
786 }
787 body << " return "
788 << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
789 << ";\n";
790 };
791
792 // Generate raw named accessor type. This is a wrapper class that allows
793 // referring to the attributes via accessors instead of having to use
794 // the string interface for better compile time verification.
795 auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
796 auto *method =
797 opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
798 if (!method)
799 return;
800 auto &body = method->body();
801 body << " return (*this)->getAttr(" << name << "AttrName()).template ";
802 if (attr.isOptional() || attr.hasDefaultValue())
803 body << "dyn_cast_or_null<";
804 else
805 body << "cast<";
806 body << attr.getStorageType() << ">();";
807 };
808
809 for (const NamedAttribute &namedAttr : op.getAttributes()) {
810 if (namedAttr.attr.isDerivedAttr()) {
811 emitDerivedAttr(namedAttr.name, namedAttr.attr);
812 } else {
813 emitAttrWithStorageType(namedAttr.name, namedAttr.attr);
814 emitAttrWithReturnType(namedAttr.name, namedAttr.attr);
815 }
816 }
817
818 auto derivedAttrs = make_filter_range(op.getAttributes(),
819 [](const NamedAttribute &namedAttr) {
820 return namedAttr.attr.isDerivedAttr();
821 });
822 if (!derivedAttrs.empty()) {
823 opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
824 // Generate helper method to query whether a named attribute is a derived
825 // attribute. This enables, for example, avoiding adding an attribute that
826 // overlaps with a derived attribute.
827 {
828 auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute",
829 OpMethod::MP_Static,
830 "::llvm::StringRef", "name");
831 auto &body = method->body();
832 for (auto namedAttr : derivedAttrs)
833 body << " if (name == \"" << namedAttr.name << "\") return true;\n";
834 body << " return false;";
835 }
836 // Generate method to materialize derived attributes as a DictionaryAttr.
837 {
838 auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr",
839 "materializeDerivedAttributes");
840 auto &body = method->body();
841
842 auto nonMaterializable =
843 make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
844 return namedAttr.attr.getConvertFromStorageCall().empty();
845 });
846 if (!nonMaterializable.empty()) {
847 std::string attrs;
848 llvm::raw_string_ostream os(attrs);
849 interleaveComma(nonMaterializable, os,
850 [&](const NamedAttribute &attr) { os << attr.name; });
851 PrintWarning(
852 op.getLoc(),
853 formatv(
854 "op has non-materializable derived attributes '{0}', skipping",
855 os.str()));
856 body << formatv(" emitOpError(\"op has non-materializable derived "
857 "attributes '{0}'\");\n",
858 attrs);
859 body << " return nullptr;";
860 return;
861 }
862
863 body << " ::mlir::MLIRContext* ctx = getContext();\n";
864 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
865 body << " return ::mlir::DictionaryAttr::get(";
866 body << " ctx, {\n";
867 interleave(
868 derivedAttrs, body,
869 [&](const NamedAttribute &namedAttr) {
870 auto tmpl = namedAttr.attr.getConvertFromStorageCall();
871 body << " {" << namedAttr.name << "AttrName(),\n"
872 << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()")
873 .withBuilder("odsBuilder")
874 .addSubst("_ctx", "ctx"))
875 << "}";
876 },
877 ",\n");
878 body << "});";
879 }
880 }
881 }
882
genAttrSetters()883 void OpEmitter::genAttrSetters() {
884 // Generate raw named setter type. This is a wrapper class that allows setting
885 // to the attributes via setters instead of having to use the string interface
886 // for better compile time verification.
887 auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
888 auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
889 attr.getStorageType(), "attr");
890 if (method)
891 method->body() << " (*this)->setAttr(" << name << "AttrName(), attr);";
892 };
893
894 for (const NamedAttribute &namedAttr : op.getAttributes())
895 if (!namedAttr.attr.isDerivedAttr())
896 emitAttrWithStorageType(namedAttr.name, namedAttr.attr);
897 }
898
genOptionalAttrRemovers()899 void OpEmitter::genOptionalAttrRemovers() {
900 // Generate methods for removing optional attributes, instead of having to
901 // use the string interface. Enables better compile time verification.
902 auto emitRemoveAttr = [&](StringRef name) {
903 auto upperInitial = name.take_front().upper();
904 auto suffix = name.drop_front();
905 auto *method = opClass.addMethodAndPrune(
906 "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
907 if (!method)
908 return;
909 method->body() << " return (*this)->removeAttr(" << name << "AttrName());";
910 };
911
912 for (const NamedAttribute &namedAttr : op.getAttributes())
913 if (namedAttr.attr.isOptional())
914 emitRemoveAttr(namedAttr.name);
915 }
916
917 // Generates the code to compute the start and end index of an operand or result
918 // range.
919 template <typename RangeT>
920 static void
generateValueRangeStartAndEnd(Class & opClass,StringRef methodName,int numVariadic,int numNonVariadic,StringRef rangeSizeCall,bool hasAttrSegmentSize,StringRef sizeAttrInit,RangeT && odsValues)921 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
922 int numVariadic, int numNonVariadic,
923 StringRef rangeSizeCall, bool hasAttrSegmentSize,
924 StringRef sizeAttrInit, RangeT &&odsValues) {
925 auto *method = opClass.addMethodAndPrune("std::pair<unsigned, unsigned>",
926 methodName, "unsigned", "index");
927 if (!method)
928 return;
929 auto &body = method->body();
930 if (numVariadic == 0) {
931 body << " return {index, 1};\n";
932 } else if (hasAttrSegmentSize) {
933 body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
934 } else {
935 // Because the op can have arbitrarily interleaved variadic and non-variadic
936 // operands, we need to embed a list in the "sink" getter method for
937 // calculation at run-time.
938 llvm::SmallVector<StringRef, 4> isVariadic;
939 isVariadic.reserve(llvm::size(odsValues));
940 for (auto &it : odsValues)
941 isVariadic.push_back(it.isVariableLength() ? "true" : "false");
942 std::string isVariadicList = llvm::join(isVariadic, ", ");
943 body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
944 numNonVariadic, numVariadic, rangeSizeCall, "operand");
945 }
946 }
947
948 // Generates the named operand getter methods for the given Operator `op` and
949 // puts them in `opClass`. Uses `rangeType` as the return type of getters that
950 // return a range of operands (individual operands are `Value ` and each
951 // element in the range must also be `Value `); use `rangeBeginCall` to get
952 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
953 // obtain the number of operands. `getOperandCallPattern` contains the code
954 // necessary to obtain a single operand whose position will be substituted
955 // instead of
956 // "{0}" marker in the pattern. Note that the pattern should work for any kind
957 // of ops, in particular for one-operand ops that may not have the
958 // `getOperand(unsigned)` method.
generateNamedOperandGetters(const Operator & op,Class & opClass,StringRef sizeAttrInit,StringRef rangeType,StringRef rangeBeginCall,StringRef rangeSizeCall,StringRef getOperandCallPattern)959 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
960 StringRef sizeAttrInit,
961 StringRef rangeType,
962 StringRef rangeBeginCall,
963 StringRef rangeSizeCall,
964 StringRef getOperandCallPattern) {
965 const int numOperands = op.getNumOperands();
966 const int numVariadicOperands = op.getNumVariableLengthOperands();
967 const int numNormalOperands = numOperands - numVariadicOperands;
968
969 const auto *sameVariadicSize =
970 op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
971 const auto *attrSizedOperands =
972 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
973
974 if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
975 PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
976 "specification over their sizes");
977 }
978
979 if (numVariadicOperands < 2 && attrSizedOperands) {
980 PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
981 "to use 'AttrSizedOperandSegments' trait");
982 }
983
984 if (attrSizedOperands && sameVariadicSize) {
985 PrintFatalError(op.getLoc(),
986 "op cannot have both 'AttrSizedOperandSegments' and "
987 "'SameVariadicOperandSize' traits");
988 }
989
990 // First emit a few "sink" getter methods upon which we layer all nicer named
991 // getter methods.
992 generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
993 numVariadicOperands, numNormalOperands,
994 rangeSizeCall, attrSizedOperands, sizeAttrInit,
995 const_cast<Operator &>(op).getOperands());
996
997 auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned",
998 "index");
999 auto &body = m->body();
1000 body << formatv(valueRangeReturnCode, rangeBeginCall,
1001 "getODSOperandIndexAndLength(index)");
1002
1003 // Then we emit nicer named getter methods by redirecting to the "sink" getter
1004 // method.
1005 for (int i = 0; i != numOperands; ++i) {
1006 const auto &operand = op.getOperand(i);
1007 if (operand.name.empty())
1008 continue;
1009
1010 if (operand.isOptional()) {
1011 m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
1012 m->body()
1013 << " auto operands = getODSOperands(" << i << ");\n"
1014 << " return operands.empty() ? ::mlir::Value() : *operands.begin();";
1015 } else if (operand.isVariadic()) {
1016 m = opClass.addMethodAndPrune(rangeType, operand.name);
1017 m->body() << " return getODSOperands(" << i << ");";
1018 } else {
1019 m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
1020 m->body() << " return *getODSOperands(" << i << ").begin();";
1021 }
1022 }
1023 }
1024
genNamedOperandGetters()1025 void OpEmitter::genNamedOperandGetters() {
1026 // Build the code snippet used for initializing the operand_segment_sizes
1027 // array.
1028 std::string attrSizeInitCode;
1029 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1030 attrSizeInitCode =
1031 formatv(opSegmentSizeAttrInitCode, "operand_segment_sizesAttrName()")
1032 .str();
1033 }
1034
1035 generateNamedOperandGetters(
1036 op, opClass,
1037 /*sizeAttrInit=*/attrSizeInitCode,
1038 /*rangeType=*/"::mlir::Operation::operand_range",
1039 /*rangeBeginCall=*/"getOperation()->operand_begin()",
1040 /*rangeSizeCall=*/"getOperation()->getNumOperands()",
1041 /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
1042 }
1043
genNamedOperandSetters()1044 void OpEmitter::genNamedOperandSetters() {
1045 auto *attrSizedOperands =
1046 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1047 for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
1048 const auto &operand = op.getOperand(i);
1049 if (operand.name.empty())
1050 continue;
1051 auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange",
1052 (operand.name + "Mutable").str());
1053 auto &body = m->body();
1054 body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
1055 << " return ::mlir::MutableOperandRange(getOperation(), "
1056 "range.first, range.second";
1057 if (attrSizedOperands)
1058 body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
1059 << "u, *getOperation()->getAttrDictionary().getNamed("
1060 "operand_segment_sizesAttrName()))";
1061 body << ");\n";
1062 }
1063 }
1064
genNamedResultGetters()1065 void OpEmitter::genNamedResultGetters() {
1066 const int numResults = op.getNumResults();
1067 const int numVariadicResults = op.getNumVariableLengthResults();
1068 const int numNormalResults = numResults - numVariadicResults;
1069
1070 // If we have more than one variadic results, we need more complicated logic
1071 // to calculate the value range for each result.
1072
1073 const auto *sameVariadicSize =
1074 op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
1075 const auto *attrSizedResults =
1076 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
1077
1078 if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
1079 PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
1080 "specification over their sizes");
1081 }
1082
1083 if (numVariadicResults < 2 && attrSizedResults) {
1084 PrintFatalError(op.getLoc(), "op must have at least two variadic results "
1085 "to use 'AttrSizedResultSegments' trait");
1086 }
1087
1088 if (attrSizedResults && sameVariadicSize) {
1089 PrintFatalError(op.getLoc(),
1090 "op cannot have both 'AttrSizedResultSegments' and "
1091 "'SameVariadicResultSize' traits");
1092 }
1093
1094 // Build the initializer string for the result segment size attribute.
1095 std::string attrSizeInitCode;
1096 if (attrSizedResults) {
1097 attrSizeInitCode =
1098 formatv(opSegmentSizeAttrInitCode, "result_segment_sizesAttrName()")
1099 .str();
1100 }
1101
1102 generateValueRangeStartAndEnd(
1103 opClass, "getODSResultIndexAndLength", numVariadicResults,
1104 numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
1105 attrSizeInitCode, op.getResults());
1106
1107 auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
1108 "getODSResults", "unsigned", "index");
1109 m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
1110 "getODSResultIndexAndLength(index)");
1111
1112 for (int i = 0; i != numResults; ++i) {
1113 const auto &result = op.getResult(i);
1114 if (result.name.empty())
1115 continue;
1116
1117 if (result.isOptional()) {
1118 m = opClass.addMethodAndPrune("::mlir::Value", result.name);
1119 m->body()
1120 << " auto results = getODSResults(" << i << ");\n"
1121 << " return results.empty() ? ::mlir::Value() : *results.begin();";
1122 } else if (result.isVariadic()) {
1123 m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
1124 result.name);
1125 m->body() << " return getODSResults(" << i << ");";
1126 } else {
1127 m = opClass.addMethodAndPrune("::mlir::Value", result.name);
1128 m->body() << " return *getODSResults(" << i << ").begin();";
1129 }
1130 }
1131 }
1132
genNamedRegionGetters()1133 void OpEmitter::genNamedRegionGetters() {
1134 unsigned numRegions = op.getNumRegions();
1135 for (unsigned i = 0; i < numRegions; ++i) {
1136 const auto ®ion = op.getRegion(i);
1137 if (region.name.empty())
1138 continue;
1139
1140 // Generate the accessors for a variadic region.
1141 if (region.isVariadic()) {
1142 auto *m = opClass.addMethodAndPrune(
1143 "::mlir::MutableArrayRef<::mlir::Region>", region.name);
1144 m->body() << formatv(" return (*this)->getRegions().drop_front({0});",
1145 i);
1146 continue;
1147 }
1148
1149 auto *m = opClass.addMethodAndPrune("::mlir::Region &", region.name);
1150 m->body() << formatv(" return (*this)->getRegion({0});", i);
1151 }
1152 }
1153
genNamedSuccessorGetters()1154 void OpEmitter::genNamedSuccessorGetters() {
1155 unsigned numSuccessors = op.getNumSuccessors();
1156 for (unsigned i = 0; i < numSuccessors; ++i) {
1157 const NamedSuccessor &successor = op.getSuccessor(i);
1158 if (successor.name.empty())
1159 continue;
1160
1161 // Generate the accessors for a variadic successor list.
1162 if (successor.isVariadic()) {
1163 auto *m =
1164 opClass.addMethodAndPrune("::mlir::SuccessorRange", successor.name);
1165 m->body() << formatv(
1166 " return {std::next((*this)->successor_begin(), {0}), "
1167 "(*this)->successor_end()};",
1168 i);
1169 continue;
1170 }
1171
1172 auto *m = opClass.addMethodAndPrune("::mlir::Block *", successor.name);
1173 m->body() << formatv(" return (*this)->getSuccessor({0});", i);
1174 }
1175 }
1176
canGenerateUnwrappedBuilder(Operator & op)1177 static bool canGenerateUnwrappedBuilder(Operator &op) {
1178 // If this op does not have native attributes at all, return directly to avoid
1179 // redefining builders.
1180 if (op.getNumNativeAttributes() == 0)
1181 return false;
1182
1183 bool canGenerate = false;
1184 // We are generating builders that take raw values for attributes. We need to
1185 // make sure the native attributes have a meaningful "unwrapped" value type
1186 // different from the wrapped mlir::Attribute type to avoid redefining
1187 // builders. This checks for the op has at least one such native attribute.
1188 for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
1189 NamedAttribute &namedAttr = op.getAttribute(i);
1190 if (canUseUnwrappedRawValue(namedAttr.attr)) {
1191 canGenerate = true;
1192 break;
1193 }
1194 }
1195 return canGenerate;
1196 }
1197
canInferType(Operator & op)1198 static bool canInferType(Operator &op) {
1199 return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
1200 op.getNumRegions() == 0;
1201 }
1202
genSeparateArgParamBuilder()1203 void OpEmitter::genSeparateArgParamBuilder() {
1204 SmallVector<AttrParamKind, 2> attrBuilderType;
1205 attrBuilderType.push_back(AttrParamKind::WrappedAttr);
1206 if (canGenerateUnwrappedBuilder(op))
1207 attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
1208
1209 // Emit with separate builders with or without unwrapped attributes and/or
1210 // inferring result type.
1211 auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
1212 bool inferType) {
1213 llvm::SmallVector<OpMethodParameter, 4> paramList;
1214 llvm::SmallVector<std::string, 4> resultNames;
1215 buildParamList(paramList, resultNames, paramKind, attrType);
1216
1217 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1218 std::move(paramList));
1219 // If the builder is redundant, skip generating the method.
1220 if (!m)
1221 return;
1222 auto &body = m->body();
1223 genCodeForAddingArgAndRegionForBuilder(
1224 body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
1225
1226 // Push all result types to the operation state
1227
1228 if (inferType) {
1229 // Generate builder that infers type too.
1230 // TODO: Subsume this with general checking if type can be
1231 // inferred automatically.
1232 // TODO: Expand to handle regions.
1233 body << formatv(R"(
1234 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1235 if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1236 {1}.location, {1}.operands,
1237 {1}.attributes.getDictionary({1}.getContext()),
1238 /*regions=*/{{}, inferredReturnTypes)))
1239 {1}.addTypes(inferredReturnTypes);
1240 else
1241 ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1242 opClass.getClassName(), builderOpState);
1243 return;
1244 }
1245
1246 switch (paramKind) {
1247 case TypeParamKind::None:
1248 return;
1249 case TypeParamKind::Separate:
1250 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
1251 if (op.getResult(i).isOptional())
1252 body << " if (" << resultNames[i] << ")\n ";
1253 body << " " << builderOpState << ".addTypes(" << resultNames[i]
1254 << ");\n";
1255 }
1256 return;
1257 case TypeParamKind::Collective: {
1258 int numResults = op.getNumResults();
1259 int numVariadicResults = op.getNumVariableLengthResults();
1260 int numNonVariadicResults = numResults - numVariadicResults;
1261 bool hasVariadicResult = numVariadicResults != 0;
1262
1263 // Avoid emitting "resultTypes.size() >= 0u" which is always true.
1264 if (!(hasVariadicResult && numNonVariadicResults == 0))
1265 body << " "
1266 << "assert(resultTypes.size() "
1267 << (hasVariadicResult ? ">=" : "==") << " "
1268 << numNonVariadicResults
1269 << "u && \"mismatched number of results\");\n";
1270 body << " " << builderOpState << ".addTypes(resultTypes);\n";
1271 }
1272 return;
1273 }
1274 llvm_unreachable("unhandled TypeParamKind");
1275 };
1276
1277 // Some of the build methods generated here may be ambiguous, but TableGen's
1278 // ambiguous function detection will elide those ones.
1279 for (auto attrType : attrBuilderType) {
1280 emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
1281 if (canInferType(op))
1282 emit(attrType, TypeParamKind::None, /*inferType=*/true);
1283 emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
1284 }
1285 }
1286
genUseOperandAsResultTypeCollectiveParamBuilder()1287 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
1288 int numResults = op.getNumResults();
1289
1290 // Signature
1291 llvm::SmallVector<OpMethodParameter, 4> paramList;
1292 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1293 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1294 paramList.emplace_back("::mlir::ValueRange", "operands");
1295 // Provide default value for `attributes` when its the last parameter
1296 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1297 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1298 "attributes", attributesDefaultValue);
1299 if (op.getNumVariadicRegions())
1300 paramList.emplace_back("unsigned", "numRegions");
1301
1302 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1303 std::move(paramList));
1304 // If the builder is redundant, skip generating the method
1305 if (!m)
1306 return;
1307 auto &body = m->body();
1308
1309 // Operands
1310 body << " " << builderOpState << ".addOperands(operands);\n";
1311
1312 // Attributes
1313 body << " " << builderOpState << ".addAttributes(attributes);\n";
1314
1315 // Create the correct number of regions
1316 if (int numRegions = op.getNumRegions()) {
1317 body << llvm::formatv(
1318 " for (unsigned i = 0; i != {0}; ++i)\n",
1319 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1320 body << " (void)" << builderOpState << ".addRegion();\n";
1321 }
1322
1323 // Result types
1324 SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
1325 body << " " << builderOpState << ".addTypes({"
1326 << llvm::join(resultTypes, ", ") << "});\n\n";
1327 }
1328
genInferredTypeCollectiveParamBuilder()1329 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
1330 // TODO: Expand to support regions.
1331 SmallVector<OpMethodParameter, 4> paramList;
1332 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1333 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1334 paramList.emplace_back("::mlir::ValueRange", "operands");
1335 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1336 "attributes", "{}");
1337 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1338 std::move(paramList));
1339 // If the builder is redundant, skip generating the method
1340 if (!m)
1341 return;
1342 auto &body = m->body();
1343
1344 int numResults = op.getNumResults();
1345 int numVariadicResults = op.getNumVariableLengthResults();
1346 int numNonVariadicResults = numResults - numVariadicResults;
1347
1348 int numOperands = op.getNumOperands();
1349 int numVariadicOperands = op.getNumVariableLengthOperands();
1350 int numNonVariadicOperands = numOperands - numVariadicOperands;
1351
1352 // Operands
1353 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1354 body << " assert(operands.size()"
1355 << (numVariadicOperands != 0 ? " >= " : " == ")
1356 << numNonVariadicOperands
1357 << "u && \"mismatched number of parameters\");\n";
1358 body << " " << builderOpState << ".addOperands(operands);\n";
1359 body << " " << builderOpState << ".addAttributes(attributes);\n";
1360
1361 // Create the correct number of regions
1362 if (int numRegions = op.getNumRegions()) {
1363 body << llvm::formatv(
1364 " for (unsigned i = 0; i != {0}; ++i)\n",
1365 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1366 body << " (void)" << builderOpState << ".addRegion();\n";
1367 }
1368
1369 // Result types
1370 body << formatv(R"(
1371 ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1372 if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1373 {1}.location, operands,
1374 {1}.attributes.getDictionary({1}.getContext()),
1375 /*regions=*/{{}, inferredReturnTypes))) {{)",
1376 opClass.getClassName(), builderOpState);
1377 if (numVariadicResults == 0 || numNonVariadicResults != 0)
1378 body << " assert(inferredReturnTypes.size()"
1379 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1380 << "u && \"mismatched number of return types\");\n";
1381 body << " " << builderOpState << ".addTypes(inferredReturnTypes);";
1382
1383 body << formatv(R"(
1384 } else
1385 ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1386 opClass.getClassName(), builderOpState);
1387 }
1388
genUseOperandAsResultTypeSeparateParamBuilder()1389 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
1390 llvm::SmallVector<OpMethodParameter, 4> paramList;
1391 llvm::SmallVector<std::string, 4> resultNames;
1392 buildParamList(paramList, resultNames, TypeParamKind::None);
1393
1394 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1395 std::move(paramList));
1396 // If the builder is redundant, skip generating the method
1397 if (!m)
1398 return;
1399 auto &body = m->body();
1400 genCodeForAddingArgAndRegionForBuilder(body);
1401
1402 auto numResults = op.getNumResults();
1403 if (numResults == 0)
1404 return;
1405
1406 // Push all result types to the operation state
1407 const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
1408 std::string resultType =
1409 formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
1410 body << " " << builderOpState << ".addTypes({" << resultType;
1411 for (int i = 1; i != numResults; ++i)
1412 body << ", " << resultType;
1413 body << "});\n\n";
1414 }
1415
genUseAttrAsResultTypeBuilder()1416 void OpEmitter::genUseAttrAsResultTypeBuilder() {
1417 SmallVector<OpMethodParameter, 4> paramList;
1418 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1419 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1420 paramList.emplace_back("::mlir::ValueRange", "operands");
1421 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1422 "attributes", "{}");
1423 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1424 std::move(paramList));
1425 // If the builder is redundant, skip generating the method
1426 if (!m)
1427 return;
1428
1429 auto &body = m->body();
1430
1431 // Push all result types to the operation state
1432 std::string resultType;
1433 const auto &namedAttr = op.getAttribute(0);
1434
1435 body << " auto attrName = " << namedAttr.name << "AttrName("
1436 << builderOpState
1437 << ".name);\n"
1438 " for (auto attr : attributes) {\n"
1439 " if (attr.first != attrName) continue;\n";
1440 if (namedAttr.attr.isTypeAttr()) {
1441 resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()";
1442 } else {
1443 resultType = "attr.second.getType()";
1444 }
1445
1446 // Operands
1447 body << " " << builderOpState << ".addOperands(operands);\n";
1448
1449 // Attributes
1450 body << " " << builderOpState << ".addAttributes(attributes);\n";
1451
1452 // Result types
1453 SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1454 body << " " << builderOpState << ".addTypes({"
1455 << llvm::join(resultTypes, ", ") << "});\n";
1456 body << " }\n";
1457 }
1458
1459 /// Returns a signature of the builder. Updates the context `fctx` to enable
1460 /// replacement of $_builder and $_state in the body.
getBuilderSignature(const Builder & builder)1461 static std::string getBuilderSignature(const Builder &builder) {
1462 ArrayRef<Builder::Parameter> params(builder.getParameters());
1463
1464 // Inject builder and state arguments.
1465 llvm::SmallVector<std::string, 8> arguments;
1466 arguments.reserve(params.size() + 2);
1467 arguments.push_back(
1468 llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str());
1469 arguments.push_back(
1470 llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());
1471
1472 for (unsigned i = 0, e = params.size(); i < e; ++i) {
1473 // If no name is provided, generate one.
1474 Optional<StringRef> paramName = params[i].getName();
1475 std::string name =
1476 paramName ? paramName->str() : "odsArg" + std::to_string(i);
1477
1478 std::string defaultValue;
1479 if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
1480 defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str();
1481 arguments.push_back(
1482 llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue)
1483 .str());
1484 }
1485
1486 return llvm::join(arguments, ", ");
1487 }
1488
genBuilder()1489 void OpEmitter::genBuilder() {
1490 // Handle custom builders if provided.
1491 for (const Builder &builder : op.getBuilders()) {
1492 std::string paramStr = getBuilderSignature(builder);
1493
1494 Optional<StringRef> body = builder.getBody();
1495 OpMethod::Property properties =
1496 body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
1497 auto *method =
1498 opClass.addMethodAndPrune("void", "build", properties, paramStr);
1499
1500 FmtContext fctx;
1501 fctx.withBuilder(odsBuilder);
1502 fctx.addSubst("_state", builderOpState);
1503 if (body)
1504 method->body() << tgfmt(*body, &fctx);
1505 }
1506
1507 // Generate default builders that requires all result type, operands, and
1508 // attributes as parameters.
1509 if (op.skipDefaultBuilders())
1510 return;
1511
1512 // We generate three classes of builders here:
1513 // 1. one having a stand-alone parameter for each operand / attribute, and
1514 genSeparateArgParamBuilder();
1515 // 2. one having an aggregated parameter for all result types / operands /
1516 // attributes, and
1517 genCollectiveParamBuilder();
1518 // 3. one having a stand-alone parameter for each operand and attribute,
1519 // use the first operand or attribute's type as all result types
1520 // to facilitate different call patterns.
1521 if (op.getNumVariableLengthResults() == 0) {
1522 if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
1523 genUseOperandAsResultTypeSeparateParamBuilder();
1524 genUseOperandAsResultTypeCollectiveParamBuilder();
1525 }
1526 if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
1527 genUseAttrAsResultTypeBuilder();
1528 }
1529 }
1530
genCollectiveParamBuilder()1531 void OpEmitter::genCollectiveParamBuilder() {
1532 int numResults = op.getNumResults();
1533 int numVariadicResults = op.getNumVariableLengthResults();
1534 int numNonVariadicResults = numResults - numVariadicResults;
1535
1536 int numOperands = op.getNumOperands();
1537 int numVariadicOperands = op.getNumVariableLengthOperands();
1538 int numNonVariadicOperands = numOperands - numVariadicOperands;
1539
1540 SmallVector<OpMethodParameter, 4> paramList;
1541 paramList.emplace_back("::mlir::OpBuilder &", "");
1542 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1543 paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1544 paramList.emplace_back("::mlir::ValueRange", "operands");
1545 // Provide default value for `attributes` when its the last parameter
1546 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1547 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1548 "attributes", attributesDefaultValue);
1549 if (op.getNumVariadicRegions())
1550 paramList.emplace_back("unsigned", "numRegions");
1551
1552 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1553 std::move(paramList));
1554 // If the builder is redundant, skip generating the method
1555 if (!m)
1556 return;
1557 auto &body = m->body();
1558
1559 // Operands
1560 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1561 body << " assert(operands.size()"
1562 << (numVariadicOperands != 0 ? " >= " : " == ")
1563 << numNonVariadicOperands
1564 << "u && \"mismatched number of parameters\");\n";
1565 body << " " << builderOpState << ".addOperands(operands);\n";
1566
1567 // Attributes
1568 body << " " << builderOpState << ".addAttributes(attributes);\n";
1569
1570 // Create the correct number of regions
1571 if (int numRegions = op.getNumRegions()) {
1572 body << llvm::formatv(
1573 " for (unsigned i = 0; i != {0}; ++i)\n",
1574 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1575 body << " (void)" << builderOpState << ".addRegion();\n";
1576 }
1577
1578 // Result types
1579 if (numVariadicResults == 0 || numNonVariadicResults != 0)
1580 body << " assert(resultTypes.size()"
1581 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1582 << "u && \"mismatched number of return types\");\n";
1583 body << " " << builderOpState << ".addTypes(resultTypes);\n";
1584
1585 // Generate builder that infers type too.
1586 // TODO: Expand to handle regions and successors.
1587 if (canInferType(op) && op.getNumSuccessors() == 0)
1588 genInferredTypeCollectiveParamBuilder();
1589 }
1590
buildParamList(SmallVectorImpl<OpMethodParameter> & paramList,SmallVectorImpl<std::string> & resultTypeNames,TypeParamKind typeParamKind,AttrParamKind attrParamKind)1591 void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
1592 SmallVectorImpl<std::string> &resultTypeNames,
1593 TypeParamKind typeParamKind,
1594 AttrParamKind attrParamKind) {
1595 resultTypeNames.clear();
1596 auto numResults = op.getNumResults();
1597 resultTypeNames.reserve(numResults);
1598
1599 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1600 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1601
1602 switch (typeParamKind) {
1603 case TypeParamKind::None:
1604 break;
1605 case TypeParamKind::Separate: {
1606 // Add parameters for all return types
1607 for (int i = 0; i < numResults; ++i) {
1608 const auto &result = op.getResult(i);
1609 std::string resultName = std::string(result.name);
1610 if (resultName.empty())
1611 resultName = std::string(formatv("resultType{0}", i));
1612
1613 StringRef type =
1614 result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
1615 OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1616 if (result.isOptional())
1617 properties = OpMethodParameter::PP_Optional;
1618
1619 paramList.emplace_back(type, resultName, properties);
1620 resultTypeNames.emplace_back(std::move(resultName));
1621 }
1622 } break;
1623 case TypeParamKind::Collective: {
1624 paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1625 resultTypeNames.push_back("resultTypes");
1626 } break;
1627 }
1628
1629 // Add parameters for all arguments (operands and attributes).
1630
1631 int numOperands = 0;
1632 int numAttrs = 0;
1633
1634 int defaultValuedAttrStartIndex = op.getNumArgs();
1635 if (attrParamKind == AttrParamKind::UnwrappedValue) {
1636 // Calculate the start index from which we can attach default values in the
1637 // builder declaration.
1638 for (int i = op.getNumArgs() - 1; i >= 0; --i) {
1639 auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
1640 if (!namedAttr || !namedAttr->attr.hasDefaultValue())
1641 break;
1642
1643 if (!canUseUnwrappedRawValue(namedAttr->attr))
1644 break;
1645
1646 // Creating an APInt requires us to provide bitwidth, value, and
1647 // signedness, which is complicated compared to others. Similarly
1648 // for APFloat.
1649 // TODO: Adjust the 'returnType' field of such attributes
1650 // to support them.
1651 StringRef retType = namedAttr->attr.getReturnType();
1652 if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
1653 break;
1654
1655 defaultValuedAttrStartIndex = i;
1656 }
1657 }
1658
1659 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
1660 auto argument = op.getArg(i);
1661 if (argument.is<tblgen::NamedTypeConstraint *>()) {
1662 const auto &operand = op.getOperand(numOperands);
1663 StringRef type =
1664 operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value";
1665 OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1666 if (operand.isOptional())
1667 properties = OpMethodParameter::PP_Optional;
1668
1669 paramList.emplace_back(type, getArgumentName(op, numOperands),
1670 properties);
1671 ++numOperands;
1672 } else {
1673 const auto &namedAttr = op.getAttribute(numAttrs);
1674 const auto &attr = namedAttr.attr;
1675
1676 OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1677 if (attr.isOptional())
1678 properties = OpMethodParameter::PP_Optional;
1679
1680 StringRef type;
1681 switch (attrParamKind) {
1682 case AttrParamKind::WrappedAttr:
1683 type = attr.getStorageType();
1684 break;
1685 case AttrParamKind::UnwrappedValue:
1686 if (canUseUnwrappedRawValue(attr))
1687 type = attr.getReturnType();
1688 else
1689 type = attr.getStorageType();
1690 break;
1691 }
1692
1693 std::string defaultValue;
1694 // Attach default value if requested and possible.
1695 if (attrParamKind == AttrParamKind::UnwrappedValue &&
1696 i >= defaultValuedAttrStartIndex) {
1697 bool isString = attr.getReturnType() == "::llvm::StringRef";
1698 if (isString)
1699 defaultValue.append("\"");
1700 defaultValue += attr.getDefaultValue();
1701 if (isString)
1702 defaultValue.append("\"");
1703 }
1704 paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
1705 ++numAttrs;
1706 }
1707 }
1708
1709 /// Insert parameters for each successor.
1710 for (const NamedSuccessor &succ : op.getSuccessors()) {
1711 StringRef type =
1712 succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
1713 paramList.emplace_back(type, succ.name);
1714 }
1715
1716 /// Insert parameters for variadic regions.
1717 for (const NamedRegion ®ion : op.getRegions())
1718 if (region.isVariadic())
1719 paramList.emplace_back("unsigned",
1720 llvm::formatv("{0}Count", region.name).str());
1721 }
1722
genCodeForAddingArgAndRegionForBuilder(OpMethodBody & body,bool isRawValueAttr)1723 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
1724 bool isRawValueAttr) {
1725 // Push all operands to the result.
1726 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
1727 std::string argName = getArgumentName(op, i);
1728 if (op.getOperand(i).isOptional())
1729 body << " if (" << argName << ")\n ";
1730 body << " " << builderOpState << ".addOperands(" << argName << ");\n";
1731 }
1732
1733 // If the operation has the operand segment size attribute, add it here.
1734 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1735 body << " " << builderOpState
1736 << ".addAttribute(operand_segment_sizesAttrName(" << builderOpState
1737 << ".name), "
1738 << "odsBuilder.getI32VectorAttr({";
1739 interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1740 if (op.getOperand(i).isOptional())
1741 body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
1742 else if (op.getOperand(i).isVariadic())
1743 body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
1744 else
1745 body << "1";
1746 });
1747 body << "}));\n";
1748 }
1749
1750 // Push all attributes to the result.
1751 for (const auto &namedAttr : op.getAttributes()) {
1752 auto &attr = namedAttr.attr;
1753 if (!attr.isDerivedAttr()) {
1754 bool emitNotNullCheck = attr.isOptional();
1755 if (emitNotNullCheck)
1756 body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
1757
1758 if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
1759 // If this is a raw value, then we need to wrap it in an Attribute
1760 // instance.
1761 FmtContext fctx;
1762 fctx.withBuilder("odsBuilder");
1763
1764 std::string builderTemplate =
1765 std::string(attr.getConstBuilderTemplate());
1766
1767 // For StringAttr, its constant builder call will wrap the input in
1768 // quotes, which is correct for normal string literals, but incorrect
1769 // here given we use function arguments. So we need to strip the
1770 // wrapping quotes.
1771 if (StringRef(builderTemplate).contains("\"$0\""))
1772 builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
1773
1774 std::string value =
1775 std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
1776 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
1777 builderOpState, namedAttr.name, value);
1778 } else {
1779 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {1});\n",
1780 builderOpState, namedAttr.name);
1781 }
1782 if (emitNotNullCheck)
1783 body << " }\n";
1784 }
1785 }
1786
1787 // Create the correct number of regions.
1788 for (const NamedRegion ®ion : op.getRegions()) {
1789 if (region.isVariadic())
1790 body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
1791 region.name);
1792
1793 body << " (void)" << builderOpState << ".addRegion();\n";
1794 }
1795
1796 // Push all successors to the result.
1797 for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
1798 body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
1799 namedSuccessor.name);
1800 }
1801 }
1802
genCanonicalizerDecls()1803 void OpEmitter::genCanonicalizerDecls() {
1804 bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod");
1805 if (hasCanonicalizeMethod) {
1806 // static LogicResult FooOp::
1807 // canonicalize(FooOp op, PatternRewriter &rewriter);
1808 SmallVector<OpMethodParameter, 2> paramList;
1809 paramList.emplace_back(op.getCppClassName(), "op");
1810 paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
1811 opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize",
1812 OpMethod::MP_StaticDeclaration,
1813 std::move(paramList));
1814 }
1815
1816 // We get a prototype for 'getCanonicalizationPatterns' if requested directly
1817 // or if using a 'canonicalize' method.
1818 bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer");
1819 if (!hasCanonicalizeMethod && !hasCanonicalizer)
1820 return;
1821
1822 // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
1823 // method, but not implementing 'getCanonicalizationPatterns' manually.
1824 bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
1825
1826 // Add a signature for getCanonicalizationPatterns if implemented by the
1827 // dialect or if synthesized to call 'canonicalize'.
1828 SmallVector<OpMethodParameter, 2> paramList;
1829 paramList.emplace_back("::mlir::RewritePatternSet &", "results");
1830 paramList.emplace_back("::mlir::MLIRContext *", "context");
1831 auto kind = hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
1832 auto *method = opClass.addMethodAndPrune(
1833 "void", "getCanonicalizationPatterns", kind, std::move(paramList));
1834
1835 // If synthesizing the method, fill it it.
1836 if (hasBody)
1837 method->body() << " results.add(canonicalize);\n";
1838 }
1839
genFolderDecls()1840 void OpEmitter::genFolderDecls() {
1841 bool hasSingleResult =
1842 op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
1843
1844 if (def.getValueAsBit("hasFolder")) {
1845 if (hasSingleResult) {
1846 opClass.addMethodAndPrune(
1847 "::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration,
1848 "::llvm::ArrayRef<::mlir::Attribute>", "operands");
1849 } else {
1850 SmallVector<OpMethodParameter, 2> paramList;
1851 paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
1852 paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
1853 "results");
1854 opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
1855 OpMethod::MP_Declaration, std::move(paramList));
1856 }
1857 }
1858 }
1859
genOpInterfaceMethods(const tblgen::InterfaceTrait * opTrait)1860 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
1861 Interface interface = opTrait->getInterface();
1862
1863 // Get the set of methods that should always be declared.
1864 auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
1865 llvm::StringSet<> alwaysDeclaredMethods;
1866 alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
1867 alwaysDeclaredMethodsVec.end());
1868
1869 for (const InterfaceMethod &method : interface.getMethods()) {
1870 // Don't declare if the method has a body.
1871 if (method.getBody())
1872 continue;
1873 // Don't declare if the method has a default implementation and the op
1874 // didn't request that it always be declared.
1875 if (method.getDefaultImplementation() &&
1876 !alwaysDeclaredMethods.count(method.getName()))
1877 continue;
1878 genOpInterfaceMethod(method);
1879 }
1880 }
1881
genOpInterfaceMethod(const InterfaceMethod & method,bool declaration)1882 OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
1883 bool declaration) {
1884 SmallVector<OpMethodParameter, 4> paramList;
1885 for (const InterfaceMethod::Argument &arg : method.getArguments())
1886 paramList.emplace_back(arg.type, arg.name);
1887
1888 auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
1889 if (declaration)
1890 properties =
1891 static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
1892 return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
1893 properties, std::move(paramList));
1894 }
1895
genOpInterfaceMethods()1896 void OpEmitter::genOpInterfaceMethods() {
1897 for (const auto &trait : op.getTraits()) {
1898 if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
1899 if (opTrait->shouldDeclareMethods())
1900 genOpInterfaceMethods(opTrait);
1901 }
1902 }
1903
genSideEffectInterfaceMethods()1904 void OpEmitter::genSideEffectInterfaceMethods() {
1905 enum EffectKind { Operand, Result, Symbol, Static };
1906 struct EffectLocation {
1907 /// The effect applied.
1908 SideEffect effect;
1909
1910 /// The index if the kind is not static.
1911 unsigned index : 30;
1912
1913 /// The kind of the location.
1914 unsigned kind : 2;
1915 };
1916
1917 StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
1918 auto resolveDecorators = [&](Operator::var_decorator_range decorators,
1919 unsigned index, unsigned kind) {
1920 for (auto decorator : decorators)
1921 if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
1922 opClass.addTrait(effect->getInterfaceTrait());
1923 interfaceEffects[effect->getBaseEffectName()].push_back(
1924 EffectLocation{*effect, index, kind});
1925 }
1926 };
1927
1928 // Collect effects that were specified via:
1929 /// Traits.
1930 for (const auto &trait : op.getTraits()) {
1931 const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
1932 if (!opTrait)
1933 continue;
1934 auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
1935 for (auto decorator : opTrait->getEffects())
1936 effects.push_back(EffectLocation{cast<SideEffect>(decorator),
1937 /*index=*/0, EffectKind::Static});
1938 }
1939 /// Attributes and Operands.
1940 for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
1941 Argument arg = op.getArg(i);
1942 if (arg.is<NamedTypeConstraint *>()) {
1943 resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
1944 ++operandIt;
1945 continue;
1946 }
1947 const NamedAttribute *attr = arg.get<NamedAttribute *>();
1948 if (attr->attr.getBaseAttr().isSymbolRefAttr())
1949 resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
1950 }
1951 /// Results.
1952 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
1953 resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
1954
1955 // The code used to add an effect instance.
1956 // {0}: The effect class.
1957 // {1}: Optional value or symbol reference.
1958 // {1}: The resource class.
1959 const char *addEffectCode =
1960 " effects.emplace_back({0}::get(), {1}{2}::get());\n";
1961
1962 for (auto &it : interfaceEffects) {
1963 // Generate the 'getEffects' method.
1964 std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
1965 "SideEffects::EffectInstance<{0}>> &",
1966 it.first())
1967 .str();
1968 auto *getEffects =
1969 opClass.addMethodAndPrune("void", "getEffects", type, "effects");
1970 auto &body = getEffects->body();
1971
1972 // Add effect instances for each of the locations marked on the operation.
1973 for (auto &location : it.second) {
1974 StringRef effect = location.effect.getName();
1975 StringRef resource = location.effect.getResource();
1976 if (location.kind == EffectKind::Static) {
1977 // A static instance has no attached value.
1978 body << llvm::formatv(addEffectCode, effect, "", resource).str();
1979 } else if (location.kind == EffectKind::Symbol) {
1980 // A symbol reference requires adding the proper attribute.
1981 const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
1982 if (attr->attr.isOptional()) {
1983 body << " if (auto symbolRef = " << attr->name << "Attr())\n "
1984 << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
1985 .str();
1986 } else {
1987 body << llvm::formatv(addEffectCode, effect, attr->name + "(), ",
1988 resource)
1989 .str();
1990 }
1991 } else {
1992 // Otherwise this is an operand/result, so we need to attach the Value.
1993 body << " for (::mlir::Value value : getODS"
1994 << (location.kind == EffectKind::Operand ? "Operands" : "Results")
1995 << "(" << location.index << "))\n "
1996 << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
1997 }
1998 }
1999 }
2000 }
2001
genTypeInterfaceMethods()2002 void OpEmitter::genTypeInterfaceMethods() {
2003 if (!op.allResultTypesKnown())
2004 return;
2005 // Generate 'inferReturnTypes' method declaration using the interface method
2006 // declared in 'InferTypeOpInterface' op interface.
2007 const auto *trait = dyn_cast<InterfaceTrait>(
2008 op.getTrait("::mlir::InferTypeOpInterface::Trait"));
2009 Interface interface = trait->getInterface();
2010 OpMethod *method = [&]() -> OpMethod * {
2011 for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
2012 if (interfaceMethod.getName() == "inferReturnTypes") {
2013 return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
2014 }
2015 }
2016 assert(0 && "unable to find inferReturnTypes interface method");
2017 return nullptr;
2018 }();
2019 auto &body = method->body();
2020 body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
2021
2022 FmtContext fctx;
2023 fctx.withBuilder("odsBuilder");
2024 body << " ::mlir::Builder odsBuilder(context);\n";
2025
2026 auto emitType =
2027 [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & {
2028 if (type.isArg()) {
2029 auto argIndex = type.getArg();
2030 assert(!op.getArg(argIndex).is<NamedAttribute *>());
2031 auto arg = op.getArgToOperandOrAttribute(argIndex);
2032 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
2033 return body << "operands[" << arg.operandOrAttributeIndex()
2034 << "].getType()";
2035 return body << "attributes[" << arg.operandOrAttributeIndex()
2036 << "].getType()";
2037 } else {
2038 return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
2039 }
2040 };
2041
2042 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
2043 body << " inferredReturnTypes[" << i << "] = ";
2044 auto types = op.getSameTypeAsResult(i);
2045 emitType(types[0]) << ";\n";
2046 if (types.size() == 1)
2047 continue;
2048 // TODO: We could verify equality here, but skipping that for verification.
2049 }
2050 body << " return ::mlir::success();";
2051 }
2052
genParser()2053 void OpEmitter::genParser() {
2054 if (!hasStringAttribute(def, "parser") ||
2055 hasStringAttribute(def, "assemblyFormat"))
2056 return;
2057
2058 SmallVector<OpMethodParameter, 2> paramList;
2059 paramList.emplace_back("::mlir::OpAsmParser &", "parser");
2060 paramList.emplace_back("::mlir::OperationState &", "result");
2061 auto *method =
2062 opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
2063 OpMethod::MP_Static, std::move(paramList));
2064
2065 FmtContext fctx;
2066 fctx.addSubst("cppClass", opClass.getClassName());
2067 auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
2068 method->body() << " " << tgfmt(parser, &fctx);
2069 }
2070
genPrinter()2071 void OpEmitter::genPrinter() {
2072 if (hasStringAttribute(def, "assemblyFormat"))
2073 return;
2074
2075 auto valueInit = def.getValueInit("printer");
2076 StringInit *stringInit = dyn_cast<StringInit>(valueInit);
2077 if (!stringInit)
2078 return;
2079
2080 auto *method =
2081 opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p");
2082 FmtContext fctx;
2083 fctx.addSubst("cppClass", opClass.getClassName());
2084 auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
2085 method->body() << " " << tgfmt(printer, &fctx);
2086 }
2087
genVerifier()2088 void OpEmitter::genVerifier() {
2089 auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify");
2090 auto &body = method->body();
2091 body << " if (failed(" << op.getAdaptorName()
2092 << "(*this).verify((*this)->getLoc()))) "
2093 << "return ::mlir::failure();\n";
2094
2095 auto *valueInit = def.getValueInit("verifier");
2096 StringInit *stringInit = dyn_cast<StringInit>(valueInit);
2097 bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
2098 populateSubstitutions(op, "(*this)->getAttr", "this->getODSOperands",
2099 "this->getODSResults", verifyCtx);
2100
2101 genAttributeVerifier(op, "(*this)->getAttr", "emitOpError(",
2102 /*emitVerificationRequiringOp=*/true, verifyCtx, body);
2103 genOperandResultVerifier(body, op.getOperands(), "operand");
2104 genOperandResultVerifier(body, op.getResults(), "result");
2105
2106 for (auto &trait : op.getTraits()) {
2107 if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
2108 body << tgfmt(" if (!($0))\n "
2109 "return emitOpError(\"failed to verify that $1\");\n",
2110 &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
2111 t->getSummary());
2112 }
2113 }
2114
2115 genRegionVerifier(body);
2116 genSuccessorVerifier(body);
2117
2118 if (hasCustomVerify) {
2119 FmtContext fctx;
2120 fctx.addSubst("cppClass", opClass.getClassName());
2121 auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
2122 body << " " << tgfmt(printer, &fctx);
2123 } else {
2124 body << " return ::mlir::success();\n";
2125 }
2126 }
2127
genOperandResultVerifier(OpMethodBody & body,Operator::value_range values,StringRef valueKind)2128 void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
2129 Operator::value_range values,
2130 StringRef valueKind) {
2131 FmtContext fctx;
2132
2133 body << " {\n";
2134 body << " unsigned index = 0; (void)index;\n";
2135
2136 for (auto staticValue : llvm::enumerate(values)) {
2137 bool hasPredicate = staticValue.value().hasPredicate();
2138 bool isOptional = staticValue.value().isOptional();
2139 if (!hasPredicate && !isOptional)
2140 continue;
2141 body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
2142 // Capitalize the first letter to match the function name
2143 valueKind.substr(0, 1).upper(), valueKind.substr(1),
2144 staticValue.index());
2145
2146 // If the constraint is optional check that the value group has at most 1
2147 // value.
2148 if (isOptional) {
2149 body << formatv(" if (valueGroup{0}.size() > 1)\n"
2150 " return emitOpError(\"{1} group starting at #\") "
2151 "<< index << \" requires 0 or 1 element, but found \" << "
2152 "valueGroup{0}.size();\n",
2153 staticValue.index(), valueKind);
2154 }
2155
2156 // Otherwise, if there is no predicate there is nothing left to do.
2157 if (!hasPredicate)
2158 continue;
2159 // Emit a loop to check all the dynamic values in the pack.
2160 StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn(
2161 staticValue.value().constraint);
2162 body << " for (::mlir::Value v : valueGroup" << staticValue.index()
2163 << ") {\n"
2164 << " if (::mlir::failed(" << constraintFn
2165 << "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n"
2166 << " return ::mlir::failure();\n"
2167 << " ++index;\n"
2168 << " }\n";
2169 }
2170
2171 body << " }\n";
2172 }
2173
genRegionVerifier(OpMethodBody & body)2174 void OpEmitter::genRegionVerifier(OpMethodBody &body) {
2175 // If we have no regions, there is nothing more to do.
2176 unsigned numRegions = op.getNumRegions();
2177 if (numRegions == 0)
2178 return;
2179
2180 body << "{\n";
2181 body << " unsigned index = 0; (void)index;\n";
2182
2183 for (unsigned i = 0; i < numRegions; ++i) {
2184 const auto ®ion = op.getRegion(i);
2185 if (region.constraint.getPredicate().isNull())
2186 continue;
2187
2188 body << " for (::mlir::Region ®ion : ";
2189 body << formatv(region.isVariadic()
2190 ? "{0}()"
2191 : "::mlir::MutableArrayRef<::mlir::Region>((*this)"
2192 "->getRegion({1}))",
2193 region.name, i);
2194 body << ") {\n";
2195 auto constraint = tgfmt(region.constraint.getConditionTemplate(),
2196 &verifyCtx.withSelf("region"))
2197 .str();
2198
2199 body << formatv(" (void)region;\n"
2200 " if (!({0})) {\n "
2201 "return emitOpError(\"region #\") << index << \" {1}"
2202 "failed to "
2203 "verify constraint: {2}\";\n }\n",
2204 constraint,
2205 region.name.empty() ? "" : "('" + region.name + "') ",
2206 region.constraint.getSummary())
2207 << " ++index;\n"
2208 << " }\n";
2209 }
2210 body << " }\n";
2211 }
2212
genSuccessorVerifier(OpMethodBody & body)2213 void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
2214 // If we have no successors, there is nothing more to do.
2215 unsigned numSuccessors = op.getNumSuccessors();
2216 if (numSuccessors == 0)
2217 return;
2218
2219 body << "{\n";
2220 body << " unsigned index = 0; (void)index;\n";
2221
2222 for (unsigned i = 0; i < numSuccessors; ++i) {
2223 const auto &successor = op.getSuccessor(i);
2224 if (successor.constraint.getPredicate().isNull())
2225 continue;
2226
2227 if (successor.isVariadic()) {
2228 body << formatv(" for (::mlir::Block *successor : {0}()) {\n",
2229 successor.name);
2230 } else {
2231 body << " {\n";
2232 body << formatv(" ::mlir::Block *successor = {0}();\n",
2233 successor.name);
2234 }
2235 auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
2236 &verifyCtx.withSelf("successor"))
2237 .str();
2238
2239 body << formatv(" (void)successor;\n"
2240 " if (!({0})) {\n "
2241 "return emitOpError(\"successor #\") << index << \"('{1}') "
2242 "failed to "
2243 "verify constraint: {2}\";\n }\n",
2244 constraint, successor.name,
2245 successor.constraint.getSummary())
2246 << " ++index;\n"
2247 << " }\n";
2248 }
2249 body << " }\n";
2250 }
2251
2252 /// Add a size count trait to the given operation class.
addSizeCountTrait(OpClass & opClass,StringRef traitKind,int numTotal,int numVariadic)2253 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
2254 int numTotal, int numVariadic) {
2255 if (numVariadic != 0) {
2256 if (numTotal == numVariadic)
2257 opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
2258 else
2259 opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
2260 Twine(numTotal - numVariadic) + ">::Impl");
2261 return;
2262 }
2263 switch (numTotal) {
2264 case 0:
2265 opClass.addTrait("::mlir::OpTrait::Zero" + traitKind);
2266 break;
2267 case 1:
2268 opClass.addTrait("::mlir::OpTrait::One" + traitKind);
2269 break;
2270 default:
2271 opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
2272 ">::Impl");
2273 break;
2274 }
2275 }
2276
genTraits()2277 void OpEmitter::genTraits() {
2278 // Add region size trait.
2279 unsigned numRegions = op.getNumRegions();
2280 unsigned numVariadicRegions = op.getNumVariadicRegions();
2281 addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
2282
2283 // Add result size traits.
2284 int numResults = op.getNumResults();
2285 int numVariadicResults = op.getNumVariableLengthResults();
2286 addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
2287
2288 // For single result ops with a known specific type, generate a OneTypedResult
2289 // trait.
2290 if (numResults == 1 && numVariadicResults == 0) {
2291 auto cppName = op.getResults().begin()->constraint.getCPPClassName();
2292 opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
2293 }
2294
2295 // Add successor size trait.
2296 unsigned numSuccessors = op.getNumSuccessors();
2297 unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
2298 addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
2299
2300 // Add variadic size trait and normal op traits.
2301 int numOperands = op.getNumOperands();
2302 int numVariadicOperands = op.getNumVariableLengthOperands();
2303
2304 // Add operand size trait.
2305 if (numVariadicOperands != 0) {
2306 if (numOperands == numVariadicOperands)
2307 opClass.addTrait("::mlir::OpTrait::VariadicOperands");
2308 else
2309 opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" +
2310 Twine(numOperands - numVariadicOperands) + ">::Impl");
2311 } else {
2312 switch (numOperands) {
2313 case 0:
2314 opClass.addTrait("::mlir::OpTrait::ZeroOperands");
2315 break;
2316 case 1:
2317 opClass.addTrait("::mlir::OpTrait::OneOperand");
2318 break;
2319 default:
2320 opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) +
2321 ">::Impl");
2322 break;
2323 }
2324 }
2325
2326 // Add the native and interface traits.
2327 for (const auto &trait : op.getTraits()) {
2328 if (auto opTrait = dyn_cast<tblgen::NativeTrait>(&trait))
2329 opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2330 else if (auto opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
2331 opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2332 }
2333 }
2334
genOpNameGetter()2335 void OpEmitter::genOpNameGetter() {
2336 auto *method = opClass.addMethodAndPrune(
2337 "::llvm::StringLiteral", "getOperationName",
2338 OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr));
2339 method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName()
2340 << "\");";
2341 }
2342
genOpAsmInterface()2343 void OpEmitter::genOpAsmInterface() {
2344 // If the user only has one results or specifically added the Asm trait,
2345 // then don't generate it for them. We specifically only handle multi result
2346 // operations, because the name of a single result in the common case is not
2347 // interesting(generally 'result'/'output'/etc.).
2348 // TODO: We could also add a flag to allow operations to opt in to this
2349 // generation, even if they only have a single operation.
2350 int numResults = op.getNumResults();
2351 if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
2352 return;
2353
2354 SmallVector<StringRef, 4> resultNames(numResults);
2355 for (int i = 0; i != numResults; ++i)
2356 resultNames[i] = op.getResultName(i);
2357
2358 // Don't add the trait if none of the results have a valid name.
2359 if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
2360 return;
2361 opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
2362
2363 // Generate the right accessor for the number of results.
2364 auto *method = opClass.addMethodAndPrune(
2365 "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn");
2366 auto &body = method->body();
2367 for (int i = 0; i != numResults; ++i) {
2368 body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n"
2369 << " if (!llvm::empty(resultGroup" << i << "))\n"
2370 << " setNameFn(*resultGroup" << i << ".begin(), \""
2371 << resultNames[i] << "\");\n";
2372 }
2373 }
2374
2375 //===----------------------------------------------------------------------===//
2376 // OpOperandAdaptor emitter
2377 //===----------------------------------------------------------------------===//
2378
2379 namespace {
2380 // Helper class to emit Op operand adaptors to an output stream. Operand
2381 // adaptors are wrappers around ArrayRef<Value> that provide named operand
2382 // getters identical to those defined in the Op.
2383 class OpOperandAdaptorEmitter {
2384 public:
2385 static void emitDecl(const Operator &op, raw_ostream &os);
2386 static void emitDef(const Operator &op, raw_ostream &os);
2387
2388 private:
2389 explicit OpOperandAdaptorEmitter(const Operator &op);
2390
2391 // Add verification function. This generates a verify method for the adaptor
2392 // which verifies all the op-independent attribute constraints.
2393 void addVerification();
2394
2395 const Operator &op;
2396 Class adaptor;
2397 };
2398 } // end namespace
2399
OpOperandAdaptorEmitter(const Operator & op)2400 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
2401 : op(op), adaptor(op.getAdaptorName()) {
2402 adaptor.newField("::mlir::ValueRange", "odsOperands");
2403 adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
2404 adaptor.newField("::mlir::RegionRange", "odsRegions");
2405 const auto *attrSizedOperands =
2406 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2407 {
2408 SmallVector<OpMethodParameter, 2> paramList;
2409 paramList.emplace_back("::mlir::ValueRange", "values");
2410 paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
2411 attrSizedOperands ? "" : "nullptr");
2412 paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
2413 auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList));
2414
2415 constructor->addMemberInitializer("odsOperands", "values");
2416 constructor->addMemberInitializer("odsAttrs", "attrs");
2417 constructor->addMemberInitializer("odsRegions", "regions");
2418 }
2419
2420 {
2421 auto *constructor = adaptor.addConstructorAndPrune(
2422 llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
2423 constructor->addMemberInitializer("odsOperands", "op->getOperands()");
2424 constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
2425 constructor->addMemberInitializer("odsRegions", "op->getRegions()");
2426 }
2427
2428 {
2429 auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands");
2430 m->body() << " return odsOperands;";
2431 }
2432 std::string sizeAttrInit =
2433 formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
2434 generateNamedOperandGetters(op, adaptor, sizeAttrInit,
2435 /*rangeType=*/"::mlir::ValueRange",
2436 /*rangeBeginCall=*/"odsOperands.begin()",
2437 /*rangeSizeCall=*/"odsOperands.size()",
2438 /*getOperandCallPattern=*/"odsOperands[{0}]");
2439
2440 FmtContext fctx;
2441 fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
2442
2443 auto emitAttr = [&](StringRef name, Attribute attr) {
2444 auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body();
2445 body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
2446 << "\n " << attr.getStorageType() << " attr = "
2447 << "odsAttrs.get(\"" << name << "\").";
2448 if (attr.hasDefaultValue() || attr.isOptional())
2449 body << "dyn_cast_or_null<";
2450 else
2451 body << "cast<";
2452 body << attr.getStorageType() << ">();\n";
2453
2454 if (attr.hasDefaultValue()) {
2455 // Use the default value if attribute is not set.
2456 // TODO: this is inefficient, we are recreating the attribute for every
2457 // call. This should be set instead.
2458 std::string defaultValue = std::string(
2459 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2460 body << " if (!attr)\n attr = " << defaultValue << ";\n";
2461 }
2462 body << " return attr;\n";
2463 };
2464
2465 {
2466 auto *m =
2467 adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes");
2468 m->body() << " return odsAttrs;";
2469 }
2470 for (auto &namedAttr : op.getAttributes()) {
2471 const auto &name = namedAttr.name;
2472 const auto &attr = namedAttr.attr;
2473 if (!attr.isDerivedAttr())
2474 emitAttr(name, attr);
2475 }
2476
2477 unsigned numRegions = op.getNumRegions();
2478 if (numRegions > 0) {
2479 auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions");
2480 m->body() << " return odsRegions;";
2481 }
2482 for (unsigned i = 0; i < numRegions; ++i) {
2483 const auto ®ion = op.getRegion(i);
2484 if (region.name.empty())
2485 continue;
2486
2487 // Generate the accessors for a variadic region.
2488 if (region.isVariadic()) {
2489 auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", region.name);
2490 m->body() << formatv(" return odsRegions.drop_front({0});", i);
2491 continue;
2492 }
2493
2494 auto *m = adaptor.addMethodAndPrune("::mlir::Region &", region.name);
2495 m->body() << formatv(" return *odsRegions[{0}];", i);
2496 }
2497
2498 // Add verification function.
2499 addVerification();
2500 }
2501
addVerification()2502 void OpOperandAdaptorEmitter::addVerification() {
2503 auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
2504 "::mlir::Location", "loc");
2505 auto &body = method->body();
2506
2507 const char *checkAttrSizedValueSegmentsCode = R"(
2508 {
2509 auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
2510 auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
2511 if (numElements != {1})
2512 return emitError(loc, "'{0}' attribute for specifying {2} segments "
2513 "must have {1} elements, but got ") << numElements;
2514 }
2515 )";
2516
2517 // Verify a few traits first so that we can use
2518 // getODSOperands()/getODSResults() in the rest of the verifier.
2519 for (auto &trait : op.getTraits()) {
2520 if (auto *t = dyn_cast<tblgen::NativeTrait>(&trait)) {
2521 if (t->getFullyQualifiedTraitName() ==
2522 "::mlir::OpTrait::AttrSizedOperandSegments") {
2523 body << formatv(checkAttrSizedValueSegmentsCode,
2524 "operand_segment_sizes", op.getNumOperands(),
2525 "operand");
2526 } else if (t->getFullyQualifiedTraitName() ==
2527 "::mlir::OpTrait::AttrSizedResultSegments") {
2528 body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
2529 op.getNumResults(), "result");
2530 }
2531 }
2532 }
2533
2534 FmtContext verifyCtx;
2535 populateSubstitutions(op, "odsAttrs.get", "getODSOperands",
2536 "<no results should be generated>", verifyCtx);
2537 genAttributeVerifier(op, "odsAttrs.get",
2538 Twine("emitError(loc, \"'") + op.getOperationName() +
2539 "' op \"",
2540 /*emitVerificationRequiringOp*/ false, verifyCtx, body);
2541
2542 body << " return ::mlir::success();";
2543 }
2544
emitDecl(const Operator & op,raw_ostream & os)2545 void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
2546 OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
2547 }
2548
emitDef(const Operator & op,raw_ostream & os)2549 void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
2550 OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
2551 }
2552
2553 // Emits the opcode enum and op classes.
emitOpClasses(const RecordKeeper & recordKeeper,const std::vector<Record * > & defs,raw_ostream & os,bool emitDecl)2554 static void emitOpClasses(const RecordKeeper &recordKeeper,
2555 const std::vector<Record *> &defs, raw_ostream &os,
2556 bool emitDecl) {
2557 // First emit forward declaration for each class, this allows them to refer
2558 // to each others in traits for example.
2559 if (emitDecl) {
2560 os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
2561 os << "#undef GET_OP_FWD_DEFINES\n";
2562 for (auto *def : defs) {
2563 Operator op(*def);
2564 NamespaceEmitter emitter(os, op.getCppNamespace());
2565 os << "class " << op.getCppClassName() << ";\n";
2566 }
2567 os << "#endif\n\n";
2568 }
2569
2570 IfDefScope scope("GET_OP_CLASSES", os);
2571 if (defs.empty())
2572 return;
2573
2574 // Generate all of the locally instantiated methods first.
2575 StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os,
2576 emitDecl);
2577 for (auto *def : defs) {
2578 Operator op(*def);
2579 NamespaceEmitter emitter(os, op.getCppNamespace());
2580 if (emitDecl) {
2581 os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
2582 OpOperandAdaptorEmitter::emitDecl(op, os);
2583 OpEmitter::emitDecl(op, os, staticVerifierEmitter);
2584 } else {
2585 os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
2586 OpOperandAdaptorEmitter::emitDef(op, os);
2587 OpEmitter::emitDef(op, os, staticVerifierEmitter);
2588 }
2589 }
2590 }
2591
2592 // Emits a comma-separated list of the ops.
emitOpList(const std::vector<Record * > & defs,raw_ostream & os)2593 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
2594 IfDefScope scope("GET_OP_LIST", os);
2595
2596 interleave(
2597 // TODO: We are constructing the Operator wrapper instance just for
2598 // getting it's qualified class name here. Reduce the overhead by having a
2599 // lightweight version of Operator class just for that purpose.
2600 defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
2601 [&os]() { os << ",\n"; });
2602 }
2603
emitOpDecls(const RecordKeeper & recordKeeper,raw_ostream & os)2604 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
2605 emitSourceFileHeader("Op Declarations", os);
2606
2607 std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
2608 emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
2609
2610 return false;
2611 }
2612
emitOpDefs(const RecordKeeper & recordKeeper,raw_ostream & os)2613 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
2614 emitSourceFileHeader("Op Definitions", os);
2615
2616 std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
2617 emitOpList(defs, os);
2618 emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
2619
2620 return false;
2621 }
2622
2623 static mlir::GenRegistration
2624 genOpDecls("gen-op-decls", "Generate op declarations",
__anon96b8615e1802(const RecordKeeper &records, raw_ostream &os) 2625 [](const RecordKeeper &records, raw_ostream &os) {
2626 return emitOpDecls(records, os);
2627 });
2628
2629 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
2630 [](const RecordKeeper &records,
__anon96b8615e1902(const RecordKeeper &records, raw_ostream &os) 2631 raw_ostream &os) {
2632 return emitOpDefs(records, os);
2633 });
2634