1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "OpFormatGen.h"
15 #include "mlir/TableGen/CodeGenHelpers.h"
16 #include "mlir/TableGen/Format.h"
17 #include "mlir/TableGen/GenInfo.h"
18 #include "mlir/TableGen/Interfaces.h"
19 #include "mlir/TableGen/OpClass.h"
20 #include "mlir/TableGen/OpTrait.h"
21 #include "mlir/TableGen/Operator.h"
22 #include "mlir/TableGen/SideEffects.h"
23 #include "llvm/ADT/Sequence.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Regex.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Record.h"
30 #include "llvm/TableGen/TableGenBackend.h"
31
32 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
33
34 using namespace llvm;
35 using namespace mlir;
36 using namespace mlir::tblgen;
37
38 cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls");
39
40 static cl::opt<std::string> opIncFilter(
41 "op-include-regex",
42 cl::desc("Regex of name of op's to include (no filter if empty)"),
43 cl::cat(opDefGenCat));
44 static cl::opt<std::string> opExcFilter(
45 "op-exclude-regex",
46 cl::desc("Regex of name of op's to exclude (no filter if empty)"),
47 cl::cat(opDefGenCat));
48
49 static const char *const tblgenNamePrefix = "tblgen_";
50 static const char *const generatedArgName = "odsArg";
51 static const char *const odsBuilder = "odsBuilder";
52 static const char *const builderOpState = "odsState";
53
54 // The logic to calculate the actual value range for a declared operand/result
55 // of an op with variadic operands/results. Note that this logic is not for
56 // general use; it assumes all variadic operands/results must have the same
57 // number of values.
58 //
59 // {0}: The list of whether each declared operand/result is variadic.
60 // {1}: The total number of non-variadic operands/results.
61 // {2}: The total number of variadic operands/results.
62 // {3}: The total number of actual values.
63 // {4}: "operand" or "result".
64 const char *sameVariadicSizeValueRangeCalcCode = R"(
65 bool isVariadic[] = {{{0}};
66 int prevVariadicCount = 0;
67 for (unsigned i = 0; i < index; ++i)
68 if (isVariadic[i]) ++prevVariadicCount;
69
70 // Calculate how many dynamic values a static variadic {4} corresponds to.
71 // This assumes all static variadic {4}s have the same dynamic value count.
72 int variadicSize = ({3} - {1}) / {2};
73 // `index` passed in as the parameter is the static index which counts each
74 // {4} (variadic or not) as size 1. So here for each previous static variadic
75 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
76 // value pack for this static {4} starts.
77 int start = index + (variadicSize - 1) * prevVariadicCount;
78 int size = isVariadic[index] ? variadicSize : 1;
79 return {{start, size};
80 )";
81
82 // The logic to calculate the actual value range for a declared operand/result
83 // of an op with variadic operands/results. Note that this logic is assumes
84 // the op has an attribute specifying the size of each operand/result segment
85 // (variadic or not).
86 //
87 // {0}: The name of the attribute specifying the segment sizes.
88 const char *adapterSegmentSizeAttrInitCode = R"(
89 assert(odsAttrs && "missing segment size attribute for op");
90 auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
91 )";
92 const char *opSegmentSizeAttrInitCode = R"(
93 auto sizeAttr = (*this)->getAttrOfType<::mlir::DenseIntElementsAttr>("{0}");
94 )";
95 const char *attrSizedSegmentValueRangeCalcCode = R"(
96 unsigned start = 0;
97 for (unsigned i = 0; i < index; ++i)
98 start += (*(sizeAttr.begin() + i)).getZExtValue();
99 unsigned size = (*(sizeAttr.begin() + index)).getZExtValue();
100 return {start, size};
101 )";
102
103 // The logic to build a range of either operand or result values.
104 //
105 // {0}: The begin iterator of the actual values.
106 // {1}: The call to generate the start and length of the value range.
107 const char *valueRangeReturnCode = R"(
108 auto valueRange = {1};
109 return {{std::next({0}, valueRange.first),
110 std::next({0}, valueRange.first + valueRange.second)};
111 )";
112
113 static const char *const opCommentHeader = R"(
114 //===----------------------------------------------------------------------===//
115 // {0} {1}
116 //===----------------------------------------------------------------------===//
117
118 )";
119
120 //===----------------------------------------------------------------------===//
121 // StaticVerifierFunctionEmitter
122 //===----------------------------------------------------------------------===//
123
124 namespace {
125 /// This class deduplicates shared operation verification code by emitting
126 /// static functions alongside the op definitions. These methods are local to
127 /// the definition file, and are invoked within the operation verify methods.
128 /// An example is shown below:
129 ///
130 /// static LogicalResult localVerify(...)
131 ///
132 /// LogicalResult OpA::verify(...) {
133 /// if (failed(localVerify(...)))
134 /// return failure();
135 /// ...
136 /// }
137 ///
138 /// LogicalResult OpB::verify(...) {
139 /// if (failed(localVerify(...)))
140 /// return failure();
141 /// ...
142 /// }
143 ///
144 class StaticVerifierFunctionEmitter {
145 public:
146 StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
147 ArrayRef<llvm::Record *> opDefs,
148 raw_ostream &os, bool emitDecl);
149
150 /// Get the name of the local function used for the given type constraint.
151 /// These functions are used for operand and result constraints and have the
152 /// form:
153 /// LogicalResult(Operation *op, Type type, StringRef valueKind,
154 /// unsigned valueGroupStartIndex);
getTypeConstraintFn(const Constraint & constraint) const155 StringRef getTypeConstraintFn(const Constraint &constraint) const {
156 auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
157 assert(it != localTypeConstraints.end() && "expected valid constraint fn");
158 return it->second;
159 }
160
161 private:
162 /// Returns a unique name to use when generating local methods.
163 static std::string getUniqueName(const llvm::RecordKeeper &records);
164
165 /// Emit local methods for the type constraints used within the provided op
166 /// definitions.
167 void emitTypeConstraintMethods(ArrayRef<llvm::Record *> opDefs,
168 raw_ostream &os, bool emitDecl);
169
170 /// A unique label for the file currently being generated. This is used to
171 /// ensure that the local functions have a unique name.
172 std::string uniqueOutputLabel;
173
174 /// A set of functions implementing type constraints, used for operand and
175 /// result verification.
176 llvm::DenseMap<const void *, std::string> localTypeConstraints;
177 };
178 } // namespace
179
StaticVerifierFunctionEmitter(const llvm::RecordKeeper & records,ArrayRef<llvm::Record * > opDefs,raw_ostream & os,bool emitDecl)180 StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
181 const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
182 raw_ostream &os, bool emitDecl)
183 : uniqueOutputLabel(getUniqueName(records)) {
184 llvm::Optional<NamespaceEmitter> namespaceEmitter;
185 if (!emitDecl) {
186 os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
187 namespaceEmitter.emplace(os, Operator(*opDefs[0]).getDialect());
188 }
189
190 emitTypeConstraintMethods(opDefs, os, emitDecl);
191 }
192
getUniqueName(const llvm::RecordKeeper & records)193 std::string StaticVerifierFunctionEmitter::getUniqueName(
194 const llvm::RecordKeeper &records) {
195 // Use the input file name when generating a unique name.
196 std::string inputFilename = records.getInputFilename();
197
198 // Drop all but the base filename.
199 StringRef nameRef = llvm::sys::path::filename(inputFilename);
200 nameRef.consume_back(".td");
201
202 // Sanitize any invalid characters.
203 std::string uniqueName;
204 for (char c : nameRef) {
205 if (llvm::isAlnum(c) || c == '_')
206 uniqueName.push_back(c);
207 else
208 uniqueName.append(llvm::utohexstr((unsigned char)c));
209 }
210 return uniqueName;
211 }
212
emitTypeConstraintMethods(ArrayRef<llvm::Record * > opDefs,raw_ostream & os,bool emitDecl)213 void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
214 ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
215 // Collect a set of all of the used type constraints within the operation
216 // definitions.
217 llvm::SetVector<const void *> typeConstraints;
218 for (Record *def : opDefs) {
219 Operator op(*def);
220 for (NamedTypeConstraint &operand : op.getOperands())
221 if (operand.hasPredicate())
222 typeConstraints.insert(operand.constraint.getAsOpaquePointer());
223 for (NamedTypeConstraint &result : op.getResults())
224 if (result.hasPredicate())
225 typeConstraints.insert(result.constraint.getAsOpaquePointer());
226 }
227
228 FmtContext fctx;
229 for (auto it : llvm::enumerate(typeConstraints)) {
230 // Generate an obscure and unique name for this type constraint.
231 std::string name = (Twine("__mlir_ods_local_type_constraint_") +
232 uniqueOutputLabel + Twine(it.index()))
233 .str();
234 localTypeConstraints.try_emplace(it.value(), name);
235
236 // Only generate the methods if we are generating definitions.
237 if (emitDecl)
238 continue;
239
240 Constraint constraint = Constraint::getFromOpaquePointer(it.value());
241 os << "static ::mlir::LogicalResult " << name
242 << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
243 "valueKind, unsigned valueGroupStartIndex) {\n";
244
245 os << " if (!("
246 << tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type"))
247 << ")) {\n"
248 << formatv(
249 " return op->emitOpError(valueKind) << \" #\" << "
250 "valueGroupStartIndex << \" must be {0}, but got \" << type;\n",
251 constraint.getSummary())
252 << " }\n"
253 << " return ::mlir::success();\n"
254 << "}\n\n";
255 }
256 }
257
258 //===----------------------------------------------------------------------===//
259 // Utility structs and functions
260 //===----------------------------------------------------------------------===//
261
262 // Replaces all occurrences of `match` in `str` with `substitute`.
replaceAllSubstrs(std::string str,const std::string & match,const std::string & substitute)263 static std::string replaceAllSubstrs(std::string str, const std::string &match,
264 const std::string &substitute) {
265 std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
266 while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
267 str = str.replace(matchLoc, match.size(), substitute);
268 scanLoc = matchLoc + substitute.size();
269 }
270 return str;
271 }
272
273 // Returns whether the record has a value of the given name that can be returned
274 // via getValueAsString.
hasStringAttribute(const Record & record,StringRef fieldName)275 static inline bool hasStringAttribute(const Record &record,
276 StringRef fieldName) {
277 auto valueInit = record.getValueInit(fieldName);
278 return isa<StringInit>(valueInit);
279 }
280
getArgumentName(const Operator & op,int index)281 static std::string getArgumentName(const Operator &op, int index) {
282 const auto &operand = op.getOperand(index);
283 if (!operand.name.empty())
284 return std::string(operand.name);
285 else
286 return std::string(formatv("{0}_{1}", generatedArgName, index));
287 }
288
289 // Returns true if we can use unwrapped value for the given `attr` in builders.
canUseUnwrappedRawValue(const tblgen::Attribute & attr)290 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
291 return attr.getReturnType() != attr.getStorageType() &&
292 // We need to wrap the raw value into an attribute in the builder impl
293 // so we need to make sure that the attribute specifies how to do that.
294 !attr.getConstBuilderTemplate().empty();
295 }
296
297 //===----------------------------------------------------------------------===//
298 // Op emitter
299 //===----------------------------------------------------------------------===//
300
301 namespace {
302 // Helper class to emit a record into the given output stream.
303 class OpEmitter {
304 public:
305 static void
306 emitDecl(const Operator &op, raw_ostream &os,
307 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
308 static void
309 emitDef(const Operator &op, raw_ostream &os,
310 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
311
312 private:
313 OpEmitter(const Operator &op,
314 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
315
316 void emitDecl(raw_ostream &os);
317 void emitDef(raw_ostream &os);
318
319 // Generates the OpAsmOpInterface for this operation if possible.
320 void genOpAsmInterface();
321
322 // Generates the `getOperationName` method for this op.
323 void genOpNameGetter();
324
325 // Generates getters for the attributes.
326 void genAttrGetters();
327
328 // Generates setter for the attributes.
329 void genAttrSetters();
330
331 // Generates removers for optional attributes.
332 void genOptionalAttrRemovers();
333
334 // Generates getters for named operands.
335 void genNamedOperandGetters();
336
337 // Generates setters for named operands.
338 void genNamedOperandSetters();
339
340 // Generates getters for named results.
341 void genNamedResultGetters();
342
343 // Generates getters for named regions.
344 void genNamedRegionGetters();
345
346 // Generates getters for named successors.
347 void genNamedSuccessorGetters();
348
349 // Generates builder methods for the operation.
350 void genBuilder();
351
352 // Generates the build() method that takes each operand/attribute
353 // as a stand-alone parameter.
354 void genSeparateArgParamBuilder();
355
356 // Generates the build() method that takes each operand/attribute as a
357 // stand-alone parameter. The generated build() method uses first operand's
358 // type as all results' types.
359 void genUseOperandAsResultTypeSeparateParamBuilder();
360
361 // Generates the build() method that takes all operands/attributes
362 // collectively as one parameter. The generated build() method uses first
363 // operand's type as all results' types.
364 void genUseOperandAsResultTypeCollectiveParamBuilder();
365
366 // Generates the build() method that takes aggregate operands/attributes
367 // parameters. This build() method uses inferred types as result types.
368 // Requires: The type needs to be inferable via InferTypeOpInterface.
369 void genInferredTypeCollectiveParamBuilder();
370
371 // Generates the build() method that takes each operand/attribute as a
372 // stand-alone parameter. The generated build() method uses first attribute's
373 // type as all result's types.
374 void genUseAttrAsResultTypeBuilder();
375
376 // Generates the build() method that takes all result types collectively as
377 // one parameter. Similarly for operands and attributes.
378 void genCollectiveParamBuilder();
379
380 // The kind of parameter to generate for result types in builders.
381 enum class TypeParamKind {
382 None, // No result type in parameter list.
383 Separate, // A separate parameter for each result type.
384 Collective, // An ArrayRef<Type> for all result types.
385 };
386
387 // The kind of parameter to generate for attributes in builders.
388 enum class AttrParamKind {
389 WrappedAttr, // A wrapped MLIR Attribute instance.
390 UnwrappedValue, // A raw value without MLIR Attribute wrapper.
391 };
392
393 // Builds the parameter list for build() method of this op. This method writes
394 // to `paramList` the comma-separated parameter list and updates
395 // `resultTypeNames` with the names for parameters for specifying result
396 // types. The given `typeParamKind` and `attrParamKind` controls how result
397 // types and attributes are placed in the parameter list.
398 void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> ¶mList,
399 SmallVectorImpl<std::string> &resultTypeNames,
400 TypeParamKind typeParamKind,
401 AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
402
403 // Adds op arguments and regions into operation state for build() methods.
404 void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
405 bool isRawValueAttr = false);
406
407 // Generates canonicalizer declaration for the operation.
408 void genCanonicalizerDecls();
409
410 // Generates the folder declaration for the operation.
411 void genFolderDecls();
412
413 // Generates the parser for the operation.
414 void genParser();
415
416 // Generates the printer for the operation.
417 void genPrinter();
418
419 // Generates verify method for the operation.
420 void genVerifier();
421
422 // Generates verify statements for operands and results in the operation.
423 // The generated code will be attached to `body`.
424 void genOperandResultVerifier(OpMethodBody &body,
425 Operator::value_range values,
426 StringRef valueKind);
427
428 // Generates verify statements for regions in the operation.
429 // The generated code will be attached to `body`.
430 void genRegionVerifier(OpMethodBody &body);
431
432 // Generates verify statements for successors in the operation.
433 // The generated code will be attached to `body`.
434 void genSuccessorVerifier(OpMethodBody &body);
435
436 // Generates the traits used by the object.
437 void genTraits();
438
439 // Generate the OpInterface methods for all interfaces.
440 void genOpInterfaceMethods();
441
442 // Generate op interface methods for the given interface.
443 void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait);
444
445 // Generate op interface method for the given interface method. If
446 // 'declaration' is true, generates a declaration, else a definition.
447 OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
448 bool declaration = true);
449
450 // Generate the side effect interface methods.
451 void genSideEffectInterfaceMethods();
452
453 // Generate the type inference interface methods.
454 void genTypeInterfaceMethods();
455
456 private:
457 // The TableGen record for this op.
458 // TODO: OpEmitter should not have a Record directly,
459 // it should rather go through the Operator for better abstraction.
460 const Record &def;
461
462 // The wrapper operator class for querying information from this op.
463 Operator op;
464
465 // The C++ code builder for this op
466 OpClass opClass;
467
468 // The format context for verification code generation.
469 FmtContext verifyCtx;
470
471 // The emitter containing all of the locally emitted verification functions.
472 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
473 };
474 } // end anonymous namespace
475
476 // Populate the format context `ctx` with substitutions of attributes, operands
477 // and results.
478 // - attrGet corresponds to the name of the function to call to get value of
479 // attribute (the generated function call returns an Attribute);
480 // - operandGet corresponds to the name of the function with which to retrieve
481 // an operand (the generated function call returns an OperandRange);
482 // - resultGet corresponds to the name of the function to get an result (the
483 // generated function call returns a ValueRange);
populateSubstitutions(const Operator & op,const char * attrGet,const char * operandGet,const char * resultGet,FmtContext & ctx)484 static void populateSubstitutions(const Operator &op, const char *attrGet,
485 const char *operandGet, const char *resultGet,
486 FmtContext &ctx) {
487 // Populate substitutions for attributes and named operands.
488 for (const auto &namedAttr : op.getAttributes())
489 ctx.addSubst(namedAttr.name,
490 formatv("{0}(\"{1}\")", attrGet, namedAttr.name));
491 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
492 auto &value = op.getOperand(i);
493 if (value.name.empty())
494 continue;
495
496 if (value.isVariadic())
497 ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i));
498 else
499 ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i));
500 }
501
502 // Populate substitutions for results.
503 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
504 auto &value = op.getResult(i);
505 if (value.name.empty())
506 continue;
507
508 if (value.isVariadic())
509 ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i));
510 else
511 ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i));
512 }
513 }
514
515 // Generate attribute verification. If emitVerificationRequiringOp is set then
516 // only verification for attributes whose value depend on op being known are
517 // emitted, else only verification that doesn't depend on the op being known are
518 // generated.
519 // - emitErrorPrefix is the prefix for the error emitting call which consists
520 // of the entire function call up to start of error message fragment;
521 // - emitVerificationRequiringOp specifies whether verification should be
522 // emitted for verification that require the op to exist;
genAttributeVerifier(const Operator & op,const char * attrGet,const Twine & emitErrorPrefix,bool emitVerificationRequiringOp,FmtContext & ctx,OpMethodBody & body)523 static void genAttributeVerifier(const Operator &op, const char *attrGet,
524 const Twine &emitErrorPrefix,
525 bool emitVerificationRequiringOp,
526 FmtContext &ctx, OpMethodBody &body) {
527 for (const auto &namedAttr : op.getAttributes()) {
528 const auto &attr = namedAttr.attr;
529 if (attr.isDerivedAttr())
530 continue;
531
532 auto attrName = namedAttr.name;
533 bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional();
534 auto attrPred = attr.getPredicate();
535 auto condition = attrPred.isNull() ? "" : attrPred.getCondition();
536 // There is a condition to emit only if the use of $_op and whether to
537 // emit verifications for op matches.
538 bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^
539 emitVerificationRequiringOp);
540
541 // Prefix with `tblgen_` to avoid hiding the attribute accessor.
542 auto varName = tblgenNamePrefix + attrName;
543
544 // If the attribute is
545 // 1. Required (not allowed missing) and not in op verification, or
546 // 2. Has a condition that will get verified
547 // then the variable will be used.
548 //
549 // Therefore, for optional attributes whose verification requires that an
550 // op already exists for verification/emitVerificationRequiringOp is set
551 // has nothing that can be verified here.
552 if ((allowMissingAttr || emitVerificationRequiringOp) &&
553 !hasConditionToEmit)
554 continue;
555
556 body << formatv(" {\n auto {0} = {1}(\"{2}\");\n", varName, attrGet,
557 attrName);
558
559 if (!emitVerificationRequiringOp && !allowMissingAttr) {
560 body << " if (!" << varName << ") return " << emitErrorPrefix
561 << "\"requires attribute '" << attrName << "'\");\n";
562 }
563
564 if (!hasConditionToEmit) {
565 body << " }\n";
566 continue;
567 }
568
569 if (allowMissingAttr) {
570 // If the attribute has a default value, then only verify the predicate if
571 // set. This does effectively assume that the default value is valid.
572 // TODO: verify the debug value is valid (perhaps in debug mode only).
573 body << " if (" << varName << ") {\n";
574 }
575
576 body << tgfmt(" if (!($0)) return $1\"attribute '$2' "
577 "failed to satisfy constraint: $3\");\n",
578 /*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)),
579 emitErrorPrefix, attrName, attr.getSummary());
580 if (allowMissingAttr)
581 body << " }\n";
582 body << " }\n";
583 }
584 }
585
OpEmitter(const Operator & op,const StaticVerifierFunctionEmitter & staticVerifierEmitter)586 OpEmitter::OpEmitter(const Operator &op,
587 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
588 : def(op.getDef()), op(op),
589 opClass(op.getCppClassName(), op.getExtraClassDeclaration()),
590 staticVerifierEmitter(staticVerifierEmitter) {
591 verifyCtx.withOp("(*this->getOperation())");
592
593 genTraits();
594
595 // Generate C++ code for various op methods. The order here determines the
596 // methods in the generated file.
597 genOpAsmInterface();
598 genOpNameGetter();
599 genNamedOperandGetters();
600 genNamedOperandSetters();
601 genNamedResultGetters();
602 genNamedRegionGetters();
603 genNamedSuccessorGetters();
604 genAttrGetters();
605 genAttrSetters();
606 genOptionalAttrRemovers();
607 genBuilder();
608 genParser();
609 genPrinter();
610 genVerifier();
611 genCanonicalizerDecls();
612 genFolderDecls();
613 genTypeInterfaceMethods();
614 genOpInterfaceMethods();
615 generateOpFormat(op, opClass);
616 genSideEffectInterfaceMethods();
617 }
618
emitDecl(const Operator & op,raw_ostream & os,const StaticVerifierFunctionEmitter & staticVerifierEmitter)619 void OpEmitter::emitDecl(
620 const Operator &op, raw_ostream &os,
621 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
622 OpEmitter(op, staticVerifierEmitter).emitDecl(os);
623 }
624
emitDef(const Operator & op,raw_ostream & os,const StaticVerifierFunctionEmitter & staticVerifierEmitter)625 void OpEmitter::emitDef(
626 const Operator &op, raw_ostream &os,
627 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
628 OpEmitter(op, staticVerifierEmitter).emitDef(os);
629 }
630
emitDecl(raw_ostream & os)631 void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
632
emitDef(raw_ostream & os)633 void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); }
634
genAttrGetters()635 void OpEmitter::genAttrGetters() {
636 FmtContext fctx;
637 fctx.withBuilder("::mlir::Builder((*this)->getContext())");
638
639 Dialect opDialect = op.getDialect();
640 // Emit the derived attribute body.
641 auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
642 auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
643 if (!method)
644 return;
645 auto &body = method->body();
646 body << " " << attr.getDerivedCodeBody() << "\n";
647 };
648
649 // Emit with return type specified.
650 auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) {
651 auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name);
652 auto &body = method->body();
653 body << " auto attr = " << name << "Attr();\n";
654 if (attr.hasDefaultValue()) {
655 // Returns the default value if not set.
656 // TODO: this is inefficient, we are recreating the attribute for every
657 // call. This should be set instead.
658 std::string defaultValue = std::string(
659 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
660 body << " if (!attr)\n return "
661 << tgfmt(attr.getConvertFromStorageCall(),
662 &fctx.withSelf(defaultValue))
663 << ";\n";
664 }
665 body << " return "
666 << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
667 << ";\n";
668 };
669
670 // Generate raw named accessor type. This is a wrapper class that allows
671 // referring to the attributes via accessors instead of having to use
672 // the string interface for better compile time verification.
673 auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
674 auto *method =
675 opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str());
676 if (!method)
677 return;
678 auto &body = method->body();
679 body << " return (*this)->getAttr(\"" << name << "\").template ";
680 if (attr.isOptional() || attr.hasDefaultValue())
681 body << "dyn_cast_or_null<";
682 else
683 body << "cast<";
684 body << attr.getStorageType() << ">();";
685 };
686
687 for (auto &namedAttr : op.getAttributes()) {
688 const auto &name = namedAttr.name;
689 const auto &attr = namedAttr.attr;
690 if (attr.isDerivedAttr()) {
691 emitDerivedAttr(name, attr);
692 } else {
693 emitAttrWithStorageType(name, attr);
694 emitAttrWithReturnType(name, attr);
695 }
696 }
697
698 auto derivedAttrs = make_filter_range(op.getAttributes(),
699 [](const NamedAttribute &namedAttr) {
700 return namedAttr.attr.isDerivedAttr();
701 });
702 if (!derivedAttrs.empty()) {
703 opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
704 // Generate helper method to query whether a named attribute is a derived
705 // attribute. This enables, for example, avoiding adding an attribute that
706 // overlaps with a derived attribute.
707 {
708 auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute",
709 OpMethod::MP_Static,
710 "::llvm::StringRef", "name");
711 auto &body = method->body();
712 for (auto namedAttr : derivedAttrs)
713 body << " if (name == \"" << namedAttr.name << "\") return true;\n";
714 body << " return false;";
715 }
716 // Generate method to materialize derived attributes as a DictionaryAttr.
717 {
718 auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr",
719 "materializeDerivedAttributes");
720 auto &body = method->body();
721
722 auto nonMaterializable =
723 make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
724 return namedAttr.attr.getConvertFromStorageCall().empty();
725 });
726 if (!nonMaterializable.empty()) {
727 std::string attrs;
728 llvm::raw_string_ostream os(attrs);
729 interleaveComma(nonMaterializable, os,
730 [&](const NamedAttribute &attr) { os << attr.name; });
731 PrintWarning(
732 op.getLoc(),
733 formatv(
734 "op has non-materializable derived attributes '{0}', skipping",
735 os.str()));
736 body << formatv(" emitOpError(\"op has non-materializable derived "
737 "attributes '{0}'\");\n",
738 attrs);
739 body << " return nullptr;";
740 return;
741 }
742
743 body << " ::mlir::MLIRContext* ctx = getContext();\n";
744 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
745 body << " return ::mlir::DictionaryAttr::get({\n";
746 interleave(
747 derivedAttrs, body,
748 [&](const NamedAttribute &namedAttr) {
749 auto tmpl = namedAttr.attr.getConvertFromStorageCall();
750 body << " {::mlir::Identifier::get(\"" << namedAttr.name
751 << "\", ctx),\n"
752 << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()")
753 .withBuilder("odsBuilder")
754 .addSubst("_ctx", "ctx"))
755 << "}";
756 },
757 ",\n");
758 body << "\n }, ctx);";
759 }
760 }
761 }
762
genAttrSetters()763 void OpEmitter::genAttrSetters() {
764 // Generate raw named setter type. This is a wrapper class that allows setting
765 // to the attributes via setters instead of having to use the string interface
766 // for better compile time verification.
767 auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) {
768 auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(),
769 attr.getStorageType(), "attr");
770 if (!method)
771 return;
772 auto &body = method->body();
773 body << " (*this)->setAttr(\"" << name << "\", attr);";
774 };
775
776 for (auto &namedAttr : op.getAttributes()) {
777 const auto &name = namedAttr.name;
778 const auto &attr = namedAttr.attr;
779 if (!attr.isDerivedAttr())
780 emitAttrWithStorageType(name, attr);
781 }
782 }
783
genOptionalAttrRemovers()784 void OpEmitter::genOptionalAttrRemovers() {
785 // Generate methods for removing optional attributes, instead of having to
786 // use the string interface. Enables better compile time verification.
787 auto emitRemoveAttr = [&](StringRef name) {
788 auto upperInitial = name.take_front().upper();
789 auto suffix = name.drop_front();
790 auto *method = opClass.addMethodAndPrune(
791 "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str());
792 if (!method)
793 return;
794 auto &body = method->body();
795 body << " return (*this)->removeAttr(\"" << name << "\");";
796 };
797
798 for (const auto &namedAttr : op.getAttributes()) {
799 const auto &name = namedAttr.name;
800 const auto &attr = namedAttr.attr;
801 if (attr.isOptional())
802 emitRemoveAttr(name);
803 }
804 }
805
806 // Generates the code to compute the start and end index of an operand or result
807 // range.
808 template <typename RangeT>
809 static void
generateValueRangeStartAndEnd(Class & opClass,StringRef methodName,int numVariadic,int numNonVariadic,StringRef rangeSizeCall,bool hasAttrSegmentSize,StringRef sizeAttrInit,RangeT && odsValues)810 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
811 int numVariadic, int numNonVariadic,
812 StringRef rangeSizeCall, bool hasAttrSegmentSize,
813 StringRef sizeAttrInit, RangeT &&odsValues) {
814 auto *method = opClass.addMethodAndPrune("std::pair<unsigned, unsigned>",
815 methodName, "unsigned", "index");
816 if (!method)
817 return;
818 auto &body = method->body();
819 if (numVariadic == 0) {
820 body << " return {index, 1};\n";
821 } else if (hasAttrSegmentSize) {
822 body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
823 } else {
824 // Because the op can have arbitrarily interleaved variadic and non-variadic
825 // operands, we need to embed a list in the "sink" getter method for
826 // calculation at run-time.
827 llvm::SmallVector<StringRef, 4> isVariadic;
828 isVariadic.reserve(llvm::size(odsValues));
829 for (auto &it : odsValues)
830 isVariadic.push_back(it.isVariableLength() ? "true" : "false");
831 std::string isVariadicList = llvm::join(isVariadic, ", ");
832 body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
833 numNonVariadic, numVariadic, rangeSizeCall, "operand");
834 }
835 }
836
837 // Generates the named operand getter methods for the given Operator `op` and
838 // puts them in `opClass`. Uses `rangeType` as the return type of getters that
839 // return a range of operands (individual operands are `Value ` and each
840 // element in the range must also be `Value `); use `rangeBeginCall` to get
841 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
842 // obtain the number of operands. `getOperandCallPattern` contains the code
843 // necessary to obtain a single operand whose position will be substituted
844 // instead of
845 // "{0}" marker in the pattern. Note that the pattern should work for any kind
846 // of ops, in particular for one-operand ops that may not have the
847 // `getOperand(unsigned)` method.
generateNamedOperandGetters(const Operator & op,Class & opClass,StringRef sizeAttrInit,StringRef rangeType,StringRef rangeBeginCall,StringRef rangeSizeCall,StringRef getOperandCallPattern)848 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
849 StringRef sizeAttrInit,
850 StringRef rangeType,
851 StringRef rangeBeginCall,
852 StringRef rangeSizeCall,
853 StringRef getOperandCallPattern) {
854 const int numOperands = op.getNumOperands();
855 const int numVariadicOperands = op.getNumVariableLengthOperands();
856 const int numNormalOperands = numOperands - numVariadicOperands;
857
858 const auto *sameVariadicSize =
859 op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
860 const auto *attrSizedOperands =
861 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
862
863 if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
864 PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
865 "specification over their sizes");
866 }
867
868 if (numVariadicOperands < 2 && attrSizedOperands) {
869 PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
870 "to use 'AttrSizedOperandSegments' trait");
871 }
872
873 if (attrSizedOperands && sameVariadicSize) {
874 PrintFatalError(op.getLoc(),
875 "op cannot have both 'AttrSizedOperandSegments' and "
876 "'SameVariadicOperandSize' traits");
877 }
878
879 // First emit a few "sink" getter methods upon which we layer all nicer named
880 // getter methods.
881 generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
882 numVariadicOperands, numNormalOperands,
883 rangeSizeCall, attrSizedOperands, sizeAttrInit,
884 const_cast<Operator &>(op).getOperands());
885
886 auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned",
887 "index");
888 auto &body = m->body();
889 body << formatv(valueRangeReturnCode, rangeBeginCall,
890 "getODSOperandIndexAndLength(index)");
891
892 // Then we emit nicer named getter methods by redirecting to the "sink" getter
893 // method.
894 for (int i = 0; i != numOperands; ++i) {
895 const auto &operand = op.getOperand(i);
896 if (operand.name.empty())
897 continue;
898
899 if (operand.isOptional()) {
900 m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
901 m->body() << " auto operands = getODSOperands(" << i << ");\n"
902 << " return operands.empty() ? Value() : *operands.begin();";
903 } else if (operand.isVariadic()) {
904 m = opClass.addMethodAndPrune(rangeType, operand.name);
905 m->body() << " return getODSOperands(" << i << ");";
906 } else {
907 m = opClass.addMethodAndPrune("::mlir::Value", operand.name);
908 m->body() << " return *getODSOperands(" << i << ").begin();";
909 }
910 }
911 }
912
genNamedOperandGetters()913 void OpEmitter::genNamedOperandGetters() {
914 generateNamedOperandGetters(
915 op, opClass,
916 /*sizeAttrInit=*/
917 formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(),
918 /*rangeType=*/"::mlir::Operation::operand_range",
919 /*rangeBeginCall=*/"getOperation()->operand_begin()",
920 /*rangeSizeCall=*/"getOperation()->getNumOperands()",
921 /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
922 }
923
genNamedOperandSetters()924 void OpEmitter::genNamedOperandSetters() {
925 auto *attrSizedOperands =
926 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
927 for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
928 const auto &operand = op.getOperand(i);
929 if (operand.name.empty())
930 continue;
931 auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange",
932 (operand.name + "Mutable").str());
933 auto &body = m->body();
934 body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
935 << " return ::mlir::MutableOperandRange(getOperation(), "
936 "range.first, range.second";
937 if (attrSizedOperands)
938 body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
939 << "u, *getOperation()->getAttrDictionary().getNamed("
940 "\"operand_segment_sizes\"))";
941 body << ");\n";
942 }
943 }
944
genNamedResultGetters()945 void OpEmitter::genNamedResultGetters() {
946 const int numResults = op.getNumResults();
947 const int numVariadicResults = op.getNumVariableLengthResults();
948 const int numNormalResults = numResults - numVariadicResults;
949
950 // If we have more than one variadic results, we need more complicated logic
951 // to calculate the value range for each result.
952
953 const auto *sameVariadicSize =
954 op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
955 const auto *attrSizedResults =
956 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
957
958 if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
959 PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
960 "specification over their sizes");
961 }
962
963 if (numVariadicResults < 2 && attrSizedResults) {
964 PrintFatalError(op.getLoc(), "op must have at least two variadic results "
965 "to use 'AttrSizedResultSegments' trait");
966 }
967
968 if (attrSizedResults && sameVariadicSize) {
969 PrintFatalError(op.getLoc(),
970 "op cannot have both 'AttrSizedResultSegments' and "
971 "'SameVariadicResultSize' traits");
972 }
973
974 generateValueRangeStartAndEnd(
975 opClass, "getODSResultIndexAndLength", numVariadicResults,
976 numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
977 formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(),
978 op.getResults());
979
980 auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
981 "getODSResults", "unsigned", "index");
982 m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
983 "getODSResultIndexAndLength(index)");
984
985 for (int i = 0; i != numResults; ++i) {
986 const auto &result = op.getResult(i);
987 if (result.name.empty())
988 continue;
989
990 if (result.isOptional()) {
991 m = opClass.addMethodAndPrune("::mlir::Value", result.name);
992 m->body()
993 << " auto results = getODSResults(" << i << ");\n"
994 << " return results.empty() ? ::mlir::Value() : *results.begin();";
995 } else if (result.isVariadic()) {
996 m = opClass.addMethodAndPrune("::mlir::Operation::result_range",
997 result.name);
998 m->body() << " return getODSResults(" << i << ");";
999 } else {
1000 m = opClass.addMethodAndPrune("::mlir::Value", result.name);
1001 m->body() << " return *getODSResults(" << i << ").begin();";
1002 }
1003 }
1004 }
1005
genNamedRegionGetters()1006 void OpEmitter::genNamedRegionGetters() {
1007 unsigned numRegions = op.getNumRegions();
1008 for (unsigned i = 0; i < numRegions; ++i) {
1009 const auto ®ion = op.getRegion(i);
1010 if (region.name.empty())
1011 continue;
1012
1013 // Generate the accessors for a variadic region.
1014 if (region.isVariadic()) {
1015 auto *m = opClass.addMethodAndPrune("::mlir::MutableArrayRef<Region>",
1016 region.name);
1017 m->body() << formatv(" return (*this)->getRegions().drop_front({0});",
1018 i);
1019 continue;
1020 }
1021
1022 auto *m = opClass.addMethodAndPrune("::mlir::Region &", region.name);
1023 m->body() << formatv(" return (*this)->getRegion({0});", i);
1024 }
1025 }
1026
genNamedSuccessorGetters()1027 void OpEmitter::genNamedSuccessorGetters() {
1028 unsigned numSuccessors = op.getNumSuccessors();
1029 for (unsigned i = 0; i < numSuccessors; ++i) {
1030 const NamedSuccessor &successor = op.getSuccessor(i);
1031 if (successor.name.empty())
1032 continue;
1033
1034 // Generate the accessors for a variadic successor list.
1035 if (successor.isVariadic()) {
1036 auto *m =
1037 opClass.addMethodAndPrune("::mlir::SuccessorRange", successor.name);
1038 m->body() << formatv(
1039 " return {std::next((*this)->successor_begin(), {0}), "
1040 "(*this)->successor_end()};",
1041 i);
1042 continue;
1043 }
1044
1045 auto *m = opClass.addMethodAndPrune("::mlir::Block *", successor.name);
1046 m->body() << formatv(" return (*this)->getSuccessor({0});", i);
1047 }
1048 }
1049
canGenerateUnwrappedBuilder(Operator & op)1050 static bool canGenerateUnwrappedBuilder(Operator &op) {
1051 // If this op does not have native attributes at all, return directly to avoid
1052 // redefining builders.
1053 if (op.getNumNativeAttributes() == 0)
1054 return false;
1055
1056 bool canGenerate = false;
1057 // We are generating builders that take raw values for attributes. We need to
1058 // make sure the native attributes have a meaningful "unwrapped" value type
1059 // different from the wrapped mlir::Attribute type to avoid redefining
1060 // builders. This checks for the op has at least one such native attribute.
1061 for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
1062 NamedAttribute &namedAttr = op.getAttribute(i);
1063 if (canUseUnwrappedRawValue(namedAttr.attr)) {
1064 canGenerate = true;
1065 break;
1066 }
1067 }
1068 return canGenerate;
1069 }
1070
canInferType(Operator & op)1071 static bool canInferType(Operator &op) {
1072 return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
1073 op.getNumRegions() == 0;
1074 }
1075
genSeparateArgParamBuilder()1076 void OpEmitter::genSeparateArgParamBuilder() {
1077 SmallVector<AttrParamKind, 2> attrBuilderType;
1078 attrBuilderType.push_back(AttrParamKind::WrappedAttr);
1079 if (canGenerateUnwrappedBuilder(op))
1080 attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
1081
1082 // Emit with separate builders with or without unwrapped attributes and/or
1083 // inferring result type.
1084 auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
1085 bool inferType) {
1086 llvm::SmallVector<OpMethodParameter, 4> paramList;
1087 llvm::SmallVector<std::string, 4> resultNames;
1088 buildParamList(paramList, resultNames, paramKind, attrType);
1089
1090 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1091 std::move(paramList));
1092 // If the builder is redundant, skip generating the method.
1093 if (!m)
1094 return;
1095 auto &body = m->body();
1096 genCodeForAddingArgAndRegionForBuilder(
1097 body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
1098
1099 // Push all result types to the operation state
1100
1101 if (inferType) {
1102 // Generate builder that infers type too.
1103 // TODO: Subsume this with general checking if type can be
1104 // inferred automatically.
1105 // TODO: Expand to handle regions.
1106 body << formatv(R"(
1107 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1108 if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1109 {1}.location, {1}.operands,
1110 {1}.attributes.getDictionary({1}.getContext()),
1111 /*regions=*/{{}, inferredReturnTypes)))
1112 {1}.addTypes(inferredReturnTypes);
1113 else
1114 ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1115 opClass.getClassName(), builderOpState);
1116 return;
1117 }
1118
1119 switch (paramKind) {
1120 case TypeParamKind::None:
1121 return;
1122 case TypeParamKind::Separate:
1123 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
1124 if (op.getResult(i).isOptional())
1125 body << " if (" << resultNames[i] << ")\n ";
1126 body << " " << builderOpState << ".addTypes(" << resultNames[i]
1127 << ");\n";
1128 }
1129 return;
1130 case TypeParamKind::Collective: {
1131 int numResults = op.getNumResults();
1132 int numVariadicResults = op.getNumVariableLengthResults();
1133 int numNonVariadicResults = numResults - numVariadicResults;
1134 bool hasVariadicResult = numVariadicResults != 0;
1135
1136 // Avoid emitting "resultTypes.size() >= 0u" which is always true.
1137 if (!(hasVariadicResult && numNonVariadicResults == 0))
1138 body << " "
1139 << "assert(resultTypes.size() "
1140 << (hasVariadicResult ? ">=" : "==") << " "
1141 << numNonVariadicResults
1142 << "u && \"mismatched number of results\");\n";
1143 body << " " << builderOpState << ".addTypes(resultTypes);\n";
1144 }
1145 return;
1146 }
1147 llvm_unreachable("unhandled TypeParamKind");
1148 };
1149
1150 // Some of the build methods generated here may be ambiguous, but TableGen's
1151 // ambiguous function detection will elide those ones.
1152 for (auto attrType : attrBuilderType) {
1153 emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
1154 if (canInferType(op))
1155 emit(attrType, TypeParamKind::None, /*inferType=*/true);
1156 emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
1157 }
1158 }
1159
genUseOperandAsResultTypeCollectiveParamBuilder()1160 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
1161 int numResults = op.getNumResults();
1162
1163 // Signature
1164 llvm::SmallVector<OpMethodParameter, 4> paramList;
1165 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1166 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1167 paramList.emplace_back("::mlir::ValueRange", "operands");
1168 // Provide default value for `attributes` when its the last parameter
1169 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1170 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1171 "attributes", attributesDefaultValue);
1172 if (op.getNumVariadicRegions())
1173 paramList.emplace_back("unsigned", "numRegions");
1174
1175 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1176 std::move(paramList));
1177 // If the builder is redundant, skip generating the method
1178 if (!m)
1179 return;
1180 auto &body = m->body();
1181
1182 // Operands
1183 body << " " << builderOpState << ".addOperands(operands);\n";
1184
1185 // Attributes
1186 body << " " << builderOpState << ".addAttributes(attributes);\n";
1187
1188 // Create the correct number of regions
1189 if (int numRegions = op.getNumRegions()) {
1190 body << llvm::formatv(
1191 " for (unsigned i = 0; i != {0}; ++i)\n",
1192 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1193 body << " (void)" << builderOpState << ".addRegion();\n";
1194 }
1195
1196 // Result types
1197 SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
1198 body << " " << builderOpState << ".addTypes({"
1199 << llvm::join(resultTypes, ", ") << "});\n\n";
1200 }
1201
genInferredTypeCollectiveParamBuilder()1202 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
1203 // TODO: Expand to support regions.
1204 SmallVector<OpMethodParameter, 4> paramList;
1205 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1206 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1207 paramList.emplace_back("::mlir::ValueRange", "operands");
1208 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1209 "attributes", "{}");
1210 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1211 std::move(paramList));
1212 // If the builder is redundant, skip generating the method
1213 if (!m)
1214 return;
1215 auto &body = m->body();
1216
1217 int numResults = op.getNumResults();
1218 int numVariadicResults = op.getNumVariableLengthResults();
1219 int numNonVariadicResults = numResults - numVariadicResults;
1220
1221 int numOperands = op.getNumOperands();
1222 int numVariadicOperands = op.getNumVariableLengthOperands();
1223 int numNonVariadicOperands = numOperands - numVariadicOperands;
1224
1225 // Operands
1226 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1227 body << " assert(operands.size()"
1228 << (numVariadicOperands != 0 ? " >= " : " == ")
1229 << numNonVariadicOperands
1230 << "u && \"mismatched number of parameters\");\n";
1231 body << " " << builderOpState << ".addOperands(operands);\n";
1232 body << " " << builderOpState << ".addAttributes(attributes);\n";
1233
1234 // Create the correct number of regions
1235 if (int numRegions = op.getNumRegions()) {
1236 body << llvm::formatv(
1237 " for (unsigned i = 0; i != {0}; ++i)\n",
1238 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1239 body << " (void)" << builderOpState << ".addRegion();\n";
1240 }
1241
1242 // Result types
1243 body << formatv(R"(
1244 ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1245 if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1246 {1}.location, operands,
1247 {1}.attributes.getDictionary({1}.getContext()),
1248 /*regions=*/{{}, inferredReturnTypes))) {{)",
1249 opClass.getClassName(), builderOpState);
1250 if (numVariadicResults == 0 || numNonVariadicResults != 0)
1251 body << " assert(inferredReturnTypes.size()"
1252 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1253 << "u && \"mismatched number of return types\");\n";
1254 body << " " << builderOpState << ".addTypes(inferredReturnTypes);";
1255
1256 body << formatv(R"(
1257 } else
1258 ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1259 opClass.getClassName(), builderOpState);
1260 }
1261
genUseOperandAsResultTypeSeparateParamBuilder()1262 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
1263 llvm::SmallVector<OpMethodParameter, 4> paramList;
1264 llvm::SmallVector<std::string, 4> resultNames;
1265 buildParamList(paramList, resultNames, TypeParamKind::None);
1266
1267 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1268 std::move(paramList));
1269 // If the builder is redundant, skip generating the method
1270 if (!m)
1271 return;
1272 auto &body = m->body();
1273 genCodeForAddingArgAndRegionForBuilder(body);
1274
1275 auto numResults = op.getNumResults();
1276 if (numResults == 0)
1277 return;
1278
1279 // Push all result types to the operation state
1280 const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
1281 std::string resultType =
1282 formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
1283 body << " " << builderOpState << ".addTypes({" << resultType;
1284 for (int i = 1; i != numResults; ++i)
1285 body << ", " << resultType;
1286 body << "});\n\n";
1287 }
1288
genUseAttrAsResultTypeBuilder()1289 void OpEmitter::genUseAttrAsResultTypeBuilder() {
1290 SmallVector<OpMethodParameter, 4> paramList;
1291 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1292 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1293 paramList.emplace_back("::mlir::ValueRange", "operands");
1294 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1295 "attributes", "{}");
1296 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1297 std::move(paramList));
1298 // If the builder is redundant, skip generating the method
1299 if (!m)
1300 return;
1301
1302 auto &body = m->body();
1303
1304 // Push all result types to the operation state
1305 std::string resultType;
1306 const auto &namedAttr = op.getAttribute(0);
1307
1308 body << " for (auto attr : attributes) {\n";
1309 body << " if (attr.first != \"" << namedAttr.name << "\") continue;\n";
1310 if (namedAttr.attr.isTypeAttr()) {
1311 resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()";
1312 } else {
1313 resultType = "attr.second.getType()";
1314 }
1315
1316 // Operands
1317 body << " " << builderOpState << ".addOperands(operands);\n";
1318
1319 // Attributes
1320 body << " " << builderOpState << ".addAttributes(attributes);\n";
1321
1322 // Result types
1323 SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1324 body << " " << builderOpState << ".addTypes({"
1325 << llvm::join(resultTypes, ", ") << "});\n";
1326 body << " }\n";
1327 }
1328
1329 /// Returns a signature of the builder. Updates the context `fctx` to enable
1330 /// replacement of $_builder and $_state in the body.
getBuilderSignature(const Builder & builder)1331 static std::string getBuilderSignature(const Builder &builder) {
1332 ArrayRef<Builder::Parameter> params(builder.getParameters());
1333
1334 // Inject builder and state arguments.
1335 llvm::SmallVector<std::string, 8> arguments;
1336 arguments.reserve(params.size() + 2);
1337 arguments.push_back(
1338 llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str());
1339 arguments.push_back(
1340 llvm::formatv("::mlir::OperationState &{0}", builderOpState).str());
1341
1342 for (unsigned i = 0, e = params.size(); i < e; ++i) {
1343 // If no name is provided, generate one.
1344 Optional<StringRef> paramName = params[i].getName();
1345 std::string name =
1346 paramName ? paramName->str() : "odsArg" + std::to_string(i);
1347
1348 std::string defaultValue;
1349 if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
1350 defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str();
1351 arguments.push_back(
1352 llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue)
1353 .str());
1354 }
1355
1356 return llvm::join(arguments, ", ");
1357 }
1358
genBuilder()1359 void OpEmitter::genBuilder() {
1360 // Handle custom builders if provided.
1361 for (const Builder &builder : op.getBuilders()) {
1362 std::string paramStr = getBuilderSignature(builder);
1363
1364 Optional<StringRef> body = builder.getBody();
1365 OpMethod::Property properties =
1366 body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration;
1367 auto *method =
1368 opClass.addMethodAndPrune("void", "build", properties, paramStr);
1369
1370 FmtContext fctx;
1371 fctx.withBuilder(odsBuilder);
1372 fctx.addSubst("_state", builderOpState);
1373 if (body)
1374 method->body() << tgfmt(*body, &fctx);
1375 }
1376
1377 // Generate default builders that requires all result type, operands, and
1378 // attributes as parameters.
1379 if (op.skipDefaultBuilders())
1380 return;
1381
1382 // We generate three classes of builders here:
1383 // 1. one having a stand-alone parameter for each operand / attribute, and
1384 genSeparateArgParamBuilder();
1385 // 2. one having an aggregated parameter for all result types / operands /
1386 // attributes, and
1387 genCollectiveParamBuilder();
1388 // 3. one having a stand-alone parameter for each operand and attribute,
1389 // use the first operand or attribute's type as all result types
1390 // to facilitate different call patterns.
1391 if (op.getNumVariableLengthResults() == 0) {
1392 if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
1393 genUseOperandAsResultTypeSeparateParamBuilder();
1394 genUseOperandAsResultTypeCollectiveParamBuilder();
1395 }
1396 if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
1397 genUseAttrAsResultTypeBuilder();
1398 }
1399 }
1400
genCollectiveParamBuilder()1401 void OpEmitter::genCollectiveParamBuilder() {
1402 int numResults = op.getNumResults();
1403 int numVariadicResults = op.getNumVariableLengthResults();
1404 int numNonVariadicResults = numResults - numVariadicResults;
1405
1406 int numOperands = op.getNumOperands();
1407 int numVariadicOperands = op.getNumVariableLengthOperands();
1408 int numNonVariadicOperands = numOperands - numVariadicOperands;
1409
1410 SmallVector<OpMethodParameter, 4> paramList;
1411 paramList.emplace_back("::mlir::OpBuilder &", "");
1412 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1413 paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1414 paramList.emplace_back("::mlir::ValueRange", "operands");
1415 // Provide default value for `attributes` when its the last parameter
1416 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1417 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1418 "attributes", attributesDefaultValue);
1419 if (op.getNumVariadicRegions())
1420 paramList.emplace_back("unsigned", "numRegions");
1421
1422 auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
1423 std::move(paramList));
1424 // If the builder is redundant, skip generating the method
1425 if (!m)
1426 return;
1427 auto &body = m->body();
1428
1429 // Operands
1430 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1431 body << " assert(operands.size()"
1432 << (numVariadicOperands != 0 ? " >= " : " == ")
1433 << numNonVariadicOperands
1434 << "u && \"mismatched number of parameters\");\n";
1435 body << " " << builderOpState << ".addOperands(operands);\n";
1436
1437 // Attributes
1438 body << " " << builderOpState << ".addAttributes(attributes);\n";
1439
1440 // Create the correct number of regions
1441 if (int numRegions = op.getNumRegions()) {
1442 body << llvm::formatv(
1443 " for (unsigned i = 0; i != {0}; ++i)\n",
1444 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1445 body << " (void)" << builderOpState << ".addRegion();\n";
1446 }
1447
1448 // Result types
1449 if (numVariadicResults == 0 || numNonVariadicResults != 0)
1450 body << " assert(resultTypes.size()"
1451 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1452 << "u && \"mismatched number of return types\");\n";
1453 body << " " << builderOpState << ".addTypes(resultTypes);\n";
1454
1455 // Generate builder that infers type too.
1456 // TODO: Expand to handle regions and successors.
1457 if (canInferType(op) && op.getNumSuccessors() == 0)
1458 genInferredTypeCollectiveParamBuilder();
1459 }
1460
buildParamList(SmallVectorImpl<OpMethodParameter> & paramList,SmallVectorImpl<std::string> & resultTypeNames,TypeParamKind typeParamKind,AttrParamKind attrParamKind)1461 void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> ¶mList,
1462 SmallVectorImpl<std::string> &resultTypeNames,
1463 TypeParamKind typeParamKind,
1464 AttrParamKind attrParamKind) {
1465 resultTypeNames.clear();
1466 auto numResults = op.getNumResults();
1467 resultTypeNames.reserve(numResults);
1468
1469 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1470 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1471
1472 switch (typeParamKind) {
1473 case TypeParamKind::None:
1474 break;
1475 case TypeParamKind::Separate: {
1476 // Add parameters for all return types
1477 for (int i = 0; i < numResults; ++i) {
1478 const auto &result = op.getResult(i);
1479 std::string resultName = std::string(result.name);
1480 if (resultName.empty())
1481 resultName = std::string(formatv("resultType{0}", i));
1482
1483 StringRef type =
1484 result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
1485 OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1486 if (result.isOptional())
1487 properties = OpMethodParameter::PP_Optional;
1488
1489 paramList.emplace_back(type, resultName, properties);
1490 resultTypeNames.emplace_back(std::move(resultName));
1491 }
1492 } break;
1493 case TypeParamKind::Collective: {
1494 paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1495 resultTypeNames.push_back("resultTypes");
1496 } break;
1497 }
1498
1499 // Add parameters for all arguments (operands and attributes).
1500
1501 int numOperands = 0;
1502 int numAttrs = 0;
1503
1504 int defaultValuedAttrStartIndex = op.getNumArgs();
1505 if (attrParamKind == AttrParamKind::UnwrappedValue) {
1506 // Calculate the start index from which we can attach default values in the
1507 // builder declaration.
1508 for (int i = op.getNumArgs() - 1; i >= 0; --i) {
1509 auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
1510 if (!namedAttr || !namedAttr->attr.hasDefaultValue())
1511 break;
1512
1513 if (!canUseUnwrappedRawValue(namedAttr->attr))
1514 break;
1515
1516 // Creating an APInt requires us to provide bitwidth, value, and
1517 // signedness, which is complicated compared to others. Similarly
1518 // for APFloat.
1519 // TODO: Adjust the 'returnType' field of such attributes
1520 // to support them.
1521 StringRef retType = namedAttr->attr.getReturnType();
1522 if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
1523 break;
1524
1525 defaultValuedAttrStartIndex = i;
1526 }
1527 }
1528
1529 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
1530 auto argument = op.getArg(i);
1531 if (argument.is<tblgen::NamedTypeConstraint *>()) {
1532 const auto &operand = op.getOperand(numOperands);
1533 StringRef type =
1534 operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value";
1535 OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1536 if (operand.isOptional())
1537 properties = OpMethodParameter::PP_Optional;
1538
1539 paramList.emplace_back(type, getArgumentName(op, numOperands),
1540 properties);
1541 ++numOperands;
1542 } else {
1543 const auto &namedAttr = op.getAttribute(numAttrs);
1544 const auto &attr = namedAttr.attr;
1545
1546 OpMethodParameter::Property properties = OpMethodParameter::PP_None;
1547 if (attr.isOptional())
1548 properties = OpMethodParameter::PP_Optional;
1549
1550 StringRef type;
1551 switch (attrParamKind) {
1552 case AttrParamKind::WrappedAttr:
1553 type = attr.getStorageType();
1554 break;
1555 case AttrParamKind::UnwrappedValue:
1556 if (canUseUnwrappedRawValue(attr))
1557 type = attr.getReturnType();
1558 else
1559 type = attr.getStorageType();
1560 break;
1561 }
1562
1563 std::string defaultValue;
1564 // Attach default value if requested and possible.
1565 if (attrParamKind == AttrParamKind::UnwrappedValue &&
1566 i >= defaultValuedAttrStartIndex) {
1567 bool isString = attr.getReturnType() == "::llvm::StringRef";
1568 if (isString)
1569 defaultValue.append("\"");
1570 defaultValue += attr.getDefaultValue();
1571 if (isString)
1572 defaultValue.append("\"");
1573 }
1574 paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
1575 ++numAttrs;
1576 }
1577 }
1578
1579 /// Insert parameters for each successor.
1580 for (const NamedSuccessor &succ : op.getSuccessors()) {
1581 StringRef type =
1582 succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
1583 paramList.emplace_back(type, succ.name);
1584 }
1585
1586 /// Insert parameters for variadic regions.
1587 for (const NamedRegion ®ion : op.getRegions())
1588 if (region.isVariadic())
1589 paramList.emplace_back("unsigned",
1590 llvm::formatv("{0}Count", region.name).str());
1591 }
1592
genCodeForAddingArgAndRegionForBuilder(OpMethodBody & body,bool isRawValueAttr)1593 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
1594 bool isRawValueAttr) {
1595 // Push all operands to the result.
1596 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
1597 std::string argName = getArgumentName(op, i);
1598 if (op.getOperand(i).isOptional())
1599 body << " if (" << argName << ")\n ";
1600 body << " " << builderOpState << ".addOperands(" << argName << ");\n";
1601 }
1602
1603 // If the operation has the operand segment size attribute, add it here.
1604 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1605 body << " " << builderOpState
1606 << ".addAttribute(\"operand_segment_sizes\", "
1607 "odsBuilder.getI32VectorAttr({";
1608 interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1609 if (op.getOperand(i).isOptional())
1610 body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
1611 else if (op.getOperand(i).isVariadic())
1612 body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
1613 else
1614 body << "1";
1615 });
1616 body << "}));\n";
1617 }
1618
1619 // Push all attributes to the result.
1620 for (const auto &namedAttr : op.getAttributes()) {
1621 auto &attr = namedAttr.attr;
1622 if (!attr.isDerivedAttr()) {
1623 bool emitNotNullCheck = attr.isOptional();
1624 if (emitNotNullCheck) {
1625 body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
1626 }
1627 if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
1628 // If this is a raw value, then we need to wrap it in an Attribute
1629 // instance.
1630 FmtContext fctx;
1631 fctx.withBuilder("odsBuilder");
1632
1633 std::string builderTemplate =
1634 std::string(attr.getConstBuilderTemplate());
1635
1636 // For StringAttr, its constant builder call will wrap the input in
1637 // quotes, which is correct for normal string literals, but incorrect
1638 // here given we use function arguments. So we need to strip the
1639 // wrapping quotes.
1640 if (StringRef(builderTemplate).contains("\"$0\""))
1641 builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
1642
1643 std::string value =
1644 std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
1645 body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState,
1646 namedAttr.name, value);
1647 } else {
1648 body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState,
1649 namedAttr.name);
1650 }
1651 if (emitNotNullCheck) {
1652 body << " }\n";
1653 }
1654 }
1655 }
1656
1657 // Create the correct number of regions.
1658 for (const NamedRegion ®ion : op.getRegions()) {
1659 if (region.isVariadic())
1660 body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
1661 region.name);
1662
1663 body << " (void)" << builderOpState << ".addRegion();\n";
1664 }
1665
1666 // Push all successors to the result.
1667 for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
1668 body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
1669 namedSuccessor.name);
1670 }
1671 }
1672
genCanonicalizerDecls()1673 void OpEmitter::genCanonicalizerDecls() {
1674 if (!def.getValueAsBit("hasCanonicalizer"))
1675 return;
1676
1677 SmallVector<OpMethodParameter, 2> paramList;
1678 paramList.emplace_back("::mlir::OwningRewritePatternList &", "results");
1679 paramList.emplace_back("::mlir::MLIRContext *", "context");
1680 opClass.addMethodAndPrune("void", "getCanonicalizationPatterns",
1681 OpMethod::MP_StaticDeclaration,
1682 std::move(paramList));
1683 }
1684
genFolderDecls()1685 void OpEmitter::genFolderDecls() {
1686 bool hasSingleResult =
1687 op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
1688
1689 if (def.getValueAsBit("hasFolder")) {
1690 if (hasSingleResult) {
1691 opClass.addMethodAndPrune(
1692 "::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration,
1693 "::llvm::ArrayRef<::mlir::Attribute>", "operands");
1694 } else {
1695 SmallVector<OpMethodParameter, 2> paramList;
1696 paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
1697 paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
1698 "results");
1699 opClass.addMethodAndPrune("::mlir::LogicalResult", "fold",
1700 OpMethod::MP_Declaration, std::move(paramList));
1701 }
1702 }
1703 }
1704
genOpInterfaceMethods(const tblgen::InterfaceOpTrait * opTrait)1705 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) {
1706 auto interface = opTrait->getOpInterface();
1707
1708 // Get the set of methods that should always be declared.
1709 auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
1710 llvm::StringSet<> alwaysDeclaredMethods;
1711 alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
1712 alwaysDeclaredMethodsVec.end());
1713
1714 for (const InterfaceMethod &method : interface.getMethods()) {
1715 // Don't declare if the method has a body.
1716 if (method.getBody())
1717 continue;
1718 // Don't declare if the method has a default implementation and the op
1719 // didn't request that it always be declared.
1720 if (method.getDefaultImplementation() &&
1721 !alwaysDeclaredMethods.count(method.getName()))
1722 continue;
1723 genOpInterfaceMethod(method);
1724 }
1725 }
1726
genOpInterfaceMethod(const InterfaceMethod & method,bool declaration)1727 OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
1728 bool declaration) {
1729 SmallVector<OpMethodParameter, 4> paramList;
1730 for (const InterfaceMethod::Argument &arg : method.getArguments())
1731 paramList.emplace_back(arg.type, arg.name);
1732
1733 auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
1734 if (declaration)
1735 properties =
1736 static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
1737 return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
1738 properties, std::move(paramList));
1739 }
1740
genOpInterfaceMethods()1741 void OpEmitter::genOpInterfaceMethods() {
1742 for (const auto &trait : op.getTraits()) {
1743 if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
1744 if (opTrait->shouldDeclareMethods())
1745 genOpInterfaceMethods(opTrait);
1746 }
1747 }
1748
genSideEffectInterfaceMethods()1749 void OpEmitter::genSideEffectInterfaceMethods() {
1750 enum EffectKind { Operand, Result, Symbol, Static };
1751 struct EffectLocation {
1752 /// The effect applied.
1753 SideEffect effect;
1754
1755 /// The index if the kind is not static.
1756 unsigned index : 30;
1757
1758 /// The kind of the location.
1759 unsigned kind : 2;
1760 };
1761
1762 StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
1763 auto resolveDecorators = [&](Operator::var_decorator_range decorators,
1764 unsigned index, unsigned kind) {
1765 for (auto decorator : decorators)
1766 if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
1767 opClass.addTrait(effect->getInterfaceTrait());
1768 interfaceEffects[effect->getBaseEffectName()].push_back(
1769 EffectLocation{*effect, index, kind});
1770 }
1771 };
1772
1773 // Collect effects that were specified via:
1774 /// Traits.
1775 for (const auto &trait : op.getTraits()) {
1776 const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
1777 if (!opTrait)
1778 continue;
1779 auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
1780 for (auto decorator : opTrait->getEffects())
1781 effects.push_back(EffectLocation{cast<SideEffect>(decorator),
1782 /*index=*/0, EffectKind::Static});
1783 }
1784 /// Attributes and Operands.
1785 for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
1786 Argument arg = op.getArg(i);
1787 if (arg.is<NamedTypeConstraint *>()) {
1788 resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
1789 ++operandIt;
1790 continue;
1791 }
1792 const NamedAttribute *attr = arg.get<NamedAttribute *>();
1793 if (attr->attr.getBaseAttr().isSymbolRefAttr())
1794 resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
1795 }
1796 /// Results.
1797 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
1798 resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
1799
1800 // The code used to add an effect instance.
1801 // {0}: The effect class.
1802 // {1}: Optional value or symbol reference.
1803 // {1}: The resource class.
1804 const char *addEffectCode =
1805 " effects.emplace_back({0}::get(), {1}{2}::get());\n";
1806
1807 for (auto &it : interfaceEffects) {
1808 // Generate the 'getEffects' method.
1809 std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::"
1810 "SideEffects::EffectInstance<{0}>> &",
1811 it.first())
1812 .str();
1813 auto *getEffects =
1814 opClass.addMethodAndPrune("void", "getEffects", type, "effects");
1815 auto &body = getEffects->body();
1816
1817 // Add effect instances for each of the locations marked on the operation.
1818 for (auto &location : it.second) {
1819 StringRef effect = location.effect.getName();
1820 StringRef resource = location.effect.getResource();
1821 if (location.kind == EffectKind::Static) {
1822 // A static instance has no attached value.
1823 body << llvm::formatv(addEffectCode, effect, "", resource).str();
1824 } else if (location.kind == EffectKind::Symbol) {
1825 // A symbol reference requires adding the proper attribute.
1826 const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
1827 if (attr->attr.isOptional()) {
1828 body << " if (auto symbolRef = " << attr->name << "Attr())\n "
1829 << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
1830 .str();
1831 } else {
1832 body << llvm::formatv(addEffectCode, effect, attr->name + "(), ",
1833 resource)
1834 .str();
1835 }
1836 } else {
1837 // Otherwise this is an operand/result, so we need to attach the Value.
1838 body << " for (::mlir::Value value : getODS"
1839 << (location.kind == EffectKind::Operand ? "Operands" : "Results")
1840 << "(" << location.index << "))\n "
1841 << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
1842 }
1843 }
1844 }
1845 }
1846
genTypeInterfaceMethods()1847 void OpEmitter::genTypeInterfaceMethods() {
1848 if (!op.allResultTypesKnown())
1849 return;
1850 // Generate 'inferReturnTypes' method declaration using the interface method
1851 // declared in 'InferTypeOpInterface' op interface.
1852 const auto *trait = dyn_cast<InterfaceOpTrait>(
1853 op.getTrait("::mlir::InferTypeOpInterface::Trait"));
1854 auto interface = trait->getOpInterface();
1855 OpMethod *method = [&]() -> OpMethod * {
1856 for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
1857 if (interfaceMethod.getName() == "inferReturnTypes") {
1858 return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
1859 }
1860 }
1861 assert(0 && "unable to find inferReturnTypes interface method");
1862 return nullptr;
1863 }();
1864 auto &body = method->body();
1865 body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
1866
1867 FmtContext fctx;
1868 fctx.withBuilder("odsBuilder");
1869 body << " ::mlir::Builder odsBuilder(context);\n";
1870
1871 auto emitType =
1872 [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & {
1873 if (type.isArg()) {
1874 auto argIndex = type.getArg();
1875 assert(!op.getArg(argIndex).is<NamedAttribute *>());
1876 auto arg = op.getArgToOperandOrAttribute(argIndex);
1877 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
1878 return body << "operands[" << arg.operandOrAttributeIndex()
1879 << "].getType()";
1880 return body << "attributes[" << arg.operandOrAttributeIndex()
1881 << "].getType()";
1882 } else {
1883 return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
1884 }
1885 };
1886
1887 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
1888 body << " inferredReturnTypes[" << i << "] = ";
1889 auto types = op.getSameTypeAsResult(i);
1890 emitType(types[0]) << ";\n";
1891 if (types.size() == 1)
1892 continue;
1893 // TODO: We could verify equality here, but skipping that for verification.
1894 }
1895 body << " return ::mlir::success();";
1896 }
1897
genParser()1898 void OpEmitter::genParser() {
1899 if (!hasStringAttribute(def, "parser") ||
1900 hasStringAttribute(def, "assemblyFormat"))
1901 return;
1902
1903 SmallVector<OpMethodParameter, 2> paramList;
1904 paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1905 paramList.emplace_back("::mlir::OperationState &", "result");
1906 auto *method =
1907 opClass.addMethodAndPrune("::mlir::ParseResult", "parse",
1908 OpMethod::MP_Static, std::move(paramList));
1909
1910 FmtContext fctx;
1911 fctx.addSubst("cppClass", opClass.getClassName());
1912 auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r");
1913 method->body() << " " << tgfmt(parser, &fctx);
1914 }
1915
genPrinter()1916 void OpEmitter::genPrinter() {
1917 if (hasStringAttribute(def, "assemblyFormat"))
1918 return;
1919
1920 auto valueInit = def.getValueInit("printer");
1921 StringInit *stringInit = dyn_cast<StringInit>(valueInit);
1922 if (!stringInit)
1923 return;
1924
1925 auto *method =
1926 opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p");
1927 FmtContext fctx;
1928 fctx.addSubst("cppClass", opClass.getClassName());
1929 auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
1930 method->body() << " " << tgfmt(printer, &fctx);
1931 }
1932
genVerifier()1933 void OpEmitter::genVerifier() {
1934 auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify");
1935 auto &body = method->body();
1936 body << " if (failed(" << op.getAdaptorName()
1937 << "(*this).verify((*this)->getLoc()))) "
1938 << "return ::mlir::failure();\n";
1939
1940 auto *valueInit = def.getValueInit("verifier");
1941 StringInit *stringInit = dyn_cast<StringInit>(valueInit);
1942 bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
1943 populateSubstitutions(op, "(*this)->getAttr", "this->getODSOperands",
1944 "this->getODSResults", verifyCtx);
1945
1946 genAttributeVerifier(op, "(*this)->getAttr", "emitOpError(",
1947 /*emitVerificationRequiringOp=*/true, verifyCtx, body);
1948 genOperandResultVerifier(body, op.getOperands(), "operand");
1949 genOperandResultVerifier(body, op.getResults(), "result");
1950
1951 for (auto &trait : op.getTraits()) {
1952 if (auto *t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
1953 body << tgfmt(" if (!($0))\n "
1954 "return emitOpError(\"failed to verify that $1\");\n",
1955 &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
1956 t->getSummary());
1957 }
1958 }
1959
1960 genRegionVerifier(body);
1961 genSuccessorVerifier(body);
1962
1963 if (hasCustomVerify) {
1964 FmtContext fctx;
1965 fctx.addSubst("cppClass", opClass.getClassName());
1966 auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
1967 body << " " << tgfmt(printer, &fctx);
1968 } else {
1969 body << " return ::mlir::success();\n";
1970 }
1971 }
1972
genOperandResultVerifier(OpMethodBody & body,Operator::value_range values,StringRef valueKind)1973 void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
1974 Operator::value_range values,
1975 StringRef valueKind) {
1976 FmtContext fctx;
1977
1978 body << " {\n";
1979 body << " unsigned index = 0; (void)index;\n";
1980
1981 for (auto staticValue : llvm::enumerate(values)) {
1982 bool hasPredicate = staticValue.value().hasPredicate();
1983 bool isOptional = staticValue.value().isOptional();
1984 if (!hasPredicate && !isOptional)
1985 continue;
1986 body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
1987 // Capitalize the first letter to match the function name
1988 valueKind.substr(0, 1).upper(), valueKind.substr(1),
1989 staticValue.index());
1990
1991 // If the constraint is optional check that the value group has at most 1
1992 // value.
1993 if (isOptional) {
1994 body << formatv(" if (valueGroup{0}.size() > 1)\n"
1995 " return emitOpError(\"{1} group starting at #\") "
1996 "<< index << \" requires 0 or 1 element, but found \" << "
1997 "valueGroup{0}.size();\n",
1998 staticValue.index(), valueKind);
1999 }
2000
2001 // Otherwise, if there is no predicate there is nothing left to do.
2002 if (!hasPredicate)
2003 continue;
2004 // Emit a loop to check all the dynamic values in the pack.
2005 StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn(
2006 staticValue.value().constraint);
2007 body << " for (::mlir::Value v : valueGroup" << staticValue.index()
2008 << ") {\n"
2009 << " if (::mlir::failed(" << constraintFn
2010 << "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n"
2011 << " return ::mlir::failure();\n"
2012 << " ++index;\n"
2013 << " }\n";
2014 }
2015
2016 body << " }\n";
2017 }
2018
genRegionVerifier(OpMethodBody & body)2019 void OpEmitter::genRegionVerifier(OpMethodBody &body) {
2020 // If we have no regions, there is nothing more to do.
2021 unsigned numRegions = op.getNumRegions();
2022 if (numRegions == 0)
2023 return;
2024
2025 body << "{\n";
2026 body << " unsigned index = 0; (void)index;\n";
2027
2028 for (unsigned i = 0; i < numRegions; ++i) {
2029 const auto ®ion = op.getRegion(i);
2030 if (region.constraint.getPredicate().isNull())
2031 continue;
2032
2033 body << " for (::mlir::Region ®ion : ";
2034 body << formatv(region.isVariadic()
2035 ? "{0}()"
2036 : "::mlir::MutableArrayRef<::mlir::Region>((*this)"
2037 "->getRegion({1}))",
2038 region.name, i);
2039 body << ") {\n";
2040 auto constraint = tgfmt(region.constraint.getConditionTemplate(),
2041 &verifyCtx.withSelf("region"))
2042 .str();
2043
2044 body << formatv(" (void)region;\n"
2045 " if (!({0})) {\n "
2046 "return emitOpError(\"region #\") << index << \" {1}"
2047 "failed to "
2048 "verify constraint: {2}\";\n }\n",
2049 constraint,
2050 region.name.empty() ? "" : "('" + region.name + "') ",
2051 region.constraint.getSummary())
2052 << " ++index;\n"
2053 << " }\n";
2054 }
2055 body << " }\n";
2056 }
2057
genSuccessorVerifier(OpMethodBody & body)2058 void OpEmitter::genSuccessorVerifier(OpMethodBody &body) {
2059 // If we have no successors, there is nothing more to do.
2060 unsigned numSuccessors = op.getNumSuccessors();
2061 if (numSuccessors == 0)
2062 return;
2063
2064 body << "{\n";
2065 body << " unsigned index = 0; (void)index;\n";
2066
2067 for (unsigned i = 0; i < numSuccessors; ++i) {
2068 const auto &successor = op.getSuccessor(i);
2069 if (successor.constraint.getPredicate().isNull())
2070 continue;
2071
2072 if (successor.isVariadic()) {
2073 body << formatv(" for (::mlir::Block *successor : {0}()) {\n",
2074 successor.name);
2075 } else {
2076 body << " {\n";
2077 body << formatv(" ::mlir::Block *successor = {0}();\n",
2078 successor.name);
2079 }
2080 auto constraint = tgfmt(successor.constraint.getConditionTemplate(),
2081 &verifyCtx.withSelf("successor"))
2082 .str();
2083
2084 body << formatv(" (void)successor;\n"
2085 " if (!({0})) {\n "
2086 "return emitOpError(\"successor #\") << index << \"('{1}') "
2087 "failed to "
2088 "verify constraint: {2}\";\n }\n",
2089 constraint, successor.name,
2090 successor.constraint.getSummary())
2091 << " ++index;\n"
2092 << " }\n";
2093 }
2094 body << " }\n";
2095 }
2096
2097 /// Add a size count trait to the given operation class.
addSizeCountTrait(OpClass & opClass,StringRef traitKind,int numTotal,int numVariadic)2098 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
2099 int numTotal, int numVariadic) {
2100 if (numVariadic != 0) {
2101 if (numTotal == numVariadic)
2102 opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
2103 else
2104 opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
2105 Twine(numTotal - numVariadic) + ">::Impl");
2106 return;
2107 }
2108 switch (numTotal) {
2109 case 0:
2110 opClass.addTrait("::mlir::OpTrait::Zero" + traitKind);
2111 break;
2112 case 1:
2113 opClass.addTrait("::mlir::OpTrait::One" + traitKind);
2114 break;
2115 default:
2116 opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
2117 ">::Impl");
2118 break;
2119 }
2120 }
2121
genTraits()2122 void OpEmitter::genTraits() {
2123 // Add region size trait.
2124 unsigned numRegions = op.getNumRegions();
2125 unsigned numVariadicRegions = op.getNumVariadicRegions();
2126 addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
2127
2128 // Add result size traits.
2129 int numResults = op.getNumResults();
2130 int numVariadicResults = op.getNumVariableLengthResults();
2131 addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
2132
2133 // For single result ops with a known specific type, generate a OneTypedResult
2134 // trait.
2135 if (numResults == 1 && numVariadicResults == 0) {
2136 auto cppName = op.getResults().begin()->constraint.getCPPClassName();
2137 opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
2138 }
2139
2140 // Add successor size trait.
2141 unsigned numSuccessors = op.getNumSuccessors();
2142 unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
2143 addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
2144
2145 // Add variadic size trait and normal op traits.
2146 int numOperands = op.getNumOperands();
2147 int numVariadicOperands = op.getNumVariableLengthOperands();
2148
2149 // Add operand size trait.
2150 if (numVariadicOperands != 0) {
2151 if (numOperands == numVariadicOperands)
2152 opClass.addTrait("::mlir::OpTrait::VariadicOperands");
2153 else
2154 opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" +
2155 Twine(numOperands - numVariadicOperands) + ">::Impl");
2156 } else {
2157 switch (numOperands) {
2158 case 0:
2159 opClass.addTrait("::mlir::OpTrait::ZeroOperands");
2160 break;
2161 case 1:
2162 opClass.addTrait("::mlir::OpTrait::OneOperand");
2163 break;
2164 default:
2165 opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) +
2166 ">::Impl");
2167 break;
2168 }
2169 }
2170
2171 // Add the native and interface traits.
2172 for (const auto &trait : op.getTraits()) {
2173 if (auto opTrait = dyn_cast<tblgen::NativeOpTrait>(&trait))
2174 opClass.addTrait(opTrait->getTrait());
2175 else if (auto opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
2176 opClass.addTrait(opTrait->getTrait());
2177 }
2178 }
2179
genOpNameGetter()2180 void OpEmitter::genOpNameGetter() {
2181 auto *method = opClass.addMethodAndPrune(
2182 "::llvm::StringRef", "getOperationName", OpMethod::MP_Static);
2183 method->body() << " return \"" << op.getOperationName() << "\";\n";
2184 }
2185
genOpAsmInterface()2186 void OpEmitter::genOpAsmInterface() {
2187 // If the user only has one results or specifically added the Asm trait,
2188 // then don't generate it for them. We specifically only handle multi result
2189 // operations, because the name of a single result in the common case is not
2190 // interesting(generally 'result'/'output'/etc.).
2191 // TODO: We could also add a flag to allow operations to opt in to this
2192 // generation, even if they only have a single operation.
2193 int numResults = op.getNumResults();
2194 if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
2195 return;
2196
2197 SmallVector<StringRef, 4> resultNames(numResults);
2198 for (int i = 0; i != numResults; ++i)
2199 resultNames[i] = op.getResultName(i);
2200
2201 // Don't add the trait if none of the results have a valid name.
2202 if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
2203 return;
2204 opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
2205
2206 // Generate the right accessor for the number of results.
2207 auto *method = opClass.addMethodAndPrune(
2208 "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn");
2209 auto &body = method->body();
2210 for (int i = 0; i != numResults; ++i) {
2211 body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n"
2212 << " if (!llvm::empty(resultGroup" << i << "))\n"
2213 << " setNameFn(*resultGroup" << i << ".begin(), \""
2214 << resultNames[i] << "\");\n";
2215 }
2216 }
2217
2218 //===----------------------------------------------------------------------===//
2219 // OpOperandAdaptor emitter
2220 //===----------------------------------------------------------------------===//
2221
2222 namespace {
2223 // Helper class to emit Op operand adaptors to an output stream. Operand
2224 // adaptors are wrappers around ArrayRef<Value> that provide named operand
2225 // getters identical to those defined in the Op.
2226 class OpOperandAdaptorEmitter {
2227 public:
2228 static void emitDecl(const Operator &op, raw_ostream &os);
2229 static void emitDef(const Operator &op, raw_ostream &os);
2230
2231 private:
2232 explicit OpOperandAdaptorEmitter(const Operator &op);
2233
2234 // Add verification function. This generates a verify method for the adaptor
2235 // which verifies all the op-independent attribute constraints.
2236 void addVerification();
2237
2238 const Operator &op;
2239 Class adaptor;
2240 };
2241 } // end namespace
2242
OpOperandAdaptorEmitter(const Operator & op)2243 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
2244 : op(op), adaptor(op.getAdaptorName()) {
2245 adaptor.newField("::mlir::ValueRange", "odsOperands");
2246 adaptor.newField("::mlir::DictionaryAttr", "odsAttrs");
2247 const auto *attrSizedOperands =
2248 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2249 {
2250 SmallVector<OpMethodParameter, 2> paramList;
2251 paramList.emplace_back("::mlir::ValueRange", "values");
2252 paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
2253 attrSizedOperands ? "" : "nullptr");
2254 auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList));
2255
2256 constructor->addMemberInitializer("odsOperands", "values");
2257 constructor->addMemberInitializer("odsAttrs", "attrs");
2258 }
2259
2260 {
2261 auto *constructor = adaptor.addConstructorAndPrune(
2262 llvm::formatv("{0}&", op.getCppClassName()).str(), "op");
2263 constructor->addMemberInitializer("odsOperands", "op->getOperands()");
2264 constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
2265 }
2266
2267 std::string sizeAttrInit =
2268 formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
2269 generateNamedOperandGetters(op, adaptor, sizeAttrInit,
2270 /*rangeType=*/"::mlir::ValueRange",
2271 /*rangeBeginCall=*/"odsOperands.begin()",
2272 /*rangeSizeCall=*/"odsOperands.size()",
2273 /*getOperandCallPattern=*/"odsOperands[{0}]");
2274
2275 FmtContext fctx;
2276 fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
2277
2278 auto emitAttr = [&](StringRef name, Attribute attr) {
2279 auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body();
2280 body << " assert(odsAttrs && \"no attributes when constructing adapter\");"
2281 << "\n " << attr.getStorageType() << " attr = "
2282 << "odsAttrs.get(\"" << name << "\").";
2283 if (attr.hasDefaultValue() || attr.isOptional())
2284 body << "dyn_cast_or_null<";
2285 else
2286 body << "cast<";
2287 body << attr.getStorageType() << ">();\n";
2288
2289 if (attr.hasDefaultValue()) {
2290 // Use the default value if attribute is not set.
2291 // TODO: this is inefficient, we are recreating the attribute for every
2292 // call. This should be set instead.
2293 std::string defaultValue = std::string(
2294 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2295 body << " if (!attr)\n attr = " << defaultValue << ";\n";
2296 }
2297 body << " return attr;\n";
2298 };
2299
2300 for (auto &namedAttr : op.getAttributes()) {
2301 const auto &name = namedAttr.name;
2302 const auto &attr = namedAttr.attr;
2303 if (!attr.isDerivedAttr())
2304 emitAttr(name, attr);
2305 }
2306
2307 // Add verification function.
2308 addVerification();
2309 }
2310
addVerification()2311 void OpOperandAdaptorEmitter::addVerification() {
2312 auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify",
2313 "::mlir::Location", "loc");
2314 auto &body = method->body();
2315
2316 const char *checkAttrSizedValueSegmentsCode = R"(
2317 {
2318 auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>();
2319 auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
2320 if (numElements != {1})
2321 return emitError(loc, "'{0}' attribute for specifying {2} segments "
2322 "must have {1} elements, but got ") << numElements;
2323 }
2324 )";
2325
2326 // Verify a few traits first so that we can use
2327 // getODSOperands()/getODSResults() in the rest of the verifier.
2328 for (auto &trait : op.getTraits()) {
2329 if (auto *t = dyn_cast<tblgen::NativeOpTrait>(&trait)) {
2330 if (t->getTrait() == "::mlir::OpTrait::AttrSizedOperandSegments") {
2331 body << formatv(checkAttrSizedValueSegmentsCode,
2332 "operand_segment_sizes", op.getNumOperands(),
2333 "operand");
2334 } else if (t->getTrait() == "::mlir::OpTrait::AttrSizedResultSegments") {
2335 body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes",
2336 op.getNumResults(), "result");
2337 }
2338 }
2339 }
2340
2341 FmtContext verifyCtx;
2342 populateSubstitutions(op, "odsAttrs.get", "getODSOperands",
2343 "<no results should be generated>", verifyCtx);
2344 genAttributeVerifier(op, "odsAttrs.get",
2345 Twine("emitError(loc, \"'") + op.getOperationName() +
2346 "' op \"",
2347 /*emitVerificationRequiringOp*/ false, verifyCtx, body);
2348
2349 body << " return ::mlir::success();";
2350 }
2351
emitDecl(const Operator & op,raw_ostream & os)2352 void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) {
2353 OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os);
2354 }
2355
emitDef(const Operator & op,raw_ostream & os)2356 void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
2357 OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os);
2358 }
2359
2360 // Emits the opcode enum and op classes.
emitOpClasses(const RecordKeeper & recordKeeper,const std::vector<Record * > & defs,raw_ostream & os,bool emitDecl)2361 static void emitOpClasses(const RecordKeeper &recordKeeper,
2362 const std::vector<Record *> &defs, raw_ostream &os,
2363 bool emitDecl) {
2364 // First emit forward declaration for each class, this allows them to refer
2365 // to each others in traits for example.
2366 if (emitDecl) {
2367 os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
2368 os << "#undef GET_OP_FWD_DEFINES\n";
2369 for (auto *def : defs) {
2370 Operator op(*def);
2371 NamespaceEmitter emitter(os, op.getDialect());
2372 os << "class " << op.getCppClassName() << ";\n";
2373 }
2374 os << "#endif\n\n";
2375 }
2376
2377 IfDefScope scope("GET_OP_CLASSES", os);
2378 if (defs.empty())
2379 return;
2380
2381 // Generate all of the locally instantiated methods first.
2382 StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os,
2383 emitDecl);
2384 for (auto *def : defs) {
2385 Operator op(*def);
2386 NamespaceEmitter emitter(os, op.getDialect());
2387 if (emitDecl) {
2388 os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
2389 OpOperandAdaptorEmitter::emitDecl(op, os);
2390 OpEmitter::emitDecl(op, os, staticVerifierEmitter);
2391 } else {
2392 os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
2393 OpOperandAdaptorEmitter::emitDef(op, os);
2394 OpEmitter::emitDef(op, os, staticVerifierEmitter);
2395 }
2396 }
2397 }
2398
2399 // Emits a comma-separated list of the ops.
emitOpList(const std::vector<Record * > & defs,raw_ostream & os)2400 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
2401 IfDefScope scope("GET_OP_LIST", os);
2402
2403 interleave(
2404 // TODO: We are constructing the Operator wrapper instance just for
2405 // getting it's qualified class name here. Reduce the overhead by having a
2406 // lightweight version of Operator class just for that purpose.
2407 defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
2408 [&os]() { os << ",\n"; });
2409 }
2410
getOperationName(const Record & def)2411 static std::string getOperationName(const Record &def) {
2412 auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
2413 auto opName = def.getValueAsString("opName");
2414 if (prefix.empty())
2415 return std::string(opName);
2416 return std::string(llvm::formatv("{0}.{1}", prefix, opName));
2417 }
2418
2419 static std::vector<Record *>
getAllDerivedDefinitions(const RecordKeeper & recordKeeper,StringRef className)2420 getAllDerivedDefinitions(const RecordKeeper &recordKeeper,
2421 StringRef className) {
2422 Record *classDef = recordKeeper.getClass(className);
2423 if (!classDef)
2424 PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n");
2425
2426 llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
2427 std::vector<Record *> defs;
2428 for (const auto &def : recordKeeper.getDefs()) {
2429 if (!def.second->isSubClassOf(classDef))
2430 continue;
2431 // Include if no include filter or include filter matches.
2432 if (!opIncFilter.empty() &&
2433 !includeRegex.match(getOperationName(*def.second)))
2434 continue;
2435 // Unless there is an exclude filter and it matches.
2436 if (!opExcFilter.empty() &&
2437 excludeRegex.match(getOperationName(*def.second)))
2438 continue;
2439 defs.push_back(def.second.get());
2440 }
2441
2442 return defs;
2443 }
2444
emitOpDecls(const RecordKeeper & recordKeeper,raw_ostream & os)2445 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
2446 emitSourceFileHeader("Op Declarations", os);
2447
2448 const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2449 emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
2450
2451 return false;
2452 }
2453
emitOpDefs(const RecordKeeper & recordKeeper,raw_ostream & os)2454 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
2455 emitSourceFileHeader("Op Definitions", os);
2456
2457 const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2458 emitOpList(defs, os);
2459 emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
2460
2461 return false;
2462 }
2463
2464 static mlir::GenRegistration
2465 genOpDecls("gen-op-decls", "Generate op declarations",
__anon984378f01602(const RecordKeeper &records, raw_ostream &os) 2466 [](const RecordKeeper &records, raw_ostream &os) {
2467 return emitOpDecls(records, os);
2468 });
2469
2470 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
2471 [](const RecordKeeper &records,
__anon984378f01702(const RecordKeeper &records, raw_ostream &os) 2472 raw_ostream &os) {
2473 return emitOpDefs(records, os);
2474 });
2475