1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/TableGen/Pattern.h"
19 #include "mlir/TableGen/Predicate.h"
20 #include "mlir/TableGen/Type.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatAdapters.h"
26 #include "llvm/Support/PrettyStackTrace.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Main.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
32
33 using namespace mlir;
34 using namespace mlir::tblgen;
35
36 using llvm::formatv;
37 using llvm::Record;
38 using llvm::RecordKeeper;
39
40 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
41
42 namespace llvm {
43 template <>
44 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
formatllvm::format_provider45 static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
46 raw_ostream &os, StringRef style) {
47 os << v.first << ":" << v.second;
48 }
49 };
50 } // end namespace llvm
51
52 //===----------------------------------------------------------------------===//
53 // PatternEmitter
54 //===----------------------------------------------------------------------===//
55
56 namespace {
57 class PatternEmitter {
58 public:
59 PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
60
61 // Emits the mlir::RewritePattern struct named `rewriteName`.
62 void emit(StringRef rewriteName);
63
64 private:
65 // Emits the code for matching ops.
66 void emitMatchLogic(DagNode tree, StringRef opName);
67
68 // Emits the code for rewriting ops.
69 void emitRewriteLogic();
70
71 //===--------------------------------------------------------------------===//
72 // Match utilities
73 //===--------------------------------------------------------------------===//
74
75 // Emits C++ statements for matching the DAG structure.
76 void emitMatch(DagNode tree, StringRef name, int depth);
77
78 // Emits C++ statements for matching using a native code call.
79 void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
80
81 // Emits C++ statements for matching the op constrained by the given DAG
82 // `tree` returning the op's variable name.
83 void emitOpMatch(DagNode tree, StringRef opName, int depth);
84
85 // Emits C++ statements for matching the `argIndex`-th argument of the given
86 // DAG `tree` as an operand. operandIndex is the index in the DAG excluding
87 // the preceding attributes.
88 void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
89 int operandIndex, int depth);
90
91 // Emits C++ statements for matching the `argIndex`-th argument of the given
92 // DAG `tree` as an attribute.
93 void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
94 int depth);
95
96 // Emits C++ for checking a match with a corresponding match failure
97 // diagnostic.
98 void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
99 const llvm::formatv_object_base &failureFmt);
100
101 // Emits C++ for checking a match with a corresponding match failure
102 // diagnostics.
103 void emitMatchCheck(StringRef opName, const std::string &matchStr,
104 const std::string &failureStr);
105
106 //===--------------------------------------------------------------------===//
107 // Rewrite utilities
108 //===--------------------------------------------------------------------===//
109
110 // The entry point for handling a result pattern rooted at `resultTree`. This
111 // method dispatches to concrete handlers according to `resultTree`'s kind and
112 // returns a symbol representing the whole value pack. Callers are expected to
113 // further resolve the symbol according to the specific use case.
114 //
115 // `depth` is the nesting level of `resultTree`; 0 means top-level result
116 // pattern. For top-level result pattern, `resultIndex` indicates which result
117 // of the matched root op this pattern is intended to replace, which can be
118 // used to deduce the result type of the op generated from this result
119 // pattern.
120 std::string handleResultPattern(DagNode resultTree, int resultIndex,
121 int depth);
122
123 // Emits the C++ statement to replace the matched DAG with a value built via
124 // calling native C++ code.
125 std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
126
127 // Returns the symbol of the old value serving as the replacement.
128 StringRef handleReplaceWithValue(DagNode tree);
129
130 // Returns the location value to use.
131 std::pair<bool, std::string> getLocation(DagNode tree);
132
133 // Returns the location value to use.
134 std::string handleLocationDirective(DagNode tree);
135
136 // Emits the C++ statement to build a new op out of the given DAG `tree` and
137 // returns the variable name that this op is assigned to. If the root op in
138 // DAG `tree` has a specified name, the created op will be assigned to a
139 // variable of the given name. Otherwise, a unique name will be used as the
140 // result value name.
141 std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
142
143 using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
144
145 // Emits a local variable for each value and attribute to be used for creating
146 // an op.
147 void createSeparateLocalVarsForOpArgs(DagNode node,
148 ChildNodeIndexNameMap &childNodeNames);
149
150 // Emits the concrete arguments used to call an op's builder.
151 void supplyValuesForOpArgs(DagNode node,
152 const ChildNodeIndexNameMap &childNodeNames,
153 int depth);
154
155 // Emits the local variables for holding all values as a whole and all named
156 // attributes as a whole to be used for creating an op.
157 void createAggregateLocalVarsForOpArgs(
158 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
159
160 // Returns the C++ expression to construct a constant attribute of the given
161 // `value` for the given attribute kind `attr`.
162 std::string handleConstantAttr(Attribute attr, StringRef value);
163
164 // Returns the C++ expression to build an argument from the given DAG `leaf`.
165 // `patArgName` is used to bound the argument to the source pattern.
166 std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
167
168 //===--------------------------------------------------------------------===//
169 // General utilities
170 //===--------------------------------------------------------------------===//
171
172 // Collects all of the operations within the given dag tree.
173 void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
174
175 // Returns a unique symbol for a local variable of the given `op`.
176 std::string getUniqueSymbol(const Operator *op);
177
178 //===--------------------------------------------------------------------===//
179 // Symbol utilities
180 //===--------------------------------------------------------------------===//
181
182 // Returns how many static values the given DAG `node` correspond to.
183 int getNodeValueCount(DagNode node);
184
185 private:
186 // Pattern instantiation location followed by the location of multiclass
187 // prototypes used. This is intended to be used as a whole to
188 // PrintFatalError() on errors.
189 ArrayRef<llvm::SMLoc> loc;
190
191 // Op's TableGen Record to wrapper object.
192 RecordOperatorMap *opMap;
193
194 // Handy wrapper for pattern being emitted.
195 Pattern pattern;
196
197 // Map for all bound symbols' info.
198 SymbolInfoMap symbolInfoMap;
199
200 // The next unused ID for newly created values.
201 unsigned nextValueId;
202
203 raw_indented_ostream os;
204
205 // Format contexts containing placeholder substitutions.
206 FmtContext fmtCtx;
207
208 // Number of op processed.
209 int opCounter = 0;
210 };
211 } // end anonymous namespace
212
PatternEmitter(Record * pat,RecordOperatorMap * mapper,raw_ostream & os)213 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
214 raw_ostream &os)
215 : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
216 symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
217 fmtCtx.withBuilder("rewriter");
218 }
219
handleConstantAttr(Attribute attr,StringRef value)220 std::string PatternEmitter::handleConstantAttr(Attribute attr,
221 StringRef value) {
222 if (!attr.isConstBuildable())
223 PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
224 " does not have the 'constBuilderCall' field");
225
226 // TODO: Verify the constants here
227 return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
228 }
229
230 // Helper function to match patterns.
emitMatch(DagNode tree,StringRef name,int depth)231 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
232 if (tree.isNativeCodeCall()) {
233 emitNativeCodeMatch(tree, name, depth);
234 return;
235 }
236
237 if (tree.isOperation()) {
238 emitOpMatch(tree, name, depth);
239 return;
240 }
241
242 PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
243 }
244
245 // Helper function to match patterns.
emitNativeCodeMatch(DagNode tree,StringRef opName,int depth)246 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
247 int depth) {
248 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
249 LLVM_DEBUG(tree.print(llvm::dbgs()));
250 LLVM_DEBUG(llvm::dbgs() << '\n');
251
252 // TODO(suderman): iterate through arguments, determine their types, output
253 // names.
254 SmallVector<std::string, 8> capture;
255
256 raw_indented_ostream::DelimitedScope scope(os);
257
258 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
259 std::string argName = formatv("arg{0}_{1}", depth, i);
260 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
261 os << "Value " << argName << ";\n";
262 } else {
263 auto leaf = tree.getArgAsLeaf(i);
264 if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
265 os << "Attribute " << argName << ";\n";
266 } else {
267 os << "Value " << argName << ";\n";
268 }
269 }
270
271 capture.push_back(std::move(argName));
272 }
273
274 bool hasLocationDirective;
275 std::string locToUse;
276 std::tie(hasLocationDirective, locToUse) = getLocation(tree);
277
278 auto fmt = tree.getNativeCodeTemplate();
279 if (fmt.count("$_self") != 1)
280 PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
281 "passing the defining Operation");
282
283 auto nativeCodeCall = std::string(tgfmt(
284 fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), capture));
285
286 emitMatchCheck(opName, formatv("!failed({0})", nativeCodeCall),
287 formatv("\"{0} return failure\"", nativeCodeCall));
288
289 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
290 auto name = tree.getArgName(i);
291 if (!name.empty() && name != "_") {
292 os << formatv("{0} = {1};\n", name, capture[i]);
293 }
294 }
295
296 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
297 std::string argName = capture[i];
298
299 // Handle nested DAG construct first
300 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
301 PrintFatalError(
302 loc, formatv("Matching nested tree in NativeCodecall not support for "
303 "{0} as arg {1}",
304 argName, i));
305 }
306
307 DagLeaf leaf = tree.getArgAsLeaf(i);
308
309 // The parameter for native function doesn't bind any constraints.
310 if (leaf.isUnspecified())
311 continue;
312
313 auto constraint = leaf.getAsConstraint();
314
315 std::string self;
316 if (leaf.isAttrMatcher() || leaf.isConstantAttr())
317 self = argName;
318 else
319 self = formatv("{0}.getType()", argName);
320 emitMatchCheck(
321 opName,
322 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
323 formatv("\"operand {0} of native code call '{1}' failed to satisfy "
324 "constraint: "
325 "'{2}'\"",
326 i, tree.getNativeCodeTemplate(), constraint.getSummary()));
327 }
328
329 LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
330 }
331
332 // Helper function to match patterns.
emitOpMatch(DagNode tree,StringRef opName,int depth)333 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
334 Operator &op = tree.getDialectOp(opMap);
335 LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
336 << op.getOperationName() << "' at depth " << depth
337 << '\n');
338
339 std::string castedName = formatv("castedOp{0}", depth);
340 os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
341 "(void){0};\n",
342 castedName, opName, op.getQualCppClassName());
343
344 // Skip the operand matching at depth 0 as the pattern rewriter already does.
345 if (depth != 0)
346 emitMatchCheck(opName, /*matchStr=*/castedName,
347 formatv("\"{0} is not {1} type\"", castedName,
348 op.getQualCppClassName()));
349
350 if (tree.getNumArgs() != op.getNumArgs())
351 PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
352 "pattern vs. {2} in definition",
353 op.getOperationName(), tree.getNumArgs(),
354 op.getNumArgs()));
355
356 // If the operand's name is set, set to that variable.
357 auto name = tree.getSymbol();
358 if (!name.empty())
359 os << formatv("{0} = {1};\n", name, castedName);
360
361 for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) {
362 auto opArg = op.getArg(i);
363 std::string argName = formatv("op{0}", depth + 1);
364
365 // Handle nested DAG construct first
366 if (DagNode argTree = tree.getArgAsNestedDag(i)) {
367 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
368 if (operand->isVariableLength()) {
369 auto error = formatv("use nested DAG construct to match op {0}'s "
370 "variadic operand #{1} unsupported now",
371 op.getOperationName(), i);
372 PrintFatalError(loc, error);
373 }
374 }
375 os << "{\n";
376
377 // Attributes don't count for getODSOperands.
378 // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
379 os.indent() << formatv(
380 "auto *{0} = "
381 "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
382 argName, castedName, nextOperand);
383 // Null check of operand's definingOp
384 emitMatchCheck(castedName, /*matchStr=*/argName,
385 formatv("\"Operand {0} of {1} has null definingOp\"",
386 nextOperand++, castedName));
387 emitMatch(argTree, argName, depth + 1);
388 os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
389 os.unindent() << "}\n";
390 continue;
391 }
392
393 // Next handle DAG leaf: operand or attribute
394 if (opArg.is<NamedTypeConstraint *>()) {
395 // emitOperandMatch's argument indexing counts attributes.
396 emitOperandMatch(tree, castedName, i, nextOperand, depth);
397 ++nextOperand;
398 } else if (opArg.is<NamedAttribute *>()) {
399 emitAttributeMatch(tree, opName, i, depth);
400 } else {
401 PrintFatalError(loc, "unhandled case when matching op");
402 }
403 }
404 LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
405 << op.getOperationName() << "' at depth " << depth
406 << '\n');
407 }
408
emitOperandMatch(DagNode tree,StringRef opName,int argIndex,int operandIndex,int depth)409 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
410 int argIndex, int operandIndex,
411 int depth) {
412 Operator &op = tree.getDialectOp(opMap);
413 auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
414 auto matcher = tree.getArgAsLeaf(argIndex);
415
416 // If a constraint is specified, we need to generate C++ statements to
417 // check the constraint.
418 if (!matcher.isUnspecified()) {
419 if (!matcher.isOperandMatcher()) {
420 PrintFatalError(
421 loc, formatv("the {1}-th argument of op '{0}' should be an operand",
422 op.getOperationName(), argIndex + 1));
423 }
424
425 // Only need to verify if the matcher's type is different from the one
426 // of op definition.
427 Constraint constraint = matcher.getAsConstraint();
428 if (operand->constraint != constraint) {
429 if (operand->isVariableLength()) {
430 auto error = formatv(
431 "further constrain op {0}'s variadic operand #{1} unsupported now",
432 op.getOperationName(), argIndex);
433 PrintFatalError(loc, error);
434 }
435 auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
436 opName, operandIndex);
437 emitMatchCheck(
438 opName,
439 tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
440 formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
441 "'{2}'\"",
442 operand - op.operand_begin(), op.getOperationName(),
443 constraint.getSummary()));
444 }
445 }
446
447 // Capture the value
448 auto name = tree.getArgName(argIndex);
449 // `$_` is a special symbol to ignore op argument matching.
450 if (!name.empty() && name != "_") {
451 // We need to subtract the number of attributes before this operand to get
452 // the index in the operand list.
453 auto numPrevAttrs = std::count_if(
454 op.arg_begin(), op.arg_begin() + argIndex,
455 [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
456
457 auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex);
458 os << formatv("{0} = {1}.getODSOperands({2});\n",
459 res->second.getVarName(name), opName,
460 argIndex - numPrevAttrs);
461 }
462 }
463
emitAttributeMatch(DagNode tree,StringRef opName,int argIndex,int depth)464 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
465 int argIndex, int depth) {
466 Operator &op = tree.getDialectOp(opMap);
467 auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
468 const auto &attr = namedAttr->attr;
469
470 os << "{\n";
471 os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
472 "(void)tblgen_attr;\n",
473 opName, attr.getStorageType(), namedAttr->name);
474
475 // TODO: This should use getter method to avoid duplication.
476 if (attr.hasDefaultValue()) {
477 os << "if (!tblgen_attr) tblgen_attr = "
478 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
479 attr.getDefaultValue()))
480 << ";\n";
481 } else if (attr.isOptional()) {
482 // For a missing attribute that is optional according to definition, we
483 // should just capture a mlir::Attribute() to signal the missing state.
484 // That is precisely what getAttr() returns on missing attributes.
485 } else {
486 emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
487 formatv("\"expected op '{0}' to have attribute '{1}' "
488 "of type '{2}'\"",
489 op.getOperationName(), namedAttr->name,
490 attr.getStorageType()));
491 }
492
493 auto matcher = tree.getArgAsLeaf(argIndex);
494 if (!matcher.isUnspecified()) {
495 if (!matcher.isAttrMatcher()) {
496 PrintFatalError(
497 loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
498 op.getOperationName(), argIndex + 1));
499 }
500
501 // If a constraint is specified, we need to generate C++ statements to
502 // check the constraint.
503 emitMatchCheck(
504 opName,
505 tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
506 formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
507 "{2}\"",
508 op.getOperationName(), namedAttr->name,
509 matcher.getAsConstraint().getSummary()));
510 }
511
512 // Capture the value
513 auto name = tree.getArgName(argIndex);
514 // `$_` is a special symbol to ignore op argument matching.
515 if (!name.empty() && name != "_") {
516 os << formatv("{0} = tblgen_attr;\n", name);
517 }
518
519 os.unindent() << "}\n";
520 }
521
emitMatchCheck(StringRef opName,const FmtObjectBase & matchFmt,const llvm::formatv_object_base & failureFmt)522 void PatternEmitter::emitMatchCheck(
523 StringRef opName, const FmtObjectBase &matchFmt,
524 const llvm::formatv_object_base &failureFmt) {
525 emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
526 }
527
emitMatchCheck(StringRef opName,const std::string & matchStr,const std::string & failureStr)528 void PatternEmitter::emitMatchCheck(StringRef opName,
529 const std::string &matchStr,
530 const std::string &failureStr) {
531
532 os << "if (!(" << matchStr << "))";
533 os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
534 << ", [&](::mlir::Diagnostic &diag) {\n diag << "
535 << failureStr << ";\n});";
536 }
537
emitMatchLogic(DagNode tree,StringRef opName)538 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
539 LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
540 int depth = 0;
541 emitMatch(tree, opName, depth);
542
543 for (auto &appliedConstraint : pattern.getConstraints()) {
544 auto &constraint = appliedConstraint.constraint;
545 auto &entities = appliedConstraint.entities;
546
547 auto condition = constraint.getConditionTemplate();
548 if (isa<TypeConstraint>(constraint)) {
549 auto self = formatv("({0}.getType())",
550 symbolInfoMap.getValueAndRangeUse(entities.front()));
551 emitMatchCheck(
552 opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
553 formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
554 entities.front(), constraint.getSummary()));
555
556 } else if (isa<AttrConstraint>(constraint)) {
557 PrintFatalError(
558 loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
559 } else {
560 // TODO: replace formatv arguments with the exact specified
561 // args.
562 if (entities.size() > 4) {
563 PrintFatalError(loc, "only support up to 4-entity constraints now");
564 }
565 SmallVector<std::string, 4> names;
566 int i = 0;
567 for (int e = entities.size(); i < e; ++i)
568 names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
569 std::string self = appliedConstraint.self;
570 if (!self.empty())
571 self = symbolInfoMap.getValueAndRangeUse(self);
572 for (; i < 4; ++i)
573 names.push_back("<unused>");
574 emitMatchCheck(opName,
575 tgfmt(condition, &fmtCtx.withSelf(self), names[0],
576 names[1], names[2], names[3]),
577 formatv("\"entities '{0}' failed to satisfy constraint: "
578 "{1}\"",
579 llvm::join(entities, ", "),
580 constraint.getSummary()));
581 }
582 }
583
584 // Some of the operands could be bound to the same symbol name, we need
585 // to enforce equality constraint on those.
586 // TODO: we should be able to emit equality checks early
587 // and short circuit unnecessary work if vars are not equal.
588 for (auto symbolInfoIt = symbolInfoMap.begin();
589 symbolInfoIt != symbolInfoMap.end();) {
590 auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
591 auto startRange = range.first;
592 auto endRange = range.second;
593
594 auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
595 for (++startRange; startRange != endRange; ++startRange) {
596 auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
597 emitMatchCheck(
598 opName,
599 formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
600 formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
601 secondOperand));
602 }
603
604 symbolInfoIt = endRange;
605 }
606
607 LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
608 }
609
collectOps(DagNode tree,llvm::SmallPtrSetImpl<const Operator * > & ops)610 void PatternEmitter::collectOps(DagNode tree,
611 llvm::SmallPtrSetImpl<const Operator *> &ops) {
612 // Check if this tree is an operation.
613 if (tree.isOperation()) {
614 const Operator &op = tree.getDialectOp(opMap);
615 LLVM_DEBUG(llvm::dbgs()
616 << "found operation " << op.getOperationName() << '\n');
617 ops.insert(&op);
618 }
619
620 // Recurse the arguments of the tree.
621 for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
622 if (auto child = tree.getArgAsNestedDag(i))
623 collectOps(child, ops);
624 }
625
emit(StringRef rewriteName)626 void PatternEmitter::emit(StringRef rewriteName) {
627 // Get the DAG tree for the source pattern.
628 DagNode sourceTree = pattern.getSourcePattern();
629
630 const Operator &rootOp = pattern.getSourceRootOp();
631 auto rootName = rootOp.getOperationName();
632
633 // Collect the set of result operations.
634 llvm::SmallPtrSet<const Operator *, 4> resultOps;
635 LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
636 for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
637 collectOps(pattern.getResultPattern(i), resultOps);
638 }
639 LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
640
641 // Emit RewritePattern for Pattern.
642 auto locs = pattern.getLocation();
643 os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
644 make_range(locs.rbegin(), locs.rend()));
645 os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
646 {0}(::mlir::MLIRContext *context)
647 : ::mlir::RewritePattern("{1}", {2}, context, {{)",
648 rewriteName, rootName, pattern.getBenefit());
649 // Sort result operators by name.
650 llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
651 resultOps.end());
652 llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
653 return lhs->getOperationName() < rhs->getOperationName();
654 });
655 llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
656 os << '"' << op->getOperationName() << '"';
657 });
658 os << "}) {}\n";
659
660 // Emit matchAndRewrite() function.
661 {
662 auto classScope = os.scope();
663 os.reindent(R"(
664 ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
665 ::mlir::PatternRewriter &rewriter) const override {)")
666 << '\n';
667 {
668 auto functionScope = os.scope();
669
670 // Register all symbols bound in the source pattern.
671 pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
672
673 LLVM_DEBUG(llvm::dbgs()
674 << "start creating local variables for capturing matches\n");
675 os << "// Variables for capturing values and attributes used while "
676 "creating ops\n";
677 // Create local variables for storing the arguments and results bound
678 // to symbols.
679 for (const auto &symbolInfoPair : symbolInfoMap) {
680 const auto &symbol = symbolInfoPair.first;
681 const auto &info = symbolInfoPair.second;
682
683 os << info.getVarDecl(symbol);
684 }
685 // TODO: capture ops with consistent numbering so that it can be
686 // reused for fused loc.
687 os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
688 pattern.getSourcePattern().getNumOps());
689 LLVM_DEBUG(llvm::dbgs()
690 << "done creating local variables for capturing matches\n");
691
692 os << "// Match\n";
693 os << "tblgen_ops[0] = op0;\n";
694 emitMatchLogic(sourceTree, "op0");
695
696 os << "\n// Rewrite\n";
697 emitRewriteLogic();
698
699 os << "return ::mlir::success();\n";
700 }
701 os << "};\n";
702 }
703 os << "};\n\n";
704 }
705
emitRewriteLogic()706 void PatternEmitter::emitRewriteLogic() {
707 LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
708 const Operator &rootOp = pattern.getSourceRootOp();
709 int numExpectedResults = rootOp.getNumResults();
710 int numResultPatterns = pattern.getNumResultPatterns();
711
712 // First register all symbols bound to ops generated in result patterns.
713 pattern.collectResultPatternBoundSymbols(symbolInfoMap);
714
715 // Only the last N static values generated are used to replace the matched
716 // root N-result op. We need to calculate the starting index (of the results
717 // of the matched op) each result pattern is to replace.
718 SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
719 // If we don't need to replace any value at all, set the replacement starting
720 // index as the number of result patterns so we skip all of them when trying
721 // to replace the matched op's results.
722 int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
723 for (int i = numResultPatterns - 1; i >= 0; --i) {
724 auto numValues = getNodeValueCount(pattern.getResultPattern(i));
725 offsets[i] = offsets[i + 1] - numValues;
726 if (offsets[i] == 0) {
727 if (replStartIndex == -1)
728 replStartIndex = i;
729 } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
730 auto error = formatv(
731 "cannot use the same multi-result op '{0}' to generate both "
732 "auxiliary values and values to be used for replacing the matched op",
733 pattern.getResultPattern(i).getSymbol());
734 PrintFatalError(loc, error);
735 }
736 }
737
738 if (offsets.front() > 0) {
739 const char error[] = "no enough values generated to replace the matched op";
740 PrintFatalError(loc, error);
741 }
742
743 os << "auto odsLoc = rewriter.getFusedLoc({";
744 for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
745 os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
746 }
747 os << "}); (void)odsLoc;\n";
748
749 // Process auxiliary result patterns.
750 for (int i = 0; i < replStartIndex; ++i) {
751 DagNode resultTree = pattern.getResultPattern(i);
752 auto val = handleResultPattern(resultTree, offsets[i], 0);
753 // Normal op creation will be streamed to `os` by the above call; but
754 // NativeCodeCall will only be materialized to `os` if it is used. Here
755 // we are handling auxiliary patterns so we want the side effect even if
756 // NativeCodeCall is not replacing matched root op's results.
757 if (resultTree.isNativeCodeCall() &&
758 resultTree.getNumReturnsOfNativeCode() == 0)
759 os << val << ";\n";
760 }
761
762 if (numExpectedResults == 0) {
763 assert(replStartIndex >= numResultPatterns &&
764 "invalid auxiliary vs. replacement pattern division!");
765 // No result to replace. Just erase the op.
766 os << "rewriter.eraseOp(op0);\n";
767 } else {
768 // Process replacement result patterns.
769 os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
770 for (int i = replStartIndex; i < numResultPatterns; ++i) {
771 DagNode resultTree = pattern.getResultPattern(i);
772 auto val = handleResultPattern(resultTree, offsets[i], 0);
773 os << "\n";
774 // Resolve each symbol for all range use so that we can loop over them.
775 // We need an explicit cast to `SmallVector` to capture the cases where
776 // `{0}` resolves to an `Operation::result_range` as well as cases that
777 // are not iterable (e.g. vector that gets wrapped in additional braces by
778 // RewriterGen).
779 // TODO: Revisit the need for materializing a vector.
780 os << symbolInfoMap.getAllRangeUse(
781 val,
782 "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
783 " tblgen_repl_values.push_back(v);\n}\n",
784 "\n");
785 }
786 os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
787 }
788
789 LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
790 }
791
getUniqueSymbol(const Operator * op)792 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
793 return std::string(
794 formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
795 }
796
handleResultPattern(DagNode resultTree,int resultIndex,int depth)797 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
798 int resultIndex, int depth) {
799 LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
800 LLVM_DEBUG(resultTree.print(llvm::dbgs()));
801 LLVM_DEBUG(llvm::dbgs() << '\n');
802
803 if (resultTree.isLocationDirective()) {
804 PrintFatalError(loc,
805 "location directive can only be used with op creation");
806 }
807
808 if (resultTree.isNativeCodeCall())
809 return handleReplaceWithNativeCodeCall(resultTree, depth);
810
811 if (resultTree.isReplaceWithValue())
812 return handleReplaceWithValue(resultTree).str();
813
814 // Normal op creation.
815 auto symbol = handleOpCreation(resultTree, resultIndex, depth);
816 if (resultTree.getSymbol().empty()) {
817 // This is an op not explicitly bound to a symbol in the rewrite rule.
818 // Register the auto-generated symbol for it.
819 symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
820 }
821 return symbol;
822 }
823
handleReplaceWithValue(DagNode tree)824 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
825 assert(tree.isReplaceWithValue());
826
827 if (tree.getNumArgs() != 1) {
828 PrintFatalError(
829 loc, "replaceWithValue directive must take exactly one argument");
830 }
831
832 if (!tree.getSymbol().empty()) {
833 PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
834 }
835
836 return tree.getArgName(0);
837 }
838
handleLocationDirective(DagNode tree)839 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
840 assert(tree.isLocationDirective());
841 auto lookUpArgLoc = [this, &tree](int idx) {
842 const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
843 return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
844 };
845
846 if (tree.getNumArgs() == 0)
847 llvm::PrintFatalError(
848 "At least one argument to location directive required");
849
850 if (!tree.getSymbol().empty())
851 PrintFatalError(loc, "cannot bind symbol to location");
852
853 if (tree.getNumArgs() == 1) {
854 DagLeaf leaf = tree.getArgAsLeaf(0);
855 if (leaf.isStringAttr())
856 return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"))",
857 leaf.getStringAttr())
858 .str();
859 return lookUpArgLoc(0);
860 }
861
862 std::string ret;
863 llvm::raw_string_ostream os(ret);
864 std::string strAttr;
865 os << "rewriter.getFusedLoc({";
866 bool first = true;
867 for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
868 DagLeaf leaf = tree.getArgAsLeaf(i);
869 // Handle the optional string value.
870 if (leaf.isStringAttr()) {
871 if (!strAttr.empty())
872 llvm::PrintFatalError("Only one string attribute may be specified");
873 strAttr = leaf.getStringAttr();
874 continue;
875 }
876 os << (first ? "" : ", ") << lookUpArgLoc(i);
877 first = false;
878 }
879 os << "}";
880 if (!strAttr.empty()) {
881 os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
882 }
883 os << ")";
884 return os.str();
885 }
886
handleOpArgument(DagLeaf leaf,StringRef patArgName)887 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
888 StringRef patArgName) {
889 if (leaf.isStringAttr())
890 PrintFatalError(loc, "raw string not supported as argument");
891 if (leaf.isConstantAttr()) {
892 auto constAttr = leaf.getAsConstantAttr();
893 return handleConstantAttr(constAttr.getAttribute(),
894 constAttr.getConstantValue());
895 }
896 if (leaf.isEnumAttrCase()) {
897 auto enumCase = leaf.getAsEnumAttrCase();
898 if (enumCase.isStrCase())
899 return handleConstantAttr(enumCase, enumCase.getSymbol());
900 // This is an enum case backed by an IntegerAttr. We need to get its value
901 // to build the constant.
902 std::string val = std::to_string(enumCase.getValue());
903 return handleConstantAttr(enumCase, val);
904 }
905
906 LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
907 auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
908 if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
909 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
910 << "' (via symbol ref)\n");
911 return argName;
912 }
913 if (leaf.isNativeCodeCall()) {
914 auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
915 LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
916 << "' (via NativeCodeCall)\n");
917 return std::string(repl);
918 }
919 PrintFatalError(loc, "unhandled case when rewriting op");
920 }
921
handleReplaceWithNativeCodeCall(DagNode tree,int depth)922 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
923 int depth) {
924 LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
925 LLVM_DEBUG(tree.print(llvm::dbgs()));
926 LLVM_DEBUG(llvm::dbgs() << '\n');
927
928 auto fmt = tree.getNativeCodeTemplate();
929
930 SmallVector<std::string, 16> attrs;
931
932 bool hasLocationDirective;
933 std::string locToUse;
934 std::tie(hasLocationDirective, locToUse) = getLocation(tree);
935
936 for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
937 if (tree.isNestedDagArg(i)) {
938 attrs.push_back(
939 handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
940 } else {
941 attrs.push_back(
942 handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
943 }
944 LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
945 << " replacement: " << attrs[i] << "\n");
946 }
947
948 std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs);
949
950 // In general, NativeCodeCall without naming binding don't need this. To
951 // ensure void helper function has been correctly labeled, i.e., use
952 // NativeCodeCallVoid, we cache the result to a local variable so that we will
953 // get a compilation error in the auto-generated file.
954 // Example.
955 // // In the td file
956 // Pat<(...), (NativeCodeCall<Foo> ...)>
957 //
958 // ---
959 //
960 // // In the auto-generated .cpp
961 // ...
962 // // Causes compilation error if Foo() returns void.
963 // auto nativeVar = Foo();
964 // ...
965 if (tree.getNumReturnsOfNativeCode() != 0) {
966 // Determine the local variable name for return value.
967 std::string varName =
968 SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
969 if (varName.empty()) {
970 varName = formatv("nativeVar_{0}", nextValueId++);
971 // Register the local variable for later uses.
972 symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
973 }
974
975 // Catch the return value of helper function.
976 os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
977
978 if (!tree.getSymbol().empty())
979 symbol = tree.getSymbol().str();
980 else
981 symbol = varName;
982 }
983
984 return symbol;
985 }
986
getNodeValueCount(DagNode node)987 int PatternEmitter::getNodeValueCount(DagNode node) {
988 if (node.isOperation()) {
989 // If the op is bound to a symbol in the rewrite rule, query its result
990 // count from the symbol info map.
991 auto symbol = node.getSymbol();
992 if (!symbol.empty()) {
993 return symbolInfoMap.getStaticValueCount(symbol);
994 }
995 // Otherwise this is an unbound op; we will use all its results.
996 return pattern.getDialectOp(node).getNumResults();
997 }
998
999 if (node.isNativeCodeCall())
1000 return node.getNumReturnsOfNativeCode();
1001
1002 return 1;
1003 }
1004
getLocation(DagNode tree)1005 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
1006 auto numPatArgs = tree.getNumArgs();
1007
1008 if (numPatArgs != 0) {
1009 if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
1010 if (lastArg.isLocationDirective()) {
1011 return std::make_pair(true, handleLocationDirective(lastArg));
1012 }
1013 }
1014
1015 // If no explicit location is given, use the default, all fused, location.
1016 return std::make_pair(false, "odsLoc");
1017 }
1018
handleOpCreation(DagNode tree,int resultIndex,int depth)1019 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1020 int depth) {
1021 LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1022 LLVM_DEBUG(tree.print(llvm::dbgs()));
1023 LLVM_DEBUG(llvm::dbgs() << '\n');
1024
1025 Operator &resultOp = tree.getDialectOp(opMap);
1026 auto numOpArgs = resultOp.getNumArgs();
1027 auto numPatArgs = tree.getNumArgs();
1028
1029 bool hasLocationDirective;
1030 std::string locToUse;
1031 std::tie(hasLocationDirective, locToUse) = getLocation(tree);
1032
1033 auto inPattern = numPatArgs - hasLocationDirective;
1034 if (numOpArgs != inPattern) {
1035 PrintFatalError(loc,
1036 formatv("resultant op '{0}' argument number mismatch: "
1037 "{1} in pattern vs. {2} in definition",
1038 resultOp.getOperationName(), inPattern, numOpArgs));
1039 }
1040
1041 // A map to collect all nested DAG child nodes' names, with operand index as
1042 // the key. This includes both bound and unbound child nodes.
1043 ChildNodeIndexNameMap childNodeNames;
1044
1045 // First go through all the child nodes who are nested DAG constructs to
1046 // create ops for them and remember the symbol names for them, so that we can
1047 // use the results in the current node. This happens in a recursive manner.
1048 for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
1049 if (auto child = tree.getArgAsNestedDag(i))
1050 childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1051 }
1052
1053 // The name of the local variable holding this op.
1054 std::string valuePackName;
1055 // The symbol for holding the result of this pattern. Note that the result of
1056 // this pattern is not necessarily the same as the variable created by this
1057 // pattern because we can use `__N` suffix to refer only a specific result if
1058 // the generated op is a multi-result op.
1059 std::string resultValue;
1060 if (tree.getSymbol().empty()) {
1061 // No symbol is explicitly bound to this op in the pattern. Generate a
1062 // unique name.
1063 valuePackName = resultValue = getUniqueSymbol(&resultOp);
1064 } else {
1065 resultValue = std::string(tree.getSymbol());
1066 // Strip the index to get the name for the value pack and use it to name the
1067 // local variable for the op.
1068 valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1069 }
1070
1071 // Create the local variable for this op.
1072 os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1073 valuePackName);
1074
1075 // Right now ODS don't have general type inference support. Except a few
1076 // special cases listed below, DRR needs to supply types for all results
1077 // when building an op.
1078 bool isSameOperandsAndResultType =
1079 resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1080 bool useFirstAttr =
1081 resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1082
1083 if (isSameOperandsAndResultType || useFirstAttr) {
1084 // We know how to deduce the result type for ops with these traits and we've
1085 // generated builders taking aggregate parameters. Use those builders to
1086 // create the ops.
1087
1088 // First prepare local variables for op arguments used in builder call.
1089 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1090
1091 // Then create the op.
1092 os.scope("", "\n}\n").os << formatv(
1093 "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1094 valuePackName, resultOp.getQualCppClassName(), locToUse);
1095 return resultValue;
1096 }
1097
1098 bool usePartialResults = valuePackName != resultValue;
1099
1100 if (usePartialResults || depth > 0 || resultIndex < 0) {
1101 // For these cases (broadcastable ops, op results used both as auxiliary
1102 // values and replacement values, ops in nested patterns, auxiliary ops), we
1103 // still need to supply the result types when building the op. But because
1104 // we don't generate a builder automatically with ODS for them, it's the
1105 // developer's responsibility to make sure such a builder (with result type
1106 // deduction ability) exists. We go through the separate-parameter builder
1107 // here given that it's easier for developers to write compared to
1108 // aggregate-parameter builders.
1109 createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1110
1111 os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1112 resultOp.getQualCppClassName(), locToUse);
1113 supplyValuesForOpArgs(tree, childNodeNames, depth);
1114 os << "\n );\n}\n";
1115 return resultValue;
1116 }
1117
1118 // If depth == 0 and resultIndex >= 0, it means we are replacing the values
1119 // generated from the source pattern root op. Then we can use the source
1120 // pattern's value types to determine the value type of the generated op
1121 // here.
1122
1123 // First prepare local variables for op arguments used in builder call.
1124 createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1125
1126 // Then prepare the result types. We need to specify the types for all
1127 // results.
1128 os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
1129 "(void)tblgen_types;\n");
1130 int numResults = resultOp.getNumResults();
1131 if (numResults != 0) {
1132 for (int i = 0; i < numResults; ++i)
1133 os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1134 " tblgen_types.push_back(v.getType());\n}\n",
1135 resultIndex + i);
1136 }
1137 os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1138 "tblgen_values, tblgen_attrs);\n",
1139 valuePackName, resultOp.getQualCppClassName(), locToUse);
1140 os.unindent() << "}\n";
1141 return resultValue;
1142 }
1143
createSeparateLocalVarsForOpArgs(DagNode node,ChildNodeIndexNameMap & childNodeNames)1144 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1145 DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1146 Operator &resultOp = node.getDialectOp(opMap);
1147
1148 // Now prepare operands used for building this op:
1149 // * If the operand is non-variadic, we create a `Value` local variable.
1150 // * If the operand is variadic, we create a `SmallVector<Value>` local
1151 // variable.
1152
1153 int valueIndex = 0; // An index for uniquing local variable names.
1154 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1155 const auto *operand =
1156 resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
1157 // We do not need special handling for attributes.
1158 if (!operand)
1159 continue;
1160
1161 raw_indented_ostream::DelimitedScope scope(os);
1162 std::string varName;
1163 if (operand->isVariadic()) {
1164 varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1165 os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
1166 std::string range;
1167 if (node.isNestedDagArg(argIndex)) {
1168 range = childNodeNames[argIndex];
1169 } else {
1170 range = std::string(node.getArgName(argIndex));
1171 }
1172 // Resolve the symbol for all range use so that we have a uniform way of
1173 // capturing the values.
1174 range = symbolInfoMap.getValueAndRangeUse(range);
1175 os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range,
1176 varName);
1177 } else {
1178 varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1179 os << formatv("::mlir::Value {0} = ", varName);
1180 if (node.isNestedDagArg(argIndex)) {
1181 os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1182 } else {
1183 DagLeaf leaf = node.getArgAsLeaf(argIndex);
1184 auto symbol =
1185 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1186 if (leaf.isNativeCodeCall()) {
1187 os << std::string(
1188 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1189 } else {
1190 os << symbol;
1191 }
1192 }
1193 os << ";\n";
1194 }
1195
1196 // Update to use the newly created local variable for building the op later.
1197 childNodeNames[argIndex] = varName;
1198 }
1199 }
1200
supplyValuesForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1201 void PatternEmitter::supplyValuesForOpArgs(
1202 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1203 Operator &resultOp = node.getDialectOp(opMap);
1204 for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1205 argIndex != numOpArgs; ++argIndex) {
1206 // Start each argument on its own line.
1207 os << ",\n ";
1208
1209 Argument opArg = resultOp.getArg(argIndex);
1210 // Handle the case of operand first.
1211 if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
1212 if (!operand->name.empty())
1213 os << "/*" << operand->name << "=*/";
1214 os << childNodeNames.lookup(argIndex);
1215 continue;
1216 }
1217
1218 // The argument in the op definition.
1219 auto opArgName = resultOp.getArgName(argIndex);
1220 if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1221 if (!subTree.isNativeCodeCall())
1222 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1223 "for creating attribute");
1224 os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
1225 } else {
1226 auto leaf = node.getArgAsLeaf(argIndex);
1227 // The argument in the result DAG pattern.
1228 auto patArgName = node.getArgName(argIndex);
1229 if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1230 // TODO: Refactor out into map to avoid recomputing these.
1231 if (!opArg.is<NamedAttribute *>())
1232 PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1233 if (!patArgName.empty())
1234 os << "/*" << patArgName << "=*/";
1235 } else {
1236 os << "/*" << opArgName << "=*/";
1237 }
1238 os << handleOpArgument(leaf, patArgName);
1239 }
1240 }
1241 }
1242
createAggregateLocalVarsForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1243 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1244 DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1245 Operator &resultOp = node.getDialectOp(opMap);
1246
1247 auto scope = os.scope();
1248 os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
1249 "tblgen_values; (void)tblgen_values;\n");
1250 os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
1251 "tblgen_attrs; (void)tblgen_attrs;\n");
1252
1253 const char *addAttrCmd =
1254 "if (auto tmpAttr = {1}) {\n"
1255 " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
1256 "tmpAttr);\n}\n";
1257 for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1258 if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1259 // The argument in the op definition.
1260 auto opArgName = resultOp.getArgName(argIndex);
1261 if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1262 if (!subTree.isNativeCodeCall())
1263 PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1264 "for creating attribute");
1265 os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
1266 } else {
1267 auto leaf = node.getArgAsLeaf(argIndex);
1268 // The argument in the result DAG pattern.
1269 auto patArgName = node.getArgName(argIndex);
1270 os << formatv(addAttrCmd, opArgName,
1271 handleOpArgument(leaf, patArgName));
1272 }
1273 continue;
1274 }
1275
1276 const auto *operand =
1277 resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1278 std::string varName;
1279 if (operand->isVariadic()) {
1280 std::string range;
1281 if (node.isNestedDagArg(argIndex)) {
1282 range = childNodeNames.lookup(argIndex);
1283 } else {
1284 range = std::string(node.getArgName(argIndex));
1285 }
1286 // Resolve the symbol for all range use so that we have a uniform way of
1287 // capturing the values.
1288 range = symbolInfoMap.getValueAndRangeUse(range);
1289 os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
1290 range);
1291 } else {
1292 os << formatv("tblgen_values.push_back(");
1293 if (node.isNestedDagArg(argIndex)) {
1294 os << symbolInfoMap.getValueAndRangeUse(
1295 childNodeNames.lookup(argIndex));
1296 } else {
1297 DagLeaf leaf = node.getArgAsLeaf(argIndex);
1298 auto symbol =
1299 symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1300 if (leaf.isNativeCodeCall()) {
1301 os << std::string(
1302 tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1303 } else {
1304 os << symbol;
1305 }
1306 }
1307 os << ");\n";
1308 }
1309 }
1310 }
1311
emitRewriters(const RecordKeeper & recordKeeper,raw_ostream & os)1312 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1313 emitSourceFileHeader("Rewriters", os);
1314
1315 const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1316 auto numPatterns = patterns.size();
1317
1318 // We put the map here because it can be shared among multiple patterns.
1319 RecordOperatorMap recordOpMap;
1320
1321 std::vector<std::string> rewriterNames;
1322 rewriterNames.reserve(numPatterns);
1323
1324 std::string baseRewriterName = "GeneratedConvert";
1325 int rewriterIndex = 0;
1326
1327 for (Record *p : patterns) {
1328 std::string name;
1329 if (p->isAnonymous()) {
1330 // If no name is provided, ensure unique rewriter names simply by
1331 // appending unique suffix.
1332 name = baseRewriterName + llvm::utostr(rewriterIndex++);
1333 } else {
1334 name = std::string(p->getName());
1335 }
1336 LLVM_DEBUG(llvm::dbgs()
1337 << "=== start generating pattern '" << name << "' ===\n");
1338 PatternEmitter(p, &recordOpMap, os).emit(name);
1339 LLVM_DEBUG(llvm::dbgs()
1340 << "=== done generating pattern '" << name << "' ===\n");
1341 rewriterNames.push_back(std::move(name));
1342 }
1343
1344 // Emit function to add the generated matchers to the pattern list.
1345 os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
1346 "::mlir::RewritePatternSet &patterns) {\n";
1347 for (const auto &name : rewriterNames) {
1348 os << " patterns.add<" << name << ">(patterns.getContext());\n";
1349 }
1350 os << "}\n";
1351 }
1352
1353 static mlir::GenRegistration
1354 genRewriters("gen-rewriters", "Generate pattern rewriters",
__anonb034f2370602(const RecordKeeper &records, raw_ostream &os) 1355 [](const RecordKeeper &records, raw_ostream &os) {
1356 emitRewriters(records, os);
1357 return false;
1358 });
1359