1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/TableGen/Pattern.h"
19 #include "mlir/TableGen/Predicate.h"
20 #include "mlir/TableGen/Type.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/StringSet.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Debug.h"
25 #include "llvm/Support/FormatAdapters.h"
26 #include "llvm/Support/PrettyStackTrace.h"
27 #include "llvm/Support/Signals.h"
28 #include "llvm/TableGen/Error.h"
29 #include "llvm/TableGen/Main.h"
30 #include "llvm/TableGen/Record.h"
31 #include "llvm/TableGen/TableGenBackend.h"
32 
33 using namespace mlir;
34 using namespace mlir::tblgen;
35 
36 using llvm::formatv;
37 using llvm::Record;
38 using llvm::RecordKeeper;
39 
40 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
41 
42 namespace llvm {
43 template <>
44 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
formatllvm::format_provider45   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
46                      raw_ostream &os, StringRef style) {
47     os << v.first << ":" << v.second;
48   }
49 };
50 } // end namespace llvm
51 
52 //===----------------------------------------------------------------------===//
53 // PatternEmitter
54 //===----------------------------------------------------------------------===//
55 
56 namespace {
57 class PatternEmitter {
58 public:
59   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os);
60 
61   // Emits the mlir::RewritePattern struct named `rewriteName`.
62   void emit(StringRef rewriteName);
63 
64 private:
65   // Emits the code for matching ops.
66   void emitMatchLogic(DagNode tree, StringRef opName);
67 
68   // Emits the code for rewriting ops.
69   void emitRewriteLogic();
70 
71   //===--------------------------------------------------------------------===//
72   // Match utilities
73   //===--------------------------------------------------------------------===//
74 
75   // Emits C++ statements for matching the DAG structure.
76   void emitMatch(DagNode tree, StringRef name, int depth);
77 
78   // Emits C++ statements for matching using a native code call.
79   void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
80 
81   // Emits C++ statements for matching the op constrained by the given DAG
82   // `tree` returning the op's variable name.
83   void emitOpMatch(DagNode tree, StringRef opName, int depth);
84 
85   // Emits C++ statements for matching the `argIndex`-th argument of the given
86   // DAG `tree` as an operand. operandIndex is the index in the DAG excluding
87   // the preceding attributes.
88   void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
89                         int operandIndex, int depth);
90 
91   // Emits C++ statements for matching the `argIndex`-th argument of the given
92   // DAG `tree` as an attribute.
93   void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
94                           int depth);
95 
96   // Emits C++ for checking a match with a corresponding match failure
97   // diagnostic.
98   void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
99                       const llvm::formatv_object_base &failureFmt);
100 
101   // Emits C++ for checking a match with a corresponding match failure
102   // diagnostics.
103   void emitMatchCheck(StringRef opName, const std::string &matchStr,
104                       const std::string &failureStr);
105 
106   //===--------------------------------------------------------------------===//
107   // Rewrite utilities
108   //===--------------------------------------------------------------------===//
109 
110   // The entry point for handling a result pattern rooted at `resultTree`. This
111   // method dispatches to concrete handlers according to `resultTree`'s kind and
112   // returns a symbol representing the whole value pack. Callers are expected to
113   // further resolve the symbol according to the specific use case.
114   //
115   // `depth` is the nesting level of `resultTree`; 0 means top-level result
116   // pattern. For top-level result pattern, `resultIndex` indicates which result
117   // of the matched root op this pattern is intended to replace, which can be
118   // used to deduce the result type of the op generated from this result
119   // pattern.
120   std::string handleResultPattern(DagNode resultTree, int resultIndex,
121                                   int depth);
122 
123   // Emits the C++ statement to replace the matched DAG with a value built via
124   // calling native C++ code.
125   std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
126 
127   // Returns the symbol of the old value serving as the replacement.
128   StringRef handleReplaceWithValue(DagNode tree);
129 
130   // Returns the location value to use.
131   std::pair<bool, std::string> getLocation(DagNode tree);
132 
133   // Returns the location value to use.
134   std::string handleLocationDirective(DagNode tree);
135 
136   // Emits the C++ statement to build a new op out of the given DAG `tree` and
137   // returns the variable name that this op is assigned to. If the root op in
138   // DAG `tree` has a specified name, the created op will be assigned to a
139   // variable of the given name. Otherwise, a unique name will be used as the
140   // result value name.
141   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
142 
143   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
144 
145   // Emits a local variable for each value and attribute to be used for creating
146   // an op.
147   void createSeparateLocalVarsForOpArgs(DagNode node,
148                                         ChildNodeIndexNameMap &childNodeNames);
149 
150   // Emits the concrete arguments used to call an op's builder.
151   void supplyValuesForOpArgs(DagNode node,
152                              const ChildNodeIndexNameMap &childNodeNames,
153                              int depth);
154 
155   // Emits the local variables for holding all values as a whole and all named
156   // attributes as a whole to be used for creating an op.
157   void createAggregateLocalVarsForOpArgs(
158       DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
159 
160   // Returns the C++ expression to construct a constant attribute of the given
161   // `value` for the given attribute kind `attr`.
162   std::string handleConstantAttr(Attribute attr, StringRef value);
163 
164   // Returns the C++ expression to build an argument from the given DAG `leaf`.
165   // `patArgName` is used to bound the argument to the source pattern.
166   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
167 
168   //===--------------------------------------------------------------------===//
169   // General utilities
170   //===--------------------------------------------------------------------===//
171 
172   // Collects all of the operations within the given dag tree.
173   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
174 
175   // Returns a unique symbol for a local variable of the given `op`.
176   std::string getUniqueSymbol(const Operator *op);
177 
178   //===--------------------------------------------------------------------===//
179   // Symbol utilities
180   //===--------------------------------------------------------------------===//
181 
182   // Returns how many static values the given DAG `node` correspond to.
183   int getNodeValueCount(DagNode node);
184 
185 private:
186   // Pattern instantiation location followed by the location of multiclass
187   // prototypes used. This is intended to be used as a whole to
188   // PrintFatalError() on errors.
189   ArrayRef<llvm::SMLoc> loc;
190 
191   // Op's TableGen Record to wrapper object.
192   RecordOperatorMap *opMap;
193 
194   // Handy wrapper for pattern being emitted.
195   Pattern pattern;
196 
197   // Map for all bound symbols' info.
198   SymbolInfoMap symbolInfoMap;
199 
200   // The next unused ID for newly created values.
201   unsigned nextValueId;
202 
203   raw_indented_ostream os;
204 
205   // Format contexts containing placeholder substitutions.
206   FmtContext fmtCtx;
207 
208   // Number of op processed.
209   int opCounter = 0;
210 };
211 } // end anonymous namespace
212 
PatternEmitter(Record * pat,RecordOperatorMap * mapper,raw_ostream & os)213 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
214                                raw_ostream &os)
215     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
216       symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
217   fmtCtx.withBuilder("rewriter");
218 }
219 
handleConstantAttr(Attribute attr,StringRef value)220 std::string PatternEmitter::handleConstantAttr(Attribute attr,
221                                                StringRef value) {
222   if (!attr.isConstBuildable())
223     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
224                              " does not have the 'constBuilderCall' field");
225 
226   // TODO: Verify the constants here
227   return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
228 }
229 
230 // Helper function to match patterns.
emitMatch(DagNode tree,StringRef name,int depth)231 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
232   if (tree.isNativeCodeCall()) {
233     emitNativeCodeMatch(tree, name, depth);
234     return;
235   }
236 
237   if (tree.isOperation()) {
238     emitOpMatch(tree, name, depth);
239     return;
240   }
241 
242   PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
243 }
244 
245 // Helper function to match patterns.
emitNativeCodeMatch(DagNode tree,StringRef opName,int depth)246 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
247                                          int depth) {
248   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
249   LLVM_DEBUG(tree.print(llvm::dbgs()));
250   LLVM_DEBUG(llvm::dbgs() << '\n');
251 
252   // TODO(suderman): iterate through arguments, determine their types, output
253   // names.
254   SmallVector<std::string, 8> capture;
255 
256   raw_indented_ostream::DelimitedScope scope(os);
257 
258   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
259     std::string argName = formatv("arg{0}_{1}", depth, i);
260     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
261       os << "Value " << argName << ";\n";
262     } else {
263       auto leaf = tree.getArgAsLeaf(i);
264       if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
265         os << "Attribute " << argName << ";\n";
266       } else {
267         os << "Value " << argName << ";\n";
268       }
269     }
270 
271     capture.push_back(std::move(argName));
272   }
273 
274   bool hasLocationDirective;
275   std::string locToUse;
276   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
277 
278   auto fmt = tree.getNativeCodeTemplate();
279   if (fmt.count("$_self") != 1)
280     PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
281                          "passing the defining Operation");
282 
283   auto nativeCodeCall = std::string(tgfmt(
284       fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), capture));
285 
286   emitMatchCheck(opName, formatv("!failed({0})", nativeCodeCall),
287                  formatv("\"{0} return failure\"", nativeCodeCall));
288 
289   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
290     auto name = tree.getArgName(i);
291     if (!name.empty() && name != "_") {
292       os << formatv("{0} = {1};\n", name, capture[i]);
293     }
294   }
295 
296   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
297     std::string argName = capture[i];
298 
299     // Handle nested DAG construct first
300     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
301       PrintFatalError(
302           loc, formatv("Matching nested tree in NativeCodecall not support for "
303                        "{0} as arg {1}",
304                        argName, i));
305     }
306 
307     DagLeaf leaf = tree.getArgAsLeaf(i);
308 
309     // The parameter for native function doesn't bind any constraints.
310     if (leaf.isUnspecified())
311       continue;
312 
313     auto constraint = leaf.getAsConstraint();
314 
315     std::string self;
316     if (leaf.isAttrMatcher() || leaf.isConstantAttr())
317       self = argName;
318     else
319       self = formatv("{0}.getType()", argName);
320     emitMatchCheck(
321         opName,
322         tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
323         formatv("\"operand {0} of native code call '{1}' failed to satisfy "
324                 "constraint: "
325                 "'{2}'\"",
326                 i, tree.getNativeCodeTemplate(), constraint.getSummary()));
327   }
328 
329   LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
330 }
331 
332 // Helper function to match patterns.
emitOpMatch(DagNode tree,StringRef opName,int depth)333 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
334   Operator &op = tree.getDialectOp(opMap);
335   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
336                           << op.getOperationName() << "' at depth " << depth
337                           << '\n');
338 
339   std::string castedName = formatv("castedOp{0}", depth);
340   os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
341                 "(void){0};\n",
342                 castedName, opName, op.getQualCppClassName());
343 
344   // Skip the operand matching at depth 0 as the pattern rewriter already does.
345   if (depth != 0)
346     emitMatchCheck(opName, /*matchStr=*/castedName,
347                    formatv("\"{0} is not {1} type\"", castedName,
348                            op.getQualCppClassName()));
349 
350   if (tree.getNumArgs() != op.getNumArgs())
351     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
352                                  "pattern vs. {2} in definition",
353                                  op.getOperationName(), tree.getNumArgs(),
354                                  op.getNumArgs()));
355 
356   // If the operand's name is set, set to that variable.
357   auto name = tree.getSymbol();
358   if (!name.empty())
359     os << formatv("{0} = {1};\n", name, castedName);
360 
361   for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) {
362     auto opArg = op.getArg(i);
363     std::string argName = formatv("op{0}", depth + 1);
364 
365     // Handle nested DAG construct first
366     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
367       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
368         if (operand->isVariableLength()) {
369           auto error = formatv("use nested DAG construct to match op {0}'s "
370                                "variadic operand #{1} unsupported now",
371                                op.getOperationName(), i);
372           PrintFatalError(loc, error);
373         }
374       }
375       os << "{\n";
376 
377       // Attributes don't count for getODSOperands.
378       // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
379       os.indent() << formatv(
380           "auto *{0} = "
381           "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
382           argName, castedName, nextOperand);
383       // Null check of operand's definingOp
384       emitMatchCheck(castedName, /*matchStr=*/argName,
385                      formatv("\"Operand {0} of {1} has null definingOp\"",
386                              nextOperand++, castedName));
387       emitMatch(argTree, argName, depth + 1);
388       os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName);
389       os.unindent() << "}\n";
390       continue;
391     }
392 
393     // Next handle DAG leaf: operand or attribute
394     if (opArg.is<NamedTypeConstraint *>()) {
395       // emitOperandMatch's argument indexing counts attributes.
396       emitOperandMatch(tree, castedName, i, nextOperand, depth);
397       ++nextOperand;
398     } else if (opArg.is<NamedAttribute *>()) {
399       emitAttributeMatch(tree, opName, i, depth);
400     } else {
401       PrintFatalError(loc, "unhandled case when matching op");
402     }
403   }
404   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
405                           << op.getOperationName() << "' at depth " << depth
406                           << '\n');
407 }
408 
emitOperandMatch(DagNode tree,StringRef opName,int argIndex,int operandIndex,int depth)409 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
410                                       int argIndex, int operandIndex,
411                                       int depth) {
412   Operator &op = tree.getDialectOp(opMap);
413   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
414   auto matcher = tree.getArgAsLeaf(argIndex);
415 
416   // If a constraint is specified, we need to generate C++ statements to
417   // check the constraint.
418   if (!matcher.isUnspecified()) {
419     if (!matcher.isOperandMatcher()) {
420       PrintFatalError(
421           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
422                        op.getOperationName(), argIndex + 1));
423     }
424 
425     // Only need to verify if the matcher's type is different from the one
426     // of op definition.
427     Constraint constraint = matcher.getAsConstraint();
428     if (operand->constraint != constraint) {
429       if (operand->isVariableLength()) {
430         auto error = formatv(
431             "further constrain op {0}'s variadic operand #{1} unsupported now",
432             op.getOperationName(), argIndex);
433         PrintFatalError(loc, error);
434       }
435       auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
436                           opName, operandIndex);
437       emitMatchCheck(
438           opName,
439           tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
440           formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
441                   "'{2}'\"",
442                   operand - op.operand_begin(), op.getOperationName(),
443                   constraint.getSummary()));
444     }
445   }
446 
447   // Capture the value
448   auto name = tree.getArgName(argIndex);
449   // `$_` is a special symbol to ignore op argument matching.
450   if (!name.empty() && name != "_") {
451     // We need to subtract the number of attributes before this operand to get
452     // the index in the operand list.
453     auto numPrevAttrs = std::count_if(
454         op.arg_begin(), op.arg_begin() + argIndex,
455         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
456 
457     auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex);
458     os << formatv("{0} = {1}.getODSOperands({2});\n",
459                   res->second.getVarName(name), opName,
460                   argIndex - numPrevAttrs);
461   }
462 }
463 
emitAttributeMatch(DagNode tree,StringRef opName,int argIndex,int depth)464 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
465                                         int argIndex, int depth) {
466   Operator &op = tree.getDialectOp(opMap);
467   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
468   const auto &attr = namedAttr->attr;
469 
470   os << "{\n";
471   os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
472                          "(void)tblgen_attr;\n",
473                          opName, attr.getStorageType(), namedAttr->name);
474 
475   // TODO: This should use getter method to avoid duplication.
476   if (attr.hasDefaultValue()) {
477     os << "if (!tblgen_attr) tblgen_attr = "
478        << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
479                             attr.getDefaultValue()))
480        << ";\n";
481   } else if (attr.isOptional()) {
482     // For a missing attribute that is optional according to definition, we
483     // should just capture a mlir::Attribute() to signal the missing state.
484     // That is precisely what getAttr() returns on missing attributes.
485   } else {
486     emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
487                    formatv("\"expected op '{0}' to have attribute '{1}' "
488                            "of type '{2}'\"",
489                            op.getOperationName(), namedAttr->name,
490                            attr.getStorageType()));
491   }
492 
493   auto matcher = tree.getArgAsLeaf(argIndex);
494   if (!matcher.isUnspecified()) {
495     if (!matcher.isAttrMatcher()) {
496       PrintFatalError(
497           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
498                        op.getOperationName(), argIndex + 1));
499     }
500 
501     // If a constraint is specified, we need to generate C++ statements to
502     // check the constraint.
503     emitMatchCheck(
504         opName,
505         tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
506         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
507                 "{2}\"",
508                 op.getOperationName(), namedAttr->name,
509                 matcher.getAsConstraint().getSummary()));
510   }
511 
512   // Capture the value
513   auto name = tree.getArgName(argIndex);
514   // `$_` is a special symbol to ignore op argument matching.
515   if (!name.empty() && name != "_") {
516     os << formatv("{0} = tblgen_attr;\n", name);
517   }
518 
519   os.unindent() << "}\n";
520 }
521 
emitMatchCheck(StringRef opName,const FmtObjectBase & matchFmt,const llvm::formatv_object_base & failureFmt)522 void PatternEmitter::emitMatchCheck(
523     StringRef opName, const FmtObjectBase &matchFmt,
524     const llvm::formatv_object_base &failureFmt) {
525   emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
526 }
527 
emitMatchCheck(StringRef opName,const std::string & matchStr,const std::string & failureStr)528 void PatternEmitter::emitMatchCheck(StringRef opName,
529                                     const std::string &matchStr,
530                                     const std::string &failureStr) {
531 
532   os << "if (!(" << matchStr << "))";
533   os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
534                               << ", [&](::mlir::Diagnostic &diag) {\n  diag << "
535                               << failureStr << ";\n});";
536 }
537 
emitMatchLogic(DagNode tree,StringRef opName)538 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
539   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
540   int depth = 0;
541   emitMatch(tree, opName, depth);
542 
543   for (auto &appliedConstraint : pattern.getConstraints()) {
544     auto &constraint = appliedConstraint.constraint;
545     auto &entities = appliedConstraint.entities;
546 
547     auto condition = constraint.getConditionTemplate();
548     if (isa<TypeConstraint>(constraint)) {
549       auto self = formatv("({0}.getType())",
550                           symbolInfoMap.getValueAndRangeUse(entities.front()));
551       emitMatchCheck(
552           opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
553           formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
554                   entities.front(), constraint.getSummary()));
555 
556     } else if (isa<AttrConstraint>(constraint)) {
557       PrintFatalError(
558           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
559     } else {
560       // TODO: replace formatv arguments with the exact specified
561       // args.
562       if (entities.size() > 4) {
563         PrintFatalError(loc, "only support up to 4-entity constraints now");
564       }
565       SmallVector<std::string, 4> names;
566       int i = 0;
567       for (int e = entities.size(); i < e; ++i)
568         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
569       std::string self = appliedConstraint.self;
570       if (!self.empty())
571         self = symbolInfoMap.getValueAndRangeUse(self);
572       for (; i < 4; ++i)
573         names.push_back("<unused>");
574       emitMatchCheck(opName,
575                      tgfmt(condition, &fmtCtx.withSelf(self), names[0],
576                            names[1], names[2], names[3]),
577                      formatv("\"entities '{0}' failed to satisfy constraint: "
578                              "{1}\"",
579                              llvm::join(entities, ", "),
580                              constraint.getSummary()));
581     }
582   }
583 
584   // Some of the operands could be bound to the same symbol name, we need
585   // to enforce equality constraint on those.
586   // TODO: we should be able to emit equality checks early
587   // and short circuit unnecessary work if vars are not equal.
588   for (auto symbolInfoIt = symbolInfoMap.begin();
589        symbolInfoIt != symbolInfoMap.end();) {
590     auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
591     auto startRange = range.first;
592     auto endRange = range.second;
593 
594     auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
595     for (++startRange; startRange != endRange; ++startRange) {
596       auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
597       emitMatchCheck(
598           opName,
599           formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
600           formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
601                   secondOperand));
602     }
603 
604     symbolInfoIt = endRange;
605   }
606 
607   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
608 }
609 
collectOps(DagNode tree,llvm::SmallPtrSetImpl<const Operator * > & ops)610 void PatternEmitter::collectOps(DagNode tree,
611                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
612   // Check if this tree is an operation.
613   if (tree.isOperation()) {
614     const Operator &op = tree.getDialectOp(opMap);
615     LLVM_DEBUG(llvm::dbgs()
616                << "found operation " << op.getOperationName() << '\n');
617     ops.insert(&op);
618   }
619 
620   // Recurse the arguments of the tree.
621   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
622     if (auto child = tree.getArgAsNestedDag(i))
623       collectOps(child, ops);
624 }
625 
emit(StringRef rewriteName)626 void PatternEmitter::emit(StringRef rewriteName) {
627   // Get the DAG tree for the source pattern.
628   DagNode sourceTree = pattern.getSourcePattern();
629 
630   const Operator &rootOp = pattern.getSourceRootOp();
631   auto rootName = rootOp.getOperationName();
632 
633   // Collect the set of result operations.
634   llvm::SmallPtrSet<const Operator *, 4> resultOps;
635   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
636   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
637     collectOps(pattern.getResultPattern(i), resultOps);
638   }
639   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
640 
641   // Emit RewritePattern for Pattern.
642   auto locs = pattern.getLocation();
643   os << formatv("/* Generated from:\n    {0:$[ instantiating\n    ]}\n*/\n",
644                 make_range(locs.rbegin(), locs.rend()));
645   os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
646   {0}(::mlir::MLIRContext *context)
647       : ::mlir::RewritePattern("{1}", {2}, context, {{)",
648                 rewriteName, rootName, pattern.getBenefit());
649   // Sort result operators by name.
650   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
651                                                          resultOps.end());
652   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
653     return lhs->getOperationName() < rhs->getOperationName();
654   });
655   llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
656     os << '"' << op->getOperationName() << '"';
657   });
658   os << "}) {}\n";
659 
660   // Emit matchAndRewrite() function.
661   {
662     auto classScope = os.scope();
663     os.reindent(R"(
664     ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
665         ::mlir::PatternRewriter &rewriter) const override {)")
666         << '\n';
667     {
668       auto functionScope = os.scope();
669 
670       // Register all symbols bound in the source pattern.
671       pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
672 
673       LLVM_DEBUG(llvm::dbgs()
674                  << "start creating local variables for capturing matches\n");
675       os << "// Variables for capturing values and attributes used while "
676             "creating ops\n";
677       // Create local variables for storing the arguments and results bound
678       // to symbols.
679       for (const auto &symbolInfoPair : symbolInfoMap) {
680         const auto &symbol = symbolInfoPair.first;
681         const auto &info = symbolInfoPair.second;
682 
683         os << info.getVarDecl(symbol);
684       }
685       // TODO: capture ops with consistent numbering so that it can be
686       // reused for fused loc.
687       os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
688                     pattern.getSourcePattern().getNumOps());
689       LLVM_DEBUG(llvm::dbgs()
690                  << "done creating local variables for capturing matches\n");
691 
692       os << "// Match\n";
693       os << "tblgen_ops[0] = op0;\n";
694       emitMatchLogic(sourceTree, "op0");
695 
696       os << "\n// Rewrite\n";
697       emitRewriteLogic();
698 
699       os << "return ::mlir::success();\n";
700     }
701     os << "};\n";
702   }
703   os << "};\n\n";
704 }
705 
emitRewriteLogic()706 void PatternEmitter::emitRewriteLogic() {
707   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
708   const Operator &rootOp = pattern.getSourceRootOp();
709   int numExpectedResults = rootOp.getNumResults();
710   int numResultPatterns = pattern.getNumResultPatterns();
711 
712   // First register all symbols bound to ops generated in result patterns.
713   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
714 
715   // Only the last N static values generated are used to replace the matched
716   // root N-result op. We need to calculate the starting index (of the results
717   // of the matched op) each result pattern is to replace.
718   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
719   // If we don't need to replace any value at all, set the replacement starting
720   // index as the number of result patterns so we skip all of them when trying
721   // to replace the matched op's results.
722   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
723   for (int i = numResultPatterns - 1; i >= 0; --i) {
724     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
725     offsets[i] = offsets[i + 1] - numValues;
726     if (offsets[i] == 0) {
727       if (replStartIndex == -1)
728         replStartIndex = i;
729     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
730       auto error = formatv(
731           "cannot use the same multi-result op '{0}' to generate both "
732           "auxiliary values and values to be used for replacing the matched op",
733           pattern.getResultPattern(i).getSymbol());
734       PrintFatalError(loc, error);
735     }
736   }
737 
738   if (offsets.front() > 0) {
739     const char error[] = "no enough values generated to replace the matched op";
740     PrintFatalError(loc, error);
741   }
742 
743   os << "auto odsLoc = rewriter.getFusedLoc({";
744   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
745     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
746   }
747   os << "}); (void)odsLoc;\n";
748 
749   // Process auxiliary result patterns.
750   for (int i = 0; i < replStartIndex; ++i) {
751     DagNode resultTree = pattern.getResultPattern(i);
752     auto val = handleResultPattern(resultTree, offsets[i], 0);
753     // Normal op creation will be streamed to `os` by the above call; but
754     // NativeCodeCall will only be materialized to `os` if it is used. Here
755     // we are handling auxiliary patterns so we want the side effect even if
756     // NativeCodeCall is not replacing matched root op's results.
757     if (resultTree.isNativeCodeCall() &&
758         resultTree.getNumReturnsOfNativeCode() == 0)
759       os << val << ";\n";
760   }
761 
762   if (numExpectedResults == 0) {
763     assert(replStartIndex >= numResultPatterns &&
764            "invalid auxiliary vs. replacement pattern division!");
765     // No result to replace. Just erase the op.
766     os << "rewriter.eraseOp(op0);\n";
767   } else {
768     // Process replacement result patterns.
769     os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
770     for (int i = replStartIndex; i < numResultPatterns; ++i) {
771       DagNode resultTree = pattern.getResultPattern(i);
772       auto val = handleResultPattern(resultTree, offsets[i], 0);
773       os << "\n";
774       // Resolve each symbol for all range use so that we can loop over them.
775       // We need an explicit cast to `SmallVector` to capture the cases where
776       // `{0}` resolves to an `Operation::result_range` as well as cases that
777       // are not iterable (e.g. vector that gets wrapped in additional braces by
778       // RewriterGen).
779       // TODO: Revisit the need for materializing a vector.
780       os << symbolInfoMap.getAllRangeUse(
781           val,
782           "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
783           "  tblgen_repl_values.push_back(v);\n}\n",
784           "\n");
785     }
786     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
787   }
788 
789   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
790 }
791 
getUniqueSymbol(const Operator * op)792 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
793   return std::string(
794       formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
795 }
796 
handleResultPattern(DagNode resultTree,int resultIndex,int depth)797 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
798                                                 int resultIndex, int depth) {
799   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
800   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
801   LLVM_DEBUG(llvm::dbgs() << '\n');
802 
803   if (resultTree.isLocationDirective()) {
804     PrintFatalError(loc,
805                     "location directive can only be used with op creation");
806   }
807 
808   if (resultTree.isNativeCodeCall())
809     return handleReplaceWithNativeCodeCall(resultTree, depth);
810 
811   if (resultTree.isReplaceWithValue())
812     return handleReplaceWithValue(resultTree).str();
813 
814   // Normal op creation.
815   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
816   if (resultTree.getSymbol().empty()) {
817     // This is an op not explicitly bound to a symbol in the rewrite rule.
818     // Register the auto-generated symbol for it.
819     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
820   }
821   return symbol;
822 }
823 
handleReplaceWithValue(DagNode tree)824 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
825   assert(tree.isReplaceWithValue());
826 
827   if (tree.getNumArgs() != 1) {
828     PrintFatalError(
829         loc, "replaceWithValue directive must take exactly one argument");
830   }
831 
832   if (!tree.getSymbol().empty()) {
833     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
834   }
835 
836   return tree.getArgName(0);
837 }
838 
handleLocationDirective(DagNode tree)839 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
840   assert(tree.isLocationDirective());
841   auto lookUpArgLoc = [this, &tree](int idx) {
842     const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
843     return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
844   };
845 
846   if (tree.getNumArgs() == 0)
847     llvm::PrintFatalError(
848         "At least one argument to location directive required");
849 
850   if (!tree.getSymbol().empty())
851     PrintFatalError(loc, "cannot bind symbol to location");
852 
853   if (tree.getNumArgs() == 1) {
854     DagLeaf leaf = tree.getArgAsLeaf(0);
855     if (leaf.isStringAttr())
856       return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"))",
857                      leaf.getStringAttr())
858           .str();
859     return lookUpArgLoc(0);
860   }
861 
862   std::string ret;
863   llvm::raw_string_ostream os(ret);
864   std::string strAttr;
865   os << "rewriter.getFusedLoc({";
866   bool first = true;
867   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
868     DagLeaf leaf = tree.getArgAsLeaf(i);
869     // Handle the optional string value.
870     if (leaf.isStringAttr()) {
871       if (!strAttr.empty())
872         llvm::PrintFatalError("Only one string attribute may be specified");
873       strAttr = leaf.getStringAttr();
874       continue;
875     }
876     os << (first ? "" : ", ") << lookUpArgLoc(i);
877     first = false;
878   }
879   os << "}";
880   if (!strAttr.empty()) {
881     os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
882   }
883   os << ")";
884   return os.str();
885 }
886 
handleOpArgument(DagLeaf leaf,StringRef patArgName)887 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
888                                              StringRef patArgName) {
889   if (leaf.isStringAttr())
890     PrintFatalError(loc, "raw string not supported as argument");
891   if (leaf.isConstantAttr()) {
892     auto constAttr = leaf.getAsConstantAttr();
893     return handleConstantAttr(constAttr.getAttribute(),
894                               constAttr.getConstantValue());
895   }
896   if (leaf.isEnumAttrCase()) {
897     auto enumCase = leaf.getAsEnumAttrCase();
898     if (enumCase.isStrCase())
899       return handleConstantAttr(enumCase, enumCase.getSymbol());
900     // This is an enum case backed by an IntegerAttr. We need to get its value
901     // to build the constant.
902     std::string val = std::to_string(enumCase.getValue());
903     return handleConstantAttr(enumCase, val);
904   }
905 
906   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
907   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
908   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
909     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
910                             << "' (via symbol ref)\n");
911     return argName;
912   }
913   if (leaf.isNativeCodeCall()) {
914     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
915     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
916                             << "' (via NativeCodeCall)\n");
917     return std::string(repl);
918   }
919   PrintFatalError(loc, "unhandled case when rewriting op");
920 }
921 
handleReplaceWithNativeCodeCall(DagNode tree,int depth)922 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
923                                                             int depth) {
924   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
925   LLVM_DEBUG(tree.print(llvm::dbgs()));
926   LLVM_DEBUG(llvm::dbgs() << '\n');
927 
928   auto fmt = tree.getNativeCodeTemplate();
929 
930   SmallVector<std::string, 16> attrs;
931 
932   bool hasLocationDirective;
933   std::string locToUse;
934   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
935 
936   for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
937     if (tree.isNestedDagArg(i)) {
938       attrs.push_back(
939           handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
940     } else {
941       attrs.push_back(
942           handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
943     }
944     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
945                             << " replacement: " << attrs[i] << "\n");
946   }
947 
948   std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs);
949 
950   // In general, NativeCodeCall without naming binding don't need this. To
951   // ensure void helper function has been correctly labeled, i.e., use
952   // NativeCodeCallVoid, we cache the result to a local variable so that we will
953   // get a compilation error in the auto-generated file.
954   // Example.
955   //   // In the td file
956   //   Pat<(...), (NativeCodeCall<Foo> ...)>
957   //
958   //   ---
959   //
960   //   // In the auto-generated .cpp
961   //   ...
962   //   // Causes compilation error if Foo() returns void.
963   //   auto nativeVar = Foo();
964   //   ...
965   if (tree.getNumReturnsOfNativeCode() != 0) {
966     // Determine the local variable name for return value.
967     std::string varName =
968         SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
969     if (varName.empty()) {
970       varName = formatv("nativeVar_{0}", nextValueId++);
971       // Register the local variable for later uses.
972       symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
973     }
974 
975     // Catch the return value of helper function.
976     os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
977 
978     if (!tree.getSymbol().empty())
979       symbol = tree.getSymbol().str();
980     else
981       symbol = varName;
982   }
983 
984   return symbol;
985 }
986 
getNodeValueCount(DagNode node)987 int PatternEmitter::getNodeValueCount(DagNode node) {
988   if (node.isOperation()) {
989     // If the op is bound to a symbol in the rewrite rule, query its result
990     // count from the symbol info map.
991     auto symbol = node.getSymbol();
992     if (!symbol.empty()) {
993       return symbolInfoMap.getStaticValueCount(symbol);
994     }
995     // Otherwise this is an unbound op; we will use all its results.
996     return pattern.getDialectOp(node).getNumResults();
997   }
998 
999   if (node.isNativeCodeCall())
1000     return node.getNumReturnsOfNativeCode();
1001 
1002   return 1;
1003 }
1004 
getLocation(DagNode tree)1005 std::pair<bool, std::string> PatternEmitter::getLocation(DagNode tree) {
1006   auto numPatArgs = tree.getNumArgs();
1007 
1008   if (numPatArgs != 0) {
1009     if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1))
1010       if (lastArg.isLocationDirective()) {
1011         return std::make_pair(true, handleLocationDirective(lastArg));
1012       }
1013   }
1014 
1015   // If no explicit location is given, use the default, all fused, location.
1016   return std::make_pair(false, "odsLoc");
1017 }
1018 
handleOpCreation(DagNode tree,int resultIndex,int depth)1019 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1020                                              int depth) {
1021   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1022   LLVM_DEBUG(tree.print(llvm::dbgs()));
1023   LLVM_DEBUG(llvm::dbgs() << '\n');
1024 
1025   Operator &resultOp = tree.getDialectOp(opMap);
1026   auto numOpArgs = resultOp.getNumArgs();
1027   auto numPatArgs = tree.getNumArgs();
1028 
1029   bool hasLocationDirective;
1030   std::string locToUse;
1031   std::tie(hasLocationDirective, locToUse) = getLocation(tree);
1032 
1033   auto inPattern = numPatArgs - hasLocationDirective;
1034   if (numOpArgs != inPattern) {
1035     PrintFatalError(loc,
1036                     formatv("resultant op '{0}' argument number mismatch: "
1037                             "{1} in pattern vs. {2} in definition",
1038                             resultOp.getOperationName(), inPattern, numOpArgs));
1039   }
1040 
1041   // A map to collect all nested DAG child nodes' names, with operand index as
1042   // the key. This includes both bound and unbound child nodes.
1043   ChildNodeIndexNameMap childNodeNames;
1044 
1045   // First go through all the child nodes who are nested DAG constructs to
1046   // create ops for them and remember the symbol names for them, so that we can
1047   // use the results in the current node. This happens in a recursive manner.
1048   for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) {
1049     if (auto child = tree.getArgAsNestedDag(i))
1050       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1051   }
1052 
1053   // The name of the local variable holding this op.
1054   std::string valuePackName;
1055   // The symbol for holding the result of this pattern. Note that the result of
1056   // this pattern is not necessarily the same as the variable created by this
1057   // pattern because we can use `__N` suffix to refer only a specific result if
1058   // the generated op is a multi-result op.
1059   std::string resultValue;
1060   if (tree.getSymbol().empty()) {
1061     // No symbol is explicitly bound to this op in the pattern. Generate a
1062     // unique name.
1063     valuePackName = resultValue = getUniqueSymbol(&resultOp);
1064   } else {
1065     resultValue = std::string(tree.getSymbol());
1066     // Strip the index to get the name for the value pack and use it to name the
1067     // local variable for the op.
1068     valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1069   }
1070 
1071   // Create the local variable for this op.
1072   os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1073                 valuePackName);
1074 
1075   // Right now ODS don't have general type inference support. Except a few
1076   // special cases listed below, DRR needs to supply types for all results
1077   // when building an op.
1078   bool isSameOperandsAndResultType =
1079       resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1080   bool useFirstAttr =
1081       resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1082 
1083   if (isSameOperandsAndResultType || useFirstAttr) {
1084     // We know how to deduce the result type for ops with these traits and we've
1085     // generated builders taking aggregate parameters. Use those builders to
1086     // create the ops.
1087 
1088     // First prepare local variables for op arguments used in builder call.
1089     createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1090 
1091     // Then create the op.
1092     os.scope("", "\n}\n").os << formatv(
1093         "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1094         valuePackName, resultOp.getQualCppClassName(), locToUse);
1095     return resultValue;
1096   }
1097 
1098   bool usePartialResults = valuePackName != resultValue;
1099 
1100   if (usePartialResults || depth > 0 || resultIndex < 0) {
1101     // For these cases (broadcastable ops, op results used both as auxiliary
1102     // values and replacement values, ops in nested patterns, auxiliary ops), we
1103     // still need to supply the result types when building the op. But because
1104     // we don't generate a builder automatically with ODS for them, it's the
1105     // developer's responsibility to make sure such a builder (with result type
1106     // deduction ability) exists. We go through the separate-parameter builder
1107     // here given that it's easier for developers to write compared to
1108     // aggregate-parameter builders.
1109     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1110 
1111     os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1112                              resultOp.getQualCppClassName(), locToUse);
1113     supplyValuesForOpArgs(tree, childNodeNames, depth);
1114     os << "\n  );\n}\n";
1115     return resultValue;
1116   }
1117 
1118   // If depth == 0 and resultIndex >= 0, it means we are replacing the values
1119   // generated from the source pattern root op. Then we can use the source
1120   // pattern's value types to determine the value type of the generated op
1121   // here.
1122 
1123   // First prepare local variables for op arguments used in builder call.
1124   createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1125 
1126   // Then prepare the result types. We need to specify the types for all
1127   // results.
1128   os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
1129                          "(void)tblgen_types;\n");
1130   int numResults = resultOp.getNumResults();
1131   if (numResults != 0) {
1132     for (int i = 0; i < numResults; ++i)
1133       os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1134                     "  tblgen_types.push_back(v.getType());\n}\n",
1135                     resultIndex + i);
1136   }
1137   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1138                 "tblgen_values, tblgen_attrs);\n",
1139                 valuePackName, resultOp.getQualCppClassName(), locToUse);
1140   os.unindent() << "}\n";
1141   return resultValue;
1142 }
1143 
createSeparateLocalVarsForOpArgs(DagNode node,ChildNodeIndexNameMap & childNodeNames)1144 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1145     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1146   Operator &resultOp = node.getDialectOp(opMap);
1147 
1148   // Now prepare operands used for building this op:
1149   // * If the operand is non-variadic, we create a `Value` local variable.
1150   // * If the operand is variadic, we create a `SmallVector<Value>` local
1151   //   variable.
1152 
1153   int valueIndex = 0; // An index for uniquing local variable names.
1154   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1155     const auto *operand =
1156         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
1157     // We do not need special handling for attributes.
1158     if (!operand)
1159       continue;
1160 
1161     raw_indented_ostream::DelimitedScope scope(os);
1162     std::string varName;
1163     if (operand->isVariadic()) {
1164       varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1165       os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
1166       std::string range;
1167       if (node.isNestedDagArg(argIndex)) {
1168         range = childNodeNames[argIndex];
1169       } else {
1170         range = std::string(node.getArgName(argIndex));
1171       }
1172       // Resolve the symbol for all range use so that we have a uniform way of
1173       // capturing the values.
1174       range = symbolInfoMap.getValueAndRangeUse(range);
1175       os << formatv("for (auto v: {0}) {{\n  {1}.push_back(v);\n}\n", range,
1176                     varName);
1177     } else {
1178       varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1179       os << formatv("::mlir::Value {0} = ", varName);
1180       if (node.isNestedDagArg(argIndex)) {
1181         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1182       } else {
1183         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1184         auto symbol =
1185             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1186         if (leaf.isNativeCodeCall()) {
1187           os << std::string(
1188               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1189         } else {
1190           os << symbol;
1191         }
1192       }
1193       os << ";\n";
1194     }
1195 
1196     // Update to use the newly created local variable for building the op later.
1197     childNodeNames[argIndex] = varName;
1198   }
1199 }
1200 
supplyValuesForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1201 void PatternEmitter::supplyValuesForOpArgs(
1202     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1203   Operator &resultOp = node.getDialectOp(opMap);
1204   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1205        argIndex != numOpArgs; ++argIndex) {
1206     // Start each argument on its own line.
1207     os << ",\n    ";
1208 
1209     Argument opArg = resultOp.getArg(argIndex);
1210     // Handle the case of operand first.
1211     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
1212       if (!operand->name.empty())
1213         os << "/*" << operand->name << "=*/";
1214       os << childNodeNames.lookup(argIndex);
1215       continue;
1216     }
1217 
1218     // The argument in the op definition.
1219     auto opArgName = resultOp.getArgName(argIndex);
1220     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1221       if (!subTree.isNativeCodeCall())
1222         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1223                              "for creating attribute");
1224       os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
1225     } else {
1226       auto leaf = node.getArgAsLeaf(argIndex);
1227       // The argument in the result DAG pattern.
1228       auto patArgName = node.getArgName(argIndex);
1229       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1230         // TODO: Refactor out into map to avoid recomputing these.
1231         if (!opArg.is<NamedAttribute *>())
1232           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1233         if (!patArgName.empty())
1234           os << "/*" << patArgName << "=*/";
1235       } else {
1236         os << "/*" << opArgName << "=*/";
1237       }
1238       os << handleOpArgument(leaf, patArgName);
1239     }
1240   }
1241 }
1242 
createAggregateLocalVarsForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1243 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1244     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1245   Operator &resultOp = node.getDialectOp(opMap);
1246 
1247   auto scope = os.scope();
1248   os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
1249                 "tblgen_values; (void)tblgen_values;\n");
1250   os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
1251                 "tblgen_attrs; (void)tblgen_attrs;\n");
1252 
1253   const char *addAttrCmd =
1254       "if (auto tmpAttr = {1}) {\n"
1255       "  tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
1256       "tmpAttr);\n}\n";
1257   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1258     if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1259       // The argument in the op definition.
1260       auto opArgName = resultOp.getArgName(argIndex);
1261       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1262         if (!subTree.isNativeCodeCall())
1263           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1264                                "for creating attribute");
1265         os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
1266       } else {
1267         auto leaf = node.getArgAsLeaf(argIndex);
1268         // The argument in the result DAG pattern.
1269         auto patArgName = node.getArgName(argIndex);
1270         os << formatv(addAttrCmd, opArgName,
1271                       handleOpArgument(leaf, patArgName));
1272       }
1273       continue;
1274     }
1275 
1276     const auto *operand =
1277         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1278     std::string varName;
1279     if (operand->isVariadic()) {
1280       std::string range;
1281       if (node.isNestedDagArg(argIndex)) {
1282         range = childNodeNames.lookup(argIndex);
1283       } else {
1284         range = std::string(node.getArgName(argIndex));
1285       }
1286       // Resolve the symbol for all range use so that we have a uniform way of
1287       // capturing the values.
1288       range = symbolInfoMap.getValueAndRangeUse(range);
1289       os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
1290                     range);
1291     } else {
1292       os << formatv("tblgen_values.push_back(");
1293       if (node.isNestedDagArg(argIndex)) {
1294         os << symbolInfoMap.getValueAndRangeUse(
1295             childNodeNames.lookup(argIndex));
1296       } else {
1297         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1298         auto symbol =
1299             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1300         if (leaf.isNativeCodeCall()) {
1301           os << std::string(
1302               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1303         } else {
1304           os << symbol;
1305         }
1306       }
1307       os << ");\n";
1308     }
1309   }
1310 }
1311 
emitRewriters(const RecordKeeper & recordKeeper,raw_ostream & os)1312 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1313   emitSourceFileHeader("Rewriters", os);
1314 
1315   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1316   auto numPatterns = patterns.size();
1317 
1318   // We put the map here because it can be shared among multiple patterns.
1319   RecordOperatorMap recordOpMap;
1320 
1321   std::vector<std::string> rewriterNames;
1322   rewriterNames.reserve(numPatterns);
1323 
1324   std::string baseRewriterName = "GeneratedConvert";
1325   int rewriterIndex = 0;
1326 
1327   for (Record *p : patterns) {
1328     std::string name;
1329     if (p->isAnonymous()) {
1330       // If no name is provided, ensure unique rewriter names simply by
1331       // appending unique suffix.
1332       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1333     } else {
1334       name = std::string(p->getName());
1335     }
1336     LLVM_DEBUG(llvm::dbgs()
1337                << "=== start generating pattern '" << name << "' ===\n");
1338     PatternEmitter(p, &recordOpMap, os).emit(name);
1339     LLVM_DEBUG(llvm::dbgs()
1340                << "=== done generating pattern '" << name << "' ===\n");
1341     rewriterNames.push_back(std::move(name));
1342   }
1343 
1344   // Emit function to add the generated matchers to the pattern list.
1345   os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
1346         "::mlir::RewritePatternSet &patterns) {\n";
1347   for (const auto &name : rewriterNames) {
1348     os << "  patterns.add<" << name << ">(patterns.getContext());\n";
1349   }
1350   os << "}\n";
1351 }
1352 
1353 static mlir::GenRegistration
1354     genRewriters("gen-rewriters", "Generate pattern rewriters",
__anonb034f2370602(const RecordKeeper &records, raw_ostream &os) 1355                  [](const RecordKeeper &records, raw_ostream &os) {
1356                    emitRewriters(records, os);
1357                    return false;
1358                  });
1359