1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/TableGen/Pattern.h"
19 #include "mlir/TableGen/Predicate.h"
20 #include "mlir/TableGen/Type.h"
21 #include "llvm/ADT/FunctionExtras.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatAdapters.h"
27 #include "llvm/Support/PrettyStackTrace.h"
28 #include "llvm/Support/Signals.h"
29 #include "llvm/TableGen/Error.h"
30 #include "llvm/TableGen/Main.h"
31 #include "llvm/TableGen/Record.h"
32 #include "llvm/TableGen/TableGenBackend.h"
33
34 using namespace mlir;
35 using namespace mlir::tblgen;
36
37 using llvm::formatv;
38 using llvm::Record;
39 using llvm::RecordKeeper;
40
41 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
42
43 namespace llvm {
44 template <>
45 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
formatllvm::format_provider46 static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
47 raw_ostream &os, StringRef style) {
48 os << v.first << ":" << v.second;
49 }
50 };
51 } // end namespace llvm
52
53 //===----------------------------------------------------------------------===//
54 // PatternEmitter
55 //===----------------------------------------------------------------------===//
56
57 namespace {
58
59 class StaticMatcherHelper;
60
61 class PatternEmitter {
62 public:
63 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
64 StaticMatcherHelper &helper);
65
66 // Emits the mlir::RewritePattern struct named `rewriteName`.
67 void emit(StringRef rewriteName);
68
69 // Emits the static function of DAG matcher.
70 void emitStaticMatcher(DagNode tree, std::string funcName);
71
72 private:
73 // Emits the code for matching ops.
74 void emitMatchLogic(DagNode tree, StringRef opName);
75
76 // Emits the code for rewriting ops.
77 void emitRewriteLogic();
78
79 //===--------------------------------------------------------------------===//
80 // Match utilities
81 //===--------------------------------------------------------------------===//
82
83 // Emits C++ statements for matching the DAG structure.
84 void emitMatch(DagNode tree, StringRef name, int depth);
85
86 // Emit C++ function call to static DAG matcher.
87 void emitStaticMatchCall(DagNode tree, StringRef name);
88
89 // Emits C++ statements for matching using a native code call.
90 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
91
92 // Emits C++ statements for matching the op constrained by the given DAG
93 // `tree` returning the op's variable name.
94 void emitOpMatch(DagNode tree, StringRef opName, int depth);
95
96 // Emits C++ statements for matching the `argIndex`-th argument of the given
97 // DAG `tree` as an operand. operandIndex is the index in the DAG excluding
98 // the preceding attributes.
99 void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
100 int operandIndex, int depth);
101
102 // Emits C++ statements for matching the `argIndex`-th argument of the given
103 // DAG `tree` as an attribute.
104 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
105 int depth);
106
107 // Emits C++ for checking a match with a corresponding match failure
108 // diagnostic.
109 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
110 const llvm::formatv_object_base &failureFmt);
111
112 // Emits C++ for checking a match with a corresponding match failure
113 // diagnostics.
114 void emitMatchCheck(StringRef opName, const std::string &matchStr,
115 const std::string &failureStr);
116
117 //===--------------------------------------------------------------------===//
118 // Rewrite utilities
119 //===--------------------------------------------------------------------===//
120
121 // The entry point for handling a result pattern rooted at `resultTree`. This
122 // method dispatches to concrete handlers according to `resultTree`'s kind and
123 // returns a symbol representing the whole value pack. Callers are expected to
124 // further resolve the symbol according to the specific use case.
125 //
126 // `depth` is the nesting level of `resultTree`; 0 means top-level result
127 // pattern. For top-level result pattern, `resultIndex` indicates which result
128 // of the matched root op this pattern is intended to replace, which can be
129 // used to deduce the result type of the op generated from this result
130 // pattern.
131 std::string handleResultPattern(DagNode resultTree, int resultIndex,
132 int depth);
133
134 // Emits the C++ statement to replace the matched DAG with a value built via
135 // calling native C++ code.
136 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
137
138 // Returns the symbol of the old value serving as the replacement.
139 StringRef handleReplaceWithValue(DagNode tree);
140
141 // Trailing directives are used at the end of DAG node argument lists to
142 // specify additional behaviour for op matchers and creators, etc.
143 struct TrailingDirectives {
144 // DAG node containing the `location` directive. Null if there is none.
145 DagNode location;
146
147 // DAG node containing the `returnType` directive. Null if there is none.
148 DagNode returnType;
149
150 // Number of found trailing directives.
151 int numDirectives;
152 };
153
154 // Collect any trailing directives.
155 TrailingDirectives getTrailingDirectives(DagNode tree);
156
157 // Returns the location value to use.
158 std::string getLocation(TrailingDirectives &tail);
159
160 // Returns the location value to use.
161 std::string handleLocationDirective(DagNode tree);
162
163 // Emit return type argument.
164 std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
165
166 // Emits the C++ statement to build a new op out of the given DAG `tree` and
167 // returns the variable name that this op is assigned to. If the root op in
168 // DAG `tree` has a specified name, the created op will be assigned to a
169 // variable of the given name. Otherwise, a unique name will be used as the
170 // result value name.
171 std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
172
173 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
174
175 // Emits a local variable for each value and attribute to be used for creating
176 // an op.
177 void createSeparateLocalVarsForOpArgs(DagNode node,
178 ChildNodeIndexNameMap &childNodeNames);
179
180 // Emits the concrete arguments used to call an op's builder.
181 void supplyValuesForOpArgs(DagNode node,
182 const ChildNodeIndexNameMap &childNodeNames,
183 int depth);
184
185 // Emits the local variables for holding all values as a whole and all named
186 // attributes as a whole to be used for creating an op.
187 void createAggregateLocalVarsForOpArgs(
188 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
189
190 // Returns the C++ expression to construct a constant attribute of the given
191 // `value` for the given attribute kind `attr`.
192 std::string handleConstantAttr(Attribute attr, StringRef value);
193
194 // Returns the C++ expression to build an argument from the given DAG `leaf`.
195 // `patArgName` is used to bound the argument to the source pattern.
196 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
197
198 //===--------------------------------------------------------------------===//
199 // General utilities
200 //===--------------------------------------------------------------------===//
201
202 // Collects all of the operations within the given dag tree.
203 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
204
205 // Returns a unique symbol for a local variable of the given `op`.
206 std::string getUniqueSymbol(const Operator *op);
207
208 //===--------------------------------------------------------------------===//
209 // Symbol utilities
210 //===--------------------------------------------------------------------===//
211
212 // Returns how many static values the given DAG `node` correspond to.
213 int getNodeValueCount(DagNode node);
214
215 private:
216 // Pattern instantiation location followed by the location of multiclass
217 // prototypes used. This is intended to be used as a whole to
218 // PrintFatalError() on errors.
219 ArrayRef<llvm::SMLoc> loc;
220
221 // Op's TableGen Record to wrapper object.
222 RecordOperatorMap *opMap;
223
224 // Handy wrapper for pattern being emitted.
225 Pattern pattern;
226
227 // Map for all bound symbols' info.
228 SymbolInfoMap symbolInfoMap;
229
230 StaticMatcherHelper &staticMatcherHelper;
231
232 // The next unused ID for newly created values.
233 unsigned nextValueId;
234
235 raw_indented_ostream os;
236
237 // Format contexts containing placeholder substitutions.
238 FmtContext fmtCtx;
239 };
240
241 // Tracks DagNode's reference multiple times across patterns. Enables generating
242 // static matcher functions for DagNode's referenced multiple times rather than
243 // inlining them.
244 class StaticMatcherHelper {
245 public:
246 StaticMatcherHelper(RecordOperatorMap &mapper);
247
248 // Determine if we should inline the match logic or delegate to a static
249 // function.
useStaticMatcher(DagNode node)250 bool useStaticMatcher(DagNode node) {
251 return refStats[node] > kStaticMatcherThreshold;
252 }
253
254 // Get the name of the static DAG matcher function corresponding to the node.
getMatcherName(DagNode node)255 std::string getMatcherName(DagNode node) {
256 assert(useStaticMatcher(node));
257 return matcherNames[node];
258 }
259
260 // Collect the `Record`s, i.e., the DRR, so that we can get the information of
261 // the duplicated DAGs.
262 void addPattern(Record *record);
263
264 // Emit all static functions of DAG Matcher.
265 void populateStaticMatchers(raw_ostream &os);
266
267 private:
268 static constexpr unsigned kStaticMatcherThreshold = 1;
269
270 // Consider two patterns as down below,
271 // DagNode_Root_A DagNode_Root_B
272 // \ \
273 // DagNode_C DagNode_C
274 // \ \
275 // DagNode_D DagNode_D
276 //
277 // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
278 // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
279 // multiple times so we'll have static matchers for both of them. When we're
280 // emitting the match logic for DagNode_C, we will check if DagNode_D has the
281 // static matcher generated. If so, then we'll generate a call to the
282 // function, inline otherwise. In this case, inlining is not what we want. As
283 // a result, generate the static matcher in topological order to ensure all
284 // the dependent static matchers are generated and we can avoid accidentally
285 // inlining.
286 //
287 // The topological order of all the DagNodes among all patterns.
288 SmallVector<std::pair<DagNode, Record *>> topologicalOrder;
289
290 RecordOperatorMap &opMap;
291
292 // Records of the static function name of each DagNode
293 DenseMap<DagNode, std::string> matcherNames;
294
295 // After collecting all the DagNode in each pattern, `refStats` records the
296 // number of users for each DagNode. We will generate the static matcher for a
297 // DagNode while the number of users exceeds a certain threshold.
298 DenseMap<DagNode, unsigned> refStats;
299
300 // Number of static matcher generated. This is used to generate a unique name
301 // for each DagNode.
302 int staticMatcherCounter = 0;
303 };
304
305 } // end anonymous namespace
306
PatternEmitter(Record * pat,RecordOperatorMap * mapper,raw_ostream & os,StaticMatcherHelper & helper)307 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
308 raw_ostream &os, StaticMatcherHelper &helper)
309 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
310 symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), nextValueId(0),
311 os(os) {
312 fmtCtx.withBuilder("rewriter");
313 }
314
handleConstantAttr(Attribute attr,StringRef value)315 std::string PatternEmitter::handleConstantAttr(Attribute attr,
316 StringRef value) {
317 if (!attr.isConstBuildable())
318 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
319 " does not have the 'constBuilderCall' field");
320
321 // TODO: Verify the constants here
322 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
323 }
324
emitStaticMatcher(DagNode tree,std::string funcName)325 void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
326 os << formatv(
327 "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
328 "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
329 "*, 4> &tblgen_ops",
330 funcName);
331
332 // We pass the reference of the variables that need to be captured. Hence we
333 // need to collect all the symbols in the tree first.
334 pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true);
335 symbolInfoMap.assignUniqueAlternativeNames();
336 for (const auto &info : symbolInfoMap)
337 os << formatv(", {0}", info.second.getArgDecl(info.first));
338
339 os << ") {\n";
340 os.indent();
341 os << "(void)tblgen_ops;\n";
342
343 // Note that a static matcher is considered at least one step from the match
344 // entry.
345 emitMatch(tree, "op0", /*depth=*/1);
346
347 os << "return ::mlir::success();\n";
348 os.unindent();
349 os << "}\n\n";
350 }
351
352 // Helper function to match patterns.
emitMatch(DagNode tree,StringRef name,int depth)353 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
354 if (tree.isNativeCodeCall()) {
355 emitNativeCodeMatch(tree, name, depth);
356 return;
357 }
358
359 if (tree.isOperation()) {
360 emitOpMatch(tree, name, depth);
361 return;
362 }
363
364 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
365 }
366
emitStaticMatchCall(DagNode tree,StringRef opName)367 void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
368 std::string funcName = staticMatcherHelper.getMatcherName(tree);
369 os << formatv("if(failed({0}(rewriter, {1}, tblgen_ops", funcName, opName);
370
371 // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
372 // one pass.
373
374 // In general, bound symbol should have the unique name in the pattern but
375 // for the operand, binding same symbol to multiple operands imply a
376 // constraint at the same time. In this case, we will rename those operands
377 // with different names. As a result, we need to collect all the symbolInfos
378 // from the DagNode then get the updated name of the local variables from the
379 // global symbolInfoMap.
380
381 // Collect all the bound symbols in the Dag
382 SymbolInfoMap localSymbolMap(loc);
383 pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true);
384
385 for (const auto &info : localSymbolMap) {
386 auto name = info.first;
387 auto symboInfo = info.second;
388 auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo);
389 os << formatv(", {0}", ret->second.getVarName(name));
390 }
391
392 os << "))) {\n";
393 os.scope().os << "return ::mlir::failure();\n";
394 os << "}\n";
395 }
396
397 // Helper function to match patterns.
emitNativeCodeMatch(DagNode tree,StringRef opName,int depth)398 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
399 int depth) {
400 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
401 LLVM_DEBUG(tree.print(llvm::dbgs()));
402 LLVM_DEBUG(llvm::dbgs() << '\n');
403
404 // The order of generating static matcher follows the topological order so
405 // that for every dependent DagNode already have their static matcher
406 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
407 // is when we are generating the static matcher for a DagNode itself. In this
408 // case, we need to emit the function body rather than a function call.
409 if (staticMatcherHelper.useStaticMatcher(tree) &&
410 !staticMatcherHelper.getMatcherName(tree).empty()) {
411 emitStaticMatchCall(tree, opName);
412
413 // NativeCodeCall will never be at depth 0 so that we don't need to catch
414 // the root operation as emitOpMatch();
415
416 return;
417 }
418
419 // TODO(suderman): iterate through arguments, determine their types, output
420 // names.
421 SmallVector<std::string, 8> capture;
422
423 raw_indented_ostream::DelimitedScope scope(os);
424
425 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
426 std::string argName = formatv("arg{0}_{1}", depth, i);
427 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
428 os << "Value " << argName << ";\n";
429 } else {
430 auto leaf = tree.getArgAsLeaf(i);
431 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
432 os << "Attribute " << argName << ";\n";
433 } else {
434 os << "Value " << argName << ";\n";
435 }
436 }
437
438 capture.push_back(std::move(argName));
439 }
440
441 auto tail = getTrailingDirectives(tree);
442 if (tail.returnType)
443 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
444 auto locToUse = getLocation(tail);
445
446 auto fmt = tree.getNativeCodeTemplate();
447 if (fmt.count("$_self") != 1)
448 PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
449 "passing the defining Operation");
450
451 auto nativeCodeCall = std::string(
452 tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()),
453 static_cast<ArrayRef<std::string>>(capture)));
454
455 emitMatchCheck(opName, formatv("!failed({0})", nativeCodeCall),
456 formatv("\"{0} return failure\"", nativeCodeCall));
457
458 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
459 auto name = tree.getArgName(i);
460 if (!name.empty() && name != "_") {
461 os << formatv("{0} = {1};\n", name, capture[i]);
462 }
463 }
464
465 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
466 std::string argName = capture[i];
467
468 // Handle nested DAG construct first
469 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
470 PrintFatalError(
471 loc, formatv("Matching nested tree in NativeCodecall not support for "
472 "{0} as arg {1}",
473 argName, i));
474 }
475
476 DagLeaf leaf = tree.getArgAsLeaf(i);
477
478 // The parameter for native function doesn't bind any constraints.
479 if (leaf.isUnspecified())
480 continue;
481
482 auto constraint = leaf.getAsConstraint();
483
484 std::string self;
485 if (leaf.isAttrMatcher() || leaf.isConstantAttr())
486 self = argName;
487 else
488 self = formatv("{0}.getType()", argName);
489 emitMatchCheck(
490 opName,
491 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
492 formatv("\"operand {0} of native code call '{1}' failed to satisfy "
493 "constraint: "
494 "'{2}'\"",
495 i, tree.getNativeCodeTemplate(), constraint.getSummary()));
496 }
497
498 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
499 }
500
501 // Helper function to match patterns.
emitOpMatch(DagNode tree,StringRef opName,int depth)502 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
503 Operator &op = tree.getDialectOp(opMap);
504 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
505 << op.getOperationName() << "' at depth " << depth
506 << '\n');
507
508 auto getCastedName = [depth]() -> std::string {
509 return formatv("castedOp{0}", depth);
510 };
511
512 // The order of generating static matcher follows the topological order so
513 // that for every dependent DagNode already have their static matcher
514 // generated if needed. The reason we check if `getMatcherName(tree).empty()`
515 // is when we are generating the static matcher for a DagNode itself. In this
516 // case, we need to emit the function body rather than a function call.
517 if (staticMatcherHelper.useStaticMatcher(tree) &&
518 !staticMatcherHelper.getMatcherName(tree).empty()) {
519 emitStaticMatchCall(tree, opName);
520 // In the codegen of rewriter, we suppose that castedOp0 will capture the
521 // root operation. Manually add it if the root DagNode is a static matcher.
522 if (depth == 0)
523 os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
524 "(void){2};\n",
525 opName, op.getQualCppClassName(), getCastedName());
526 return;
527 }
528
529 std::string castedName = getCastedName();
530 os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
531 "(void){0};\n",
532 castedName, opName, op.getQualCppClassName());
533
534 // Skip the operand matching at depth 0 as the pattern rewriter already does.
535 if (depth != 0)
536 emitMatchCheck(opName, /*matchStr=*/castedName,
537 formatv("\"{0} is not {1} type\"", castedName,
538 op.getQualCppClassName()));
539
540 if (tree.getNumArgs() != op.getNumArgs())
541 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
542 "pattern vs. {2} in definition",
543 op.getOperationName(), tree.getNumArgs(),
544 op.getNumArgs()));
545
546 // If the operand's name is set, set to that variable.
547 auto name = tree.getSymbol();
548 if (!name.empty())
549 os << formatv("{0} = {1};\n", name, castedName);
550
551 for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) {
552 auto opArg = op.getArg(i);
553 std::string argName = formatv("op{0}", depth + 1);
554
555 // Handle nested DAG construct first
556 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
557 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
558 if (operand->isVariableLength()) {
559 auto error = formatv("use nested DAG construct to match op {0}'s "
560 "variadic operand #{1} unsupported now",
561 op.getOperationName(), i);
562 PrintFatalError(loc, error);
563 }
564 }
565 os << "{\n";
566
567 // Attributes don't count for getODSOperands.
568 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
569 os.indent() << formatv(
570 "auto *{0} = "
571 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
572 argName, castedName, nextOperand);
573 // Null check of operand's definingOp
574 emitMatchCheck(castedName, /*matchStr=*/argName,
575 formatv("\"Operand {0} of {1} has null definingOp\"",
576 nextOperand++, castedName));
577 emitMatch(argTree, argName, depth + 1);
578 os << formatv("tblgen_ops.push_back({0});\n", argName);
579 os.unindent() << "}\n";
580 continue;
581 }
582
583 // Next handle DAG leaf: operand or attribute
584 if (opArg.is<NamedTypeConstraint *>()) {
585 // emitOperandMatch's argument indexing counts attributes.
586 emitOperandMatch(tree, castedName, i, nextOperand, depth);
587 ++nextOperand;
588 } else if (opArg.is<NamedAttribute *>()) {
589 emitAttributeMatch(tree, opName, i, depth);
590 } else {
591 PrintFatalError(loc, "unhandled case when matching op");
592 }
593 }
594 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
595 << op.getOperationName() << "' at depth " << depth
596 << '\n');
597 }
598
emitOperandMatch(DagNode tree,StringRef opName,int argIndex,int operandIndex,int depth)599 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
600 int argIndex, int operandIndex,
601 int depth) {
602 Operator &op = tree.getDialectOp(opMap);
603 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
604 auto matcher = tree.getArgAsLeaf(argIndex);
605
606 // If a constraint is specified, we need to generate C++ statements to
607 // check the constraint.
608 if (!matcher.isUnspecified()) {
609 if (!matcher.isOperandMatcher()) {
610 PrintFatalError(
611 loc, formatv("the {1}-th argument of op '{0}' should be an operand",
612 op.getOperationName(), argIndex + 1));
613 }
614
615 // Only need to verify if the matcher's type is different from the one
616 // of op definition.
617 Constraint constraint = matcher.getAsConstraint();
618 if (operand->constraint != constraint) {
619 if (operand->isVariableLength()) {
620 auto error = formatv(
621 "further constrain op {0}'s variadic operand #{1} unsupported now",
622 op.getOperationName(), argIndex);
623 PrintFatalError(loc, error);
624 }
625 auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
626 opName, operandIndex);
627 emitMatchCheck(
628 opName,
629 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
630 formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
631 "'{2}'\"",
632 operand - op.operand_begin(), op.getOperationName(),
633 constraint.getSummary()));
634 }
635 }
636
637 // Capture the value
638 auto name = tree.getArgName(argIndex);
639 // `$_` is a special symbol to ignore op argument matching.
640 if (!name.empty() && name != "_") {
641 // We need to subtract the number of attributes before this operand to get
642 // the index in the operand list.
643 auto numPrevAttrs = std::count_if(
644 op.arg_begin(), op.arg_begin() + argIndex,
645 [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
646
647 auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex);
648 os << formatv("{0} = {1}.getODSOperands({2});\n",
649 res->second.getVarName(name), opName,
650 argIndex - numPrevAttrs);
651 }
652 }
653
emitAttributeMatch(DagNode tree,StringRef opName,int argIndex,int depth)654 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
655 int argIndex, int depth) {
656 Operator &op = tree.getDialectOp(opMap);
657 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
658 const auto &attr = namedAttr->attr;
659
660 os << "{\n";
661 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
662 "(void)tblgen_attr;\n",
663 opName, attr.getStorageType(), namedAttr->name);
664
665 // TODO: This should use getter method to avoid duplication.
666 if (attr.hasDefaultValue()) {
667 os << "if (!tblgen_attr) tblgen_attr = "
668 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
669 attr.getDefaultValue()))
670 << ";\n";
671 } else if (attr.isOptional()) {
672 // For a missing attribute that is optional according to definition, we
673 // should just capture a mlir::Attribute() to signal the missing state.
674 // That is precisely what getAttr() returns on missing attributes.
675 } else {
676 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
677 formatv("\"expected op '{0}' to have attribute '{1}' "
678 "of type '{2}'\"",
679 op.getOperationName(), namedAttr->name,
680 attr.getStorageType()));
681 }
682
683 auto matcher = tree.getArgAsLeaf(argIndex);
684 if (!matcher.isUnspecified()) {
685 if (!matcher.isAttrMatcher()) {
686 PrintFatalError(
687 loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
688 op.getOperationName(), argIndex + 1));
689 }
690
691 // If a constraint is specified, we need to generate C++ statements to
692 // check the constraint.
693 emitMatchCheck(
694 opName,
695 tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
696 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
697 "{2}\"",
698 op.getOperationName(), namedAttr->name,
699 matcher.getAsConstraint().getSummary()));
700 }
701
702 // Capture the value
703 auto name = tree.getArgName(argIndex);
704 // `$_` is a special symbol to ignore op argument matching.
705 if (!name.empty() && name != "_") {
706 os << formatv("{0} = tblgen_attr;\n", name);
707 }
708
709 os.unindent() << "}\n";
710 }
711
emitMatchCheck(StringRef opName,const FmtObjectBase & matchFmt,const llvm::formatv_object_base & failureFmt)712 void PatternEmitter::emitMatchCheck(
713 StringRef opName, const FmtObjectBase &matchFmt,
714 const llvm::formatv_object_base &failureFmt) {
715 emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
716 }
717
emitMatchCheck(StringRef opName,const std::string & matchStr,const std::string & failureStr)718 void PatternEmitter::emitMatchCheck(StringRef opName,
719 const std::string &matchStr,
720 const std::string &failureStr) {
721
722 os << "if (!(" << matchStr << "))";
723 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
724 << ", [&](::mlir::Diagnostic &diag) {\n diag << "
725 << failureStr << ";\n});";
726 }
727
emitMatchLogic(DagNode tree,StringRef opName)728 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
729 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
730 int depth = 0;
731 emitMatch(tree, opName, depth);
732
733 for (auto &appliedConstraint : pattern.getConstraints()) {
734 auto &constraint = appliedConstraint.constraint;
735 auto &entities = appliedConstraint.entities;
736
737 auto condition = constraint.getConditionTemplate();
738 if (isa<TypeConstraint>(constraint)) {
739 auto self = formatv("({0}.getType())",
740 symbolInfoMap.getValueAndRangeUse(entities.front()));
741 emitMatchCheck(
742 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
743 formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
744 entities.front(), constraint.getSummary()));
745
746 } else if (isa<AttrConstraint>(constraint)) {
747 PrintFatalError(
748 loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
749 } else {
750 // TODO: replace formatv arguments with the exact specified
751 // args.
752 if (entities.size() > 4) {
753 PrintFatalError(loc, "only support up to 4-entity constraints now");
754 }
755 SmallVector<std::string, 4> names;
756 int i = 0;
757 for (int e = entities.size(); i < e; ++i)
758 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
759 std::string self = appliedConstraint.self;
760 if (!self.empty())
761 self = symbolInfoMap.getValueAndRangeUse(self);
762 for (; i < 4; ++i)
763 names.push_back("<unused>");
764 emitMatchCheck(opName,
765 tgfmt(condition, &fmtCtx.withSelf(self), names[0],
766 names[1], names[2], names[3]),
767 formatv("\"entities '{0}' failed to satisfy constraint: "
768 "{1}\"",
769 llvm::join(entities, ", "),
770 constraint.getSummary()));
771 }
772 }
773
774 // Some of the operands could be bound to the same symbol name, we need
775 // to enforce equality constraint on those.
776 // TODO: we should be able to emit equality checks early
777 // and short circuit unnecessary work if vars are not equal.
778 for (auto symbolInfoIt = symbolInfoMap.begin();
779 symbolInfoIt != symbolInfoMap.end();) {
780 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
781 auto startRange = range.first;
782 auto endRange = range.second;
783
784 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
785 for (++startRange; startRange != endRange; ++startRange) {
786 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
787 emitMatchCheck(
788 opName,
789 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
790 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
791 secondOperand));
792 }
793
794 symbolInfoIt = endRange;
795 }
796
797 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
798 }
799
collectOps(DagNode tree,llvm::SmallPtrSetImpl<const Operator * > & ops)800 void PatternEmitter::collectOps(DagNode tree,
801 llvm::SmallPtrSetImpl<const Operator *> &ops) {
802 // Check if this tree is an operation.
803 if (tree.isOperation()) {
804 const Operator &op = tree.getDialectOp(opMap);
805 LLVM_DEBUG(llvm::dbgs()
806 << "found operation " << op.getOperationName() << '\n');
807 ops.insert(&op);
808 }
809
810 // Recurse the arguments of the tree.
811 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
812 if (auto child = tree.getArgAsNestedDag(i))
813 collectOps(child, ops);
814 }
815
emit(StringRef rewriteName)816 void PatternEmitter::emit(StringRef rewriteName) {
817 // Get the DAG tree for the source pattern.
818 DagNode sourceTree = pattern.getSourcePattern();
819
820 const Operator &rootOp = pattern.getSourceRootOp();
821 auto rootName = rootOp.getOperationName();
822
823 // Collect the set of result operations.
824 llvm::SmallPtrSet<const Operator *, 4> resultOps;
825 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
826 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
827 collectOps(pattern.getResultPattern(i), resultOps);
828 }
829 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
830
831 // Emit RewritePattern for Pattern.
832 auto locs = pattern.getLocation();
833 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
834 make_range(locs.rbegin(), locs.rend()));
835 os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
836 {0}(::mlir::MLIRContext *context)
837 : ::mlir::RewritePattern("{1}", {2}, context, {{)",
838 rewriteName, rootName, pattern.getBenefit());
839 // Sort result operators by name.
840 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
841 resultOps.end());
842 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
843 return lhs->getOperationName() < rhs->getOperationName();
844 });
845 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
846 os << '"' << op->getOperationName() << '"';
847 });
848 os << "}) {}\n";
849
850 // Emit matchAndRewrite() function.
851 {
852 auto classScope = os.scope();
853 os.reindent(R"(
854 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
855 ::mlir::PatternRewriter &rewriter) const override {)")
856 << '\n';
857 {
858 auto functionScope = os.scope();
859
860 // Register all symbols bound in the source pattern.
861 pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
862
863 LLVM_DEBUG(llvm::dbgs()
864 << "start creating local variables for capturing matches\n");
865 os << "// Variables for capturing values and attributes used while "
866 "creating ops\n";
867 // Create local variables for storing the arguments and results bound
868 // to symbols.
869 for (const auto &symbolInfoPair : symbolInfoMap) {
870 const auto &symbol = symbolInfoPair.first;
871 const auto &info = symbolInfoPair.second;
872
873 os << info.getVarDecl(symbol);
874 }
875 // TODO: capture ops with consistent numbering so that it can be
876 // reused for fused loc.
877 os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
878 LLVM_DEBUG(llvm::dbgs()
879 << "done creating local variables for capturing matches\n");
880
881 os << "// Match\n";
882 os << "tblgen_ops.push_back(op0);\n";
883 emitMatchLogic(sourceTree, "op0");
884
885 os << "\n// Rewrite\n";
886 emitRewriteLogic();
887
888 os << "return ::mlir::success();\n";
889 }
890 os << "};\n";
891 }
892 os << "};\n\n";
893 }
894
emitRewriteLogic()895 void PatternEmitter::emitRewriteLogic() {
896 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
897 const Operator &rootOp = pattern.getSourceRootOp();
898 int numExpectedResults = rootOp.getNumResults();
899 int numResultPatterns = pattern.getNumResultPatterns();
900
901 // First register all symbols bound to ops generated in result patterns.
902 pattern.collectResultPatternBoundSymbols(symbolInfoMap);
903
904 // Only the last N static values generated are used to replace the matched
905 // root N-result op. We need to calculate the starting index (of the results
906 // of the matched op) each result pattern is to replace.
907 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
908 // If we don't need to replace any value at all, set the replacement starting
909 // index as the number of result patterns so we skip all of them when trying
910 // to replace the matched op's results.
911 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
912 for (int i = numResultPatterns - 1; i >= 0; --i) {
913 auto numValues = getNodeValueCount(pattern.getResultPattern(i));
914 offsets[i] = offsets[i + 1] - numValues;
915 if (offsets[i] == 0) {
916 if (replStartIndex == -1)
917 replStartIndex = i;
918 } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
919 auto error = formatv(
920 "cannot use the same multi-result op '{0}' to generate both "
921 "auxiliary values and values to be used for replacing the matched op",
922 pattern.getResultPattern(i).getSymbol());
923 PrintFatalError(loc, error);
924 }
925 }
926
927 if (offsets.front() > 0) {
928 const char error[] = "no enough values generated to replace the matched op";
929 PrintFatalError(loc, error);
930 }
931
932 os << "auto odsLoc = rewriter.getFusedLoc({";
933 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
934 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
935 }
936 os << "}); (void)odsLoc;\n";
937
938 // Process auxiliary result patterns.
939 for (int i = 0; i < replStartIndex; ++i) {
940 DagNode resultTree = pattern.getResultPattern(i);
941 auto val = handleResultPattern(resultTree, offsets[i], 0);
942 // Normal op creation will be streamed to `os` by the above call; but
943 // NativeCodeCall will only be materialized to `os` if it is used. Here
944 // we are handling auxiliary patterns so we want the side effect even if
945 // NativeCodeCall is not replacing matched root op's results.
946 if (resultTree.isNativeCodeCall() &&
947 resultTree.getNumReturnsOfNativeCode() == 0)
948 os << val << ";\n";
949 }
950
951 if (numExpectedResults == 0) {
952 assert(replStartIndex >= numResultPatterns &&
953 "invalid auxiliary vs. replacement pattern division!");
954 // No result to replace. Just erase the op.
955 os << "rewriter.eraseOp(op0);\n";
956 } else {
957 // Process replacement result patterns.
958 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
959 for (int i = replStartIndex; i < numResultPatterns; ++i) {
960 DagNode resultTree = pattern.getResultPattern(i);
961 auto val = handleResultPattern(resultTree, offsets[i], 0);
962 os << "\n";
963 // Resolve each symbol for all range use so that we can loop over them.
964 // We need an explicit cast to `SmallVector` to capture the cases where
965 // `{0}` resolves to an `Operation::result_range` as well as cases that
966 // are not iterable (e.g. vector that gets wrapped in additional braces by
967 // RewriterGen).
968 // TODO: Revisit the need for materializing a vector.
969 os << symbolInfoMap.getAllRangeUse(
970 val,
971 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
972 " tblgen_repl_values.push_back(v);\n}\n",
973 "\n");
974 }
975 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
976 }
977
978 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
979 }
980
getUniqueSymbol(const Operator * op)981 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
982 return std::string(
983 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
984 }
985
handleResultPattern(DagNode resultTree,int resultIndex,int depth)986 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
987 int resultIndex, int depth) {
988 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
989 LLVM_DEBUG(resultTree.print(llvm::dbgs()));
990 LLVM_DEBUG(llvm::dbgs() << '\n');
991
992 if (resultTree.isLocationDirective()) {
993 PrintFatalError(loc,
994 "location directive can only be used with op creation");
995 }
996
997 if (resultTree.isNativeCodeCall())
998 return handleReplaceWithNativeCodeCall(resultTree, depth);
999
1000 if (resultTree.isReplaceWithValue())
1001 return handleReplaceWithValue(resultTree).str();
1002
1003 // Normal op creation.
1004 auto symbol = handleOpCreation(resultTree, resultIndex, depth);
1005 if (resultTree.getSymbol().empty()) {
1006 // This is an op not explicitly bound to a symbol in the rewrite rule.
1007 // Register the auto-generated symbol for it.
1008 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
1009 }
1010 return symbol;
1011 }
1012
handleReplaceWithValue(DagNode tree)1013 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
1014 assert(tree.isReplaceWithValue());
1015
1016 if (tree.getNumArgs() != 1) {
1017 PrintFatalError(
1018 loc, "replaceWithValue directive must take exactly one argument");
1019 }
1020
1021 if (!tree.getSymbol().empty()) {
1022 PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
1023 }
1024
1025 return tree.getArgName(0);
1026 }
1027
handleLocationDirective(DagNode tree)1028 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
1029 assert(tree.isLocationDirective());
1030 auto lookUpArgLoc = [this, &tree](int idx) {
1031 const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
1032 return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
1033 };
1034
1035 if (tree.getNumArgs() == 0)
1036 llvm::PrintFatalError(
1037 "At least one argument to location directive required");
1038
1039 if (!tree.getSymbol().empty())
1040 PrintFatalError(loc, "cannot bind symbol to location");
1041
1042 if (tree.getNumArgs() == 1) {
1043 DagLeaf leaf = tree.getArgAsLeaf(0);
1044 if (leaf.isStringAttr())
1045 return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"))",
1046 leaf.getStringAttr())
1047 .str();
1048 return lookUpArgLoc(0);
1049 }
1050
1051 std::string ret;
1052 llvm::raw_string_ostream os(ret);
1053 std::string strAttr;
1054 os << "rewriter.getFusedLoc({";
1055 bool first = true;
1056 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1057 DagLeaf leaf = tree.getArgAsLeaf(i);
1058 // Handle the optional string value.
1059 if (leaf.isStringAttr()) {
1060 if (!strAttr.empty())
1061 llvm::PrintFatalError("Only one string attribute may be specified");
1062 strAttr = leaf.getStringAttr();
1063 continue;
1064 }
1065 os << (first ? "" : ", ") << lookUpArgLoc(i);
1066 first = false;
1067 }
1068 os << "}";
1069 if (!strAttr.empty()) {
1070 os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
1071 }
1072 os << ")";
1073 return os.str();
1074 }
1075
handleReturnTypeArg(DagNode returnType,int i,int depth)1076 std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
1077 int depth) {
1078 // Nested NativeCodeCall.
1079 if (auto dagNode = returnType.getArgAsNestedDag(i)) {
1080 if (!dagNode.isNativeCodeCall())
1081 PrintFatalError(loc, "nested DAG in `returnType` must be a native code "
1082 "call");
1083 return handleReplaceWithNativeCodeCall(dagNode, depth);
1084 }
1085 // String literal.
1086 auto dagLeaf = returnType.getArgAsLeaf(i);
1087 if (dagLeaf.isStringAttr())
1088 return tgfmt(dagLeaf.getStringAttr(), &fmtCtx);
1089 return tgfmt(
1090 "$0.getType()", &fmtCtx,
1091 handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i)));
1092 }
1093
handleOpArgument(DagLeaf leaf,StringRef patArgName)1094 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
1095 StringRef patArgName) {
1096 if (leaf.isStringAttr())
1097 PrintFatalError(loc, "raw string not supported as argument");
1098 if (leaf.isConstantAttr()) {
1099 auto constAttr = leaf.getAsConstantAttr();
1100 return handleConstantAttr(constAttr.getAttribute(),
1101 constAttr.getConstantValue());
1102 }
1103 if (leaf.isEnumAttrCase()) {
1104 auto enumCase = leaf.getAsEnumAttrCase();
1105 if (enumCase.isStrCase())
1106 return handleConstantAttr(enumCase, enumCase.getSymbol());
1107 // This is an enum case backed by an IntegerAttr. We need to get its value
1108 // to build the constant.
1109 std::string val = std::to_string(enumCase.getValue());
1110 return handleConstantAttr(enumCase, val);
1111 }
1112
1113 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
1114 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
1115 if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
1116 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
1117 << "' (via symbol ref)\n");
1118 return argName;
1119 }
1120 if (leaf.isNativeCodeCall()) {
1121 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
1122 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
1123 << "' (via NativeCodeCall)\n");
1124 return std::string(repl);
1125 }
1126 PrintFatalError(loc, "unhandled case when rewriting op");
1127 }
1128
handleReplaceWithNativeCodeCall(DagNode tree,int depth)1129 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
1130 int depth) {
1131 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
1132 LLVM_DEBUG(tree.print(llvm::dbgs()));
1133 LLVM_DEBUG(llvm::dbgs() << '\n');
1134
1135 auto fmt = tree.getNativeCodeTemplate();
1136
1137 SmallVector<std::string, 16> attrs;
1138
1139 auto tail = getTrailingDirectives(tree);
1140 if (tail.returnType)
1141 PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
1142 auto locToUse = getLocation(tail);
1143
1144 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1145 if (tree.isNestedDagArg(i)) {
1146 attrs.push_back(
1147 handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
1148 } else {
1149 attrs.push_back(
1150 handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
1151 }
1152 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
1153 << " replacement: " << attrs[i] << "\n");
1154 }
1155
1156 std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse),
1157 static_cast<ArrayRef<std::string>>(attrs));
1158
1159 // In general, NativeCodeCall without naming binding don't need this. To
1160 // ensure void helper function has been correctly labeled, i.e., use
1161 // NativeCodeCallVoid, we cache the result to a local variable so that we will
1162 // get a compilation error in the auto-generated file.
1163 // Example.
1164 // // In the td file
1165 // Pat<(...), (NativeCodeCall<Foo> ...)>
1166 //
1167 // ---
1168 //
1169 // // In the auto-generated .cpp
1170 // ...
1171 // // Causes compilation error if Foo() returns void.
1172 // auto nativeVar = Foo();
1173 // ...
1174 if (tree.getNumReturnsOfNativeCode() != 0) {
1175 // Determine the local variable name for return value.
1176 std::string varName =
1177 SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
1178 if (varName.empty()) {
1179 varName = formatv("nativeVar_{0}", nextValueId++);
1180 // Register the local variable for later uses.
1181 symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
1182 }
1183
1184 // Catch the return value of helper function.
1185 os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
1186
1187 if (!tree.getSymbol().empty())
1188 symbol = tree.getSymbol().str();
1189 else
1190 symbol = varName;
1191 }
1192
1193 return symbol;
1194 }
1195
getNodeValueCount(DagNode node)1196 int PatternEmitter::getNodeValueCount(DagNode node) {
1197 if (node.isOperation()) {
1198 // If the op is bound to a symbol in the rewrite rule, query its result
1199 // count from the symbol info map.
1200 auto symbol = node.getSymbol();
1201 if (!symbol.empty()) {
1202 return symbolInfoMap.getStaticValueCount(symbol);
1203 }
1204 // Otherwise this is an unbound op; we will use all its results.
1205 return pattern.getDialectOp(node).getNumResults();
1206 }
1207
1208 if (node.isNativeCodeCall())
1209 return node.getNumReturnsOfNativeCode();
1210
1211 return 1;
1212 }
1213
1214 PatternEmitter::TrailingDirectives
getTrailingDirectives(DagNode tree)1215 PatternEmitter::getTrailingDirectives(DagNode tree) {
1216 TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0};
1217
1218 // Look backwards through the arguments.
1219 auto numPatArgs = tree.getNumArgs();
1220 for (int i = numPatArgs - 1; i >= 0; --i) {
1221 auto dagArg = tree.getArgAsNestedDag(i);
1222 // A leaf is not a directive. Stop looking.
1223 if (!dagArg)
1224 break;
1225
1226 auto isLocation = dagArg.isLocationDirective();
1227 auto isReturnType = dagArg.isReturnTypeDirective();
1228 // If encountered a DAG node that isn't a trailing directive, stop looking.
1229 if (!(isLocation || isReturnType))
1230 break;
1231 // Save the directive, but error if one of the same type was already
1232 // found.
1233 ++tail.numDirectives;
1234 if (isLocation) {
1235 if (tail.location)
1236 PrintFatalError(loc, "`location` directive can only be specified "
1237 "once");
1238 tail.location = dagArg;
1239 } else if (isReturnType) {
1240 if (tail.returnType)
1241 PrintFatalError(loc, "`returnType` directive can only be specified "
1242 "once");
1243 tail.returnType = dagArg;
1244 }
1245 }
1246
1247 return tail;
1248 }
1249
1250 std::string
getLocation(PatternEmitter::TrailingDirectives & tail)1251 PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
1252 if (tail.location)
1253 return handleLocationDirective(tail.location);
1254
1255 // If no explicit location is given, use the default, all fused, location.
1256 return "odsLoc";
1257 }
1258
handleOpCreation(DagNode tree,int resultIndex,int depth)1259 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1260 int depth) {
1261 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1262 LLVM_DEBUG(tree.print(llvm::dbgs()));
1263 LLVM_DEBUG(llvm::dbgs() << '\n');
1264
1265 Operator &resultOp = tree.getDialectOp(opMap);
1266 auto numOpArgs = resultOp.getNumArgs();
1267 auto numPatArgs = tree.getNumArgs();
1268
1269 auto tail = getTrailingDirectives(tree);
1270 auto locToUse = getLocation(tail);
1271
1272 auto inPattern = numPatArgs - tail.numDirectives;
1273 if (numOpArgs != inPattern) {
1274 PrintFatalError(loc,
1275 formatv("resultant op '{0}' argument number mismatch: "
1276 "{1} in pattern vs. {2} in definition",
1277 resultOp.getOperationName(), inPattern, numOpArgs));
1278 }
1279
1280 // A map to collect all nested DAG child nodes' names, with operand index as
1281 // the key. This includes both bound and unbound child nodes.
1282 ChildNodeIndexNameMap childNodeNames;
1283
1284 // First go through all the child nodes who are nested DAG constructs to
1285 // create ops for them and remember the symbol names for them, so that we can
1286 // use the results in the current node. This happens in a recursive manner.
1287 for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1288 if (auto child = tree.getArgAsNestedDag(i))
1289 childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1290 }
1291
1292 // The name of the local variable holding this op.
1293 std::string valuePackName;
1294 // The symbol for holding the result of this pattern. Note that the result of
1295 // this pattern is not necessarily the same as the variable created by this
1296 // pattern because we can use `__N` suffix to refer only a specific result if
1297 // the generated op is a multi-result op.
1298 std::string resultValue;
1299 if (tree.getSymbol().empty()) {
1300 // No symbol is explicitly bound to this op in the pattern. Generate a
1301 // unique name.
1302 valuePackName = resultValue = getUniqueSymbol(&resultOp);
1303 } else {
1304 resultValue = std::string(tree.getSymbol());
1305 // Strip the index to get the name for the value pack and use it to name the
1306 // local variable for the op.
1307 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1308 }
1309
1310 // Create the local variable for this op.
1311 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1312 valuePackName);
1313
1314 // Right now ODS don't have general type inference support. Except a few
1315 // special cases listed below, DRR needs to supply types for all results
1316 // when building an op.
1317 bool isSameOperandsAndResultType =
1318 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1319 bool useFirstAttr =
1320 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1321
1322 if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
1323 // We know how to deduce the result type for ops with these traits and we've
1324 // generated builders taking aggregate parameters. Use those builders to
1325 // create the ops.
1326
1327 // First prepare local variables for op arguments used in builder call.
1328 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1329
1330 // Then create the op.
1331 os.scope("", "\n}\n").os << formatv(
1332 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1333 valuePackName, resultOp.getQualCppClassName(), locToUse);
1334 return resultValue;
1335 }
1336
1337 bool usePartialResults = valuePackName != resultValue;
1338
1339 if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
1340 // For these cases (broadcastable ops, op results used both as auxiliary
1341 // values and replacement values, ops in nested patterns, auxiliary ops), we
1342 // still need to supply the result types when building the op. But because
1343 // we don't generate a builder automatically with ODS for them, it's the
1344 // developer's responsibility to make sure such a builder (with result type
1345 // deduction ability) exists. We go through the separate-parameter builder
1346 // here given that it's easier for developers to write compared to
1347 // aggregate-parameter builders.
1348 createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1349
1350 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1351 resultOp.getQualCppClassName(), locToUse);
1352 supplyValuesForOpArgs(tree, childNodeNames, depth);
1353 os << "\n );\n}\n";
1354 return resultValue;
1355 }
1356
1357 // If we are provided explicit return types, use them to build the op.
1358 // However, if depth == 0 and resultIndex >= 0, it means we are replacing
1359 // the values generated from the source pattern root op. Then we must use the
1360 // source pattern's value types to determine the value type of the generated
1361 // op here.
1362 if (depth == 0 && resultIndex >= 0 && tail.returnType)
1363 PrintFatalError(loc, "Cannot specify explicit return types in an op whose "
1364 "return values replace the source pattern's root op");
1365
1366 // First prepare local variables for op arguments used in builder call.
1367 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1368
1369 // Then prepare the result types. We need to specify the types for all
1370 // results.
1371 os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
1372 "(void)tblgen_types;\n");
1373 int numResults = resultOp.getNumResults();
1374 if (tail.returnType) {
1375 auto numRetTys = tail.returnType.getNumArgs();
1376 for (int i = 0; i < numRetTys; ++i) {
1377 auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1);
1378 os << "tblgen_types.push_back(" << varName << ");\n";
1379 }
1380 } else {
1381 if (numResults != 0) {
1382 // Copy the result types from the source pattern.
1383 for (int i = 0; i < numResults; ++i)
1384 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1385 " tblgen_types.push_back(v.getType());\n}\n",
1386 resultIndex + i);
1387 }
1388 }
1389 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1390 "tblgen_values, tblgen_attrs);\n",
1391 valuePackName, resultOp.getQualCppClassName(), locToUse);
1392 os.unindent() << "}\n";
1393 return resultValue;
1394 }
1395
createSeparateLocalVarsForOpArgs(DagNode node,ChildNodeIndexNameMap & childNodeNames)1396 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1397 DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1398 Operator &resultOp = node.getDialectOp(opMap);
1399
1400 // Now prepare operands used for building this op:
1401 // * If the operand is non-variadic, we create a `Value` local variable.
1402 // * If the operand is variadic, we create a `SmallVector<Value>` local
1403 // variable.
1404
1405 int valueIndex = 0; // An index for uniquing local variable names.
1406 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1407 const auto *operand =
1408 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
1409 // We do not need special handling for attributes.
1410 if (!operand)
1411 continue;
1412
1413 raw_indented_ostream::DelimitedScope scope(os);
1414 std::string varName;
1415 if (operand->isVariadic()) {
1416 varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1417 os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
1418 std::string range;
1419 if (node.isNestedDagArg(argIndex)) {
1420 range = childNodeNames[argIndex];
1421 } else {
1422 range = std::string(node.getArgName(argIndex));
1423 }
1424 // Resolve the symbol for all range use so that we have a uniform way of
1425 // capturing the values.
1426 range = symbolInfoMap.getValueAndRangeUse(range);
1427 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range,
1428 varName);
1429 } else {
1430 varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1431 os << formatv("::mlir::Value {0} = ", varName);
1432 if (node.isNestedDagArg(argIndex)) {
1433 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1434 } else {
1435 DagLeaf leaf = node.getArgAsLeaf(argIndex);
1436 auto symbol =
1437 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1438 if (leaf.isNativeCodeCall()) {
1439 os << std::string(
1440 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1441 } else {
1442 os << symbol;
1443 }
1444 }
1445 os << ";\n";
1446 }
1447
1448 // Update to use the newly created local variable for building the op later.
1449 childNodeNames[argIndex] = varName;
1450 }
1451 }
1452
supplyValuesForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1453 void PatternEmitter::supplyValuesForOpArgs(
1454 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1455 Operator &resultOp = node.getDialectOp(opMap);
1456 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1457 argIndex != numOpArgs; ++argIndex) {
1458 // Start each argument on its own line.
1459 os << ",\n ";
1460
1461 Argument opArg = resultOp.getArg(argIndex);
1462 // Handle the case of operand first.
1463 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
1464 if (!operand->name.empty())
1465 os << "/*" << operand->name << "=*/";
1466 os << childNodeNames.lookup(argIndex);
1467 continue;
1468 }
1469
1470 // The argument in the op definition.
1471 auto opArgName = resultOp.getArgName(argIndex);
1472 if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1473 if (!subTree.isNativeCodeCall())
1474 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1475 "for creating attribute");
1476 os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
1477 } else {
1478 auto leaf = node.getArgAsLeaf(argIndex);
1479 // The argument in the result DAG pattern.
1480 auto patArgName = node.getArgName(argIndex);
1481 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1482 // TODO: Refactor out into map to avoid recomputing these.
1483 if (!opArg.is<NamedAttribute *>())
1484 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1485 if (!patArgName.empty())
1486 os << "/*" << patArgName << "=*/";
1487 } else {
1488 os << "/*" << opArgName << "=*/";
1489 }
1490 os << handleOpArgument(leaf, patArgName);
1491 }
1492 }
1493 }
1494
createAggregateLocalVarsForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1495 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1496 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1497 Operator &resultOp = node.getDialectOp(opMap);
1498
1499 auto scope = os.scope();
1500 os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
1501 "tblgen_values; (void)tblgen_values;\n");
1502 os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
1503 "tblgen_attrs; (void)tblgen_attrs;\n");
1504
1505 const char *addAttrCmd =
1506 "if (auto tmpAttr = {1}) {\n"
1507 " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
1508 "tmpAttr);\n}\n";
1509 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1510 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1511 // The argument in the op definition.
1512 auto opArgName = resultOp.getArgName(argIndex);
1513 if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1514 if (!subTree.isNativeCodeCall())
1515 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1516 "for creating attribute");
1517 os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
1518 } else {
1519 auto leaf = node.getArgAsLeaf(argIndex);
1520 // The argument in the result DAG pattern.
1521 auto patArgName = node.getArgName(argIndex);
1522 os << formatv(addAttrCmd, opArgName,
1523 handleOpArgument(leaf, patArgName));
1524 }
1525 continue;
1526 }
1527
1528 const auto *operand =
1529 resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1530 std::string varName;
1531 if (operand->isVariadic()) {
1532 std::string range;
1533 if (node.isNestedDagArg(argIndex)) {
1534 range = childNodeNames.lookup(argIndex);
1535 } else {
1536 range = std::string(node.getArgName(argIndex));
1537 }
1538 // Resolve the symbol for all range use so that we have a uniform way of
1539 // capturing the values.
1540 range = symbolInfoMap.getValueAndRangeUse(range);
1541 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
1542 range);
1543 } else {
1544 os << formatv("tblgen_values.push_back(");
1545 if (node.isNestedDagArg(argIndex)) {
1546 os << symbolInfoMap.getValueAndRangeUse(
1547 childNodeNames.lookup(argIndex));
1548 } else {
1549 DagLeaf leaf = node.getArgAsLeaf(argIndex);
1550 if (leaf.isConstantAttr())
1551 // TODO: Use better location
1552 PrintFatalError(
1553 loc,
1554 "attribute found where value was expected, if attempting to use "
1555 "constant value, construct a constant op with given attribute "
1556 "instead");
1557
1558 auto symbol =
1559 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1560 if (leaf.isNativeCodeCall()) {
1561 os << std::string(
1562 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1563 } else {
1564 os << symbol;
1565 }
1566 }
1567 os << ");\n";
1568 }
1569 }
1570 }
1571
StaticMatcherHelper(RecordOperatorMap & mapper)1572 StaticMatcherHelper::StaticMatcherHelper(RecordOperatorMap &mapper)
1573 : opMap(mapper) {}
1574
populateStaticMatchers(raw_ostream & os)1575 void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
1576 // PatternEmitter will use the static matcher if there's one generated. To
1577 // ensure that all the dependent static matchers are generated before emitting
1578 // the matching logic of the DagNode, we use topological order to achieve it.
1579 for (auto &dagInfo : topologicalOrder) {
1580 DagNode node = dagInfo.first;
1581 if (!useStaticMatcher(node))
1582 continue;
1583
1584 std::string funcName =
1585 formatv("static_dag_matcher_{0}", staticMatcherCounter++);
1586 assert(matcherNames.find(node) == matcherNames.end());
1587 PatternEmitter(dagInfo.second, &opMap, os, *this)
1588 .emitStaticMatcher(node, funcName);
1589 matcherNames[node] = funcName;
1590 }
1591 }
1592
addPattern(Record * record)1593 void StaticMatcherHelper::addPattern(Record *record) {
1594 Pattern pat(record, &opMap);
1595
1596 // While generating the function body of the DAG matcher, it may depends on
1597 // other DAG matchers. To ensure the dependent matchers are ready, we compute
1598 // the topological order for all the DAGs and emit the DAG matchers in this
1599 // order.
1600 llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
1601 ++refStats[node];
1602
1603 if (refStats[node] != 1)
1604 return;
1605
1606 for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
1607 if (DagNode sibling = node.getArgAsNestedDag(i))
1608 dfs(sibling);
1609
1610 topologicalOrder.push_back(std::make_pair(node, record));
1611 };
1612
1613 dfs(pat.getSourcePattern());
1614 }
1615
emitRewriters(const RecordKeeper & recordKeeper,raw_ostream & os)1616 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1617 emitSourceFileHeader("Rewriters", os);
1618
1619 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1620
1621 // We put the map here because it can be shared among multiple patterns.
1622 RecordOperatorMap recordOpMap;
1623
1624 // Exam all the patterns and generate static matcher for the duplicated
1625 // DagNode.
1626 StaticMatcherHelper staticMatcher(recordOpMap);
1627 for (Record *p : patterns)
1628 staticMatcher.addPattern(p);
1629 staticMatcher.populateStaticMatchers(os);
1630
1631 std::vector<std::string> rewriterNames;
1632 rewriterNames.reserve(patterns.size());
1633
1634 std::string baseRewriterName = "GeneratedConvert";
1635 int rewriterIndex = 0;
1636
1637 for (Record *p : patterns) {
1638 std::string name;
1639 if (p->isAnonymous()) {
1640 // If no name is provided, ensure unique rewriter names simply by
1641 // appending unique suffix.
1642 name = baseRewriterName + llvm::utostr(rewriterIndex++);
1643 } else {
1644 name = std::string(p->getName());
1645 }
1646 LLVM_DEBUG(llvm::dbgs()
1647 << "=== start generating pattern '" << name << "' ===\n");
1648 PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name);
1649 LLVM_DEBUG(llvm::dbgs()
1650 << "=== done generating pattern '" << name << "' ===\n");
1651 rewriterNames.push_back(std::move(name));
1652 }
1653
1654 // Emit function to add the generated matchers to the pattern list.
1655 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
1656 "::mlir::RewritePatternSet &patterns) {\n";
1657 for (const auto &name : rewriterNames) {
1658 os << " patterns.add<" << name << ">(patterns.getContext());\n";
1659 }
1660 os << "}\n";
1661 }
1662
1663 static mlir::GenRegistration
1664 genRewriters("gen-rewriters", "Generate pattern rewriters",
__anon78740d160802(const RecordKeeper &records, raw_ostream &os) 1665 [](const RecordKeeper &records, raw_ostream &os) {
1666 emitRewriters(records, os);
1667 return false;
1668 });
1669