1 //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // RewriterGen uses pattern rewrite definitions to generate rewriter matchers.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Support/IndentedOstream.h"
14 #include "mlir/TableGen/Attribute.h"
15 #include "mlir/TableGen/Format.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "mlir/TableGen/Pattern.h"
19 #include "mlir/TableGen/Predicate.h"
20 #include "mlir/TableGen/Type.h"
21 #include "llvm/ADT/FunctionExtras.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatAdapters.h"
27 #include "llvm/Support/PrettyStackTrace.h"
28 #include "llvm/Support/Signals.h"
29 #include "llvm/TableGen/Error.h"
30 #include "llvm/TableGen/Main.h"
31 #include "llvm/TableGen/Record.h"
32 #include "llvm/TableGen/TableGenBackend.h"
33 
34 using namespace mlir;
35 using namespace mlir::tblgen;
36 
37 using llvm::formatv;
38 using llvm::Record;
39 using llvm::RecordKeeper;
40 
41 #define DEBUG_TYPE "mlir-tblgen-rewritergen"
42 
43 namespace llvm {
44 template <>
45 struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
formatllvm::format_provider46   static void format(const mlir::tblgen::Pattern::IdentifierLine &v,
47                      raw_ostream &os, StringRef style) {
48     os << v.first << ":" << v.second;
49   }
50 };
51 } // end namespace llvm
52 
53 //===----------------------------------------------------------------------===//
54 // PatternEmitter
55 //===----------------------------------------------------------------------===//
56 
57 namespace {
58 
59 class StaticMatcherHelper;
60 
61 class PatternEmitter {
62 public:
63   PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os,
64                  StaticMatcherHelper &helper);
65 
66   // Emits the mlir::RewritePattern struct named `rewriteName`.
67   void emit(StringRef rewriteName);
68 
69   // Emits the static function of DAG matcher.
70   void emitStaticMatcher(DagNode tree, std::string funcName);
71 
72 private:
73   // Emits the code for matching ops.
74   void emitMatchLogic(DagNode tree, StringRef opName);
75 
76   // Emits the code for rewriting ops.
77   void emitRewriteLogic();
78 
79   //===--------------------------------------------------------------------===//
80   // Match utilities
81   //===--------------------------------------------------------------------===//
82 
83   // Emits C++ statements for matching the DAG structure.
84   void emitMatch(DagNode tree, StringRef name, int depth);
85 
86   // Emit C++ function call to static DAG matcher.
87   void emitStaticMatchCall(DagNode tree, StringRef name);
88 
89   // Emits C++ statements for matching using a native code call.
90   void emitNativeCodeMatch(DagNode tree, StringRef name, int depth);
91 
92   // Emits C++ statements for matching the op constrained by the given DAG
93   // `tree` returning the op's variable name.
94   void emitOpMatch(DagNode tree, StringRef opName, int depth);
95 
96   // Emits C++ statements for matching the `argIndex`-th argument of the given
97   // DAG `tree` as an operand. operandIndex is the index in the DAG excluding
98   // the preceding attributes.
99   void emitOperandMatch(DagNode tree, StringRef opName, int argIndex,
100                         int operandIndex, int depth);
101 
102   // Emits C++ statements for matching the `argIndex`-th argument of the given
103   // DAG `tree` as an attribute.
104   void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex,
105                           int depth);
106 
107   // Emits C++ for checking a match with a corresponding match failure
108   // diagnostic.
109   void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt,
110                       const llvm::formatv_object_base &failureFmt);
111 
112   // Emits C++ for checking a match with a corresponding match failure
113   // diagnostics.
114   void emitMatchCheck(StringRef opName, const std::string &matchStr,
115                       const std::string &failureStr);
116 
117   //===--------------------------------------------------------------------===//
118   // Rewrite utilities
119   //===--------------------------------------------------------------------===//
120 
121   // The entry point for handling a result pattern rooted at `resultTree`. This
122   // method dispatches to concrete handlers according to `resultTree`'s kind and
123   // returns a symbol representing the whole value pack. Callers are expected to
124   // further resolve the symbol according to the specific use case.
125   //
126   // `depth` is the nesting level of `resultTree`; 0 means top-level result
127   // pattern. For top-level result pattern, `resultIndex` indicates which result
128   // of the matched root op this pattern is intended to replace, which can be
129   // used to deduce the result type of the op generated from this result
130   // pattern.
131   std::string handleResultPattern(DagNode resultTree, int resultIndex,
132                                   int depth);
133 
134   // Emits the C++ statement to replace the matched DAG with a value built via
135   // calling native C++ code.
136   std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth);
137 
138   // Returns the symbol of the old value serving as the replacement.
139   StringRef handleReplaceWithValue(DagNode tree);
140 
141   // Trailing directives are used at the end of DAG node argument lists to
142   // specify additional behaviour for op matchers and creators, etc.
143   struct TrailingDirectives {
144     // DAG node containing the `location` directive. Null if there is none.
145     DagNode location;
146 
147     // DAG node containing the `returnType` directive. Null if there is none.
148     DagNode returnType;
149 
150     // Number of found trailing directives.
151     int numDirectives;
152   };
153 
154   // Collect any trailing directives.
155   TrailingDirectives getTrailingDirectives(DagNode tree);
156 
157   // Returns the location value to use.
158   std::string getLocation(TrailingDirectives &tail);
159 
160   // Returns the location value to use.
161   std::string handleLocationDirective(DagNode tree);
162 
163   // Emit return type argument.
164   std::string handleReturnTypeArg(DagNode returnType, int i, int depth);
165 
166   // Emits the C++ statement to build a new op out of the given DAG `tree` and
167   // returns the variable name that this op is assigned to. If the root op in
168   // DAG `tree` has a specified name, the created op will be assigned to a
169   // variable of the given name. Otherwise, a unique name will be used as the
170   // result value name.
171   std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
172 
173   using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
174 
175   // Emits a local variable for each value and attribute to be used for creating
176   // an op.
177   void createSeparateLocalVarsForOpArgs(DagNode node,
178                                         ChildNodeIndexNameMap &childNodeNames);
179 
180   // Emits the concrete arguments used to call an op's builder.
181   void supplyValuesForOpArgs(DagNode node,
182                              const ChildNodeIndexNameMap &childNodeNames,
183                              int depth);
184 
185   // Emits the local variables for holding all values as a whole and all named
186   // attributes as a whole to be used for creating an op.
187   void createAggregateLocalVarsForOpArgs(
188       DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth);
189 
190   // Returns the C++ expression to construct a constant attribute of the given
191   // `value` for the given attribute kind `attr`.
192   std::string handleConstantAttr(Attribute attr, StringRef value);
193 
194   // Returns the C++ expression to build an argument from the given DAG `leaf`.
195   // `patArgName` is used to bound the argument to the source pattern.
196   std::string handleOpArgument(DagLeaf leaf, StringRef patArgName);
197 
198   //===--------------------------------------------------------------------===//
199   // General utilities
200   //===--------------------------------------------------------------------===//
201 
202   // Collects all of the operations within the given dag tree.
203   void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
204 
205   // Returns a unique symbol for a local variable of the given `op`.
206   std::string getUniqueSymbol(const Operator *op);
207 
208   //===--------------------------------------------------------------------===//
209   // Symbol utilities
210   //===--------------------------------------------------------------------===//
211 
212   // Returns how many static values the given DAG `node` correspond to.
213   int getNodeValueCount(DagNode node);
214 
215 private:
216   // Pattern instantiation location followed by the location of multiclass
217   // prototypes used. This is intended to be used as a whole to
218   // PrintFatalError() on errors.
219   ArrayRef<llvm::SMLoc> loc;
220 
221   // Op's TableGen Record to wrapper object.
222   RecordOperatorMap *opMap;
223 
224   // Handy wrapper for pattern being emitted.
225   Pattern pattern;
226 
227   // Map for all bound symbols' info.
228   SymbolInfoMap symbolInfoMap;
229 
230   StaticMatcherHelper &staticMatcherHelper;
231 
232   // The next unused ID for newly created values.
233   unsigned nextValueId;
234 
235   raw_indented_ostream os;
236 
237   // Format contexts containing placeholder substitutions.
238   FmtContext fmtCtx;
239 };
240 
241 // Tracks DagNode's reference multiple times across patterns. Enables generating
242 // static matcher functions for DagNode's referenced multiple times rather than
243 // inlining them.
244 class StaticMatcherHelper {
245 public:
246   StaticMatcherHelper(RecordOperatorMap &mapper);
247 
248   // Determine if we should inline the match logic or delegate to a static
249   // function.
useStaticMatcher(DagNode node)250   bool useStaticMatcher(DagNode node) {
251     return refStats[node] > kStaticMatcherThreshold;
252   }
253 
254   // Get the name of the static DAG matcher function corresponding to the node.
getMatcherName(DagNode node)255   std::string getMatcherName(DagNode node) {
256     assert(useStaticMatcher(node));
257     return matcherNames[node];
258   }
259 
260   // Collect the `Record`s, i.e., the DRR, so that we can get the information of
261   // the duplicated DAGs.
262   void addPattern(Record *record);
263 
264   // Emit all static functions of DAG Matcher.
265   void populateStaticMatchers(raw_ostream &os);
266 
267 private:
268   static constexpr unsigned kStaticMatcherThreshold = 1;
269 
270   // Consider two patterns as down below,
271   //   DagNode_Root_A    DagNode_Root_B
272   //       \                 \
273   //     DagNode_C         DagNode_C
274   //         \                 \
275   //       DagNode_D         DagNode_D
276   //
277   // DagNode_Root_A and DagNode_Root_B share the same subtree which consists of
278   // DagNode_C and DagNode_D. Both DagNode_C and DagNode_D are referenced
279   // multiple times so we'll have static matchers for both of them. When we're
280   // emitting the match logic for DagNode_C, we will check if DagNode_D has the
281   // static matcher generated. If so, then we'll generate a call to the
282   // function, inline otherwise. In this case, inlining is not what we want. As
283   // a result, generate the static matcher in topological order to ensure all
284   // the dependent static matchers are generated and we can avoid accidentally
285   // inlining.
286   //
287   // The topological order of all the DagNodes among all patterns.
288   SmallVector<std::pair<DagNode, Record *>> topologicalOrder;
289 
290   RecordOperatorMap &opMap;
291 
292   // Records of the static function name of each DagNode
293   DenseMap<DagNode, std::string> matcherNames;
294 
295   // After collecting all the DagNode in each pattern, `refStats` records the
296   // number of users for each DagNode. We will generate the static matcher for a
297   // DagNode while the number of users exceeds a certain threshold.
298   DenseMap<DagNode, unsigned> refStats;
299 
300   // Number of static matcher generated. This is used to generate a unique name
301   // for each DagNode.
302   int staticMatcherCounter = 0;
303 };
304 
305 } // end anonymous namespace
306 
PatternEmitter(Record * pat,RecordOperatorMap * mapper,raw_ostream & os,StaticMatcherHelper & helper)307 PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
308                                raw_ostream &os, StaticMatcherHelper &helper)
309     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
310       symbolInfoMap(pat->getLoc()), staticMatcherHelper(helper), nextValueId(0),
311       os(os) {
312   fmtCtx.withBuilder("rewriter");
313 }
314 
handleConstantAttr(Attribute attr,StringRef value)315 std::string PatternEmitter::handleConstantAttr(Attribute attr,
316                                                StringRef value) {
317   if (!attr.isConstBuildable())
318     PrintFatalError(loc, "Attribute " + attr.getAttrDefName() +
319                              " does not have the 'constBuilderCall' field");
320 
321   // TODO: Verify the constants here
322   return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value));
323 }
324 
emitStaticMatcher(DagNode tree,std::string funcName)325 void PatternEmitter::emitStaticMatcher(DagNode tree, std::string funcName) {
326   os << formatv(
327       "static ::mlir::LogicalResult {0}(::mlir::PatternRewriter &rewriter, "
328       "::mlir::Operation *op0, ::llvm::SmallVector<::mlir::Operation "
329       "*, 4> &tblgen_ops",
330       funcName);
331 
332   // We pass the reference of the variables that need to be captured. Hence we
333   // need to collect all the symbols in the tree first.
334   pattern.collectBoundSymbols(tree, symbolInfoMap, /*isSrcPattern=*/true);
335   symbolInfoMap.assignUniqueAlternativeNames();
336   for (const auto &info : symbolInfoMap)
337     os << formatv(", {0}", info.second.getArgDecl(info.first));
338 
339   os << ") {\n";
340   os.indent();
341   os << "(void)tblgen_ops;\n";
342 
343   // Note that a static matcher is considered at least one step from the match
344   // entry.
345   emitMatch(tree, "op0", /*depth=*/1);
346 
347   os << "return ::mlir::success();\n";
348   os.unindent();
349   os << "}\n\n";
350 }
351 
352 // Helper function to match patterns.
emitMatch(DagNode tree,StringRef name,int depth)353 void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) {
354   if (tree.isNativeCodeCall()) {
355     emitNativeCodeMatch(tree, name, depth);
356     return;
357   }
358 
359   if (tree.isOperation()) {
360     emitOpMatch(tree, name, depth);
361     return;
362   }
363 
364   PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match.");
365 }
366 
emitStaticMatchCall(DagNode tree,StringRef opName)367 void PatternEmitter::emitStaticMatchCall(DagNode tree, StringRef opName) {
368   std::string funcName = staticMatcherHelper.getMatcherName(tree);
369   os << formatv("if(failed({0}(rewriter, {1}, tblgen_ops", funcName, opName);
370 
371   // TODO(chiahungduan): Add a lookupBoundSymbols() to do the subtree lookup in
372   // one pass.
373 
374   // In general, bound symbol should have the unique name in the pattern but
375   // for the operand, binding same symbol to multiple operands imply a
376   // constraint at the same time. In this case, we will rename those operands
377   // with different names. As a result, we need to collect all the symbolInfos
378   // from the DagNode then get the updated name of the local variables from the
379   // global symbolInfoMap.
380 
381   // Collect all the bound symbols in the Dag
382   SymbolInfoMap localSymbolMap(loc);
383   pattern.collectBoundSymbols(tree, localSymbolMap, /*isSrcPattern=*/true);
384 
385   for (const auto &info : localSymbolMap) {
386     auto name = info.first;
387     auto symboInfo = info.second;
388     auto ret = symbolInfoMap.findBoundSymbol(name, symboInfo);
389     os << formatv(", {0}", ret->second.getVarName(name));
390   }
391 
392   os << "))) {\n";
393   os.scope().os << "return ::mlir::failure();\n";
394   os << "}\n";
395 }
396 
397 // Helper function to match patterns.
emitNativeCodeMatch(DagNode tree,StringRef opName,int depth)398 void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName,
399                                          int depth) {
400   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: ");
401   LLVM_DEBUG(tree.print(llvm::dbgs()));
402   LLVM_DEBUG(llvm::dbgs() << '\n');
403 
404   // The order of generating static matcher follows the topological order so
405   // that for every dependent DagNode already have their static matcher
406   // generated if needed. The reason we check if `getMatcherName(tree).empty()`
407   // is when we are generating the static matcher for a DagNode itself. In this
408   // case, we need to emit the function body rather than a function call.
409   if (staticMatcherHelper.useStaticMatcher(tree) &&
410       !staticMatcherHelper.getMatcherName(tree).empty()) {
411     emitStaticMatchCall(tree, opName);
412 
413     // NativeCodeCall will never be at depth 0 so that we don't need to catch
414     // the root operation as emitOpMatch();
415 
416     return;
417   }
418 
419   // TODO(suderman): iterate through arguments, determine their types, output
420   // names.
421   SmallVector<std::string, 8> capture;
422 
423   raw_indented_ostream::DelimitedScope scope(os);
424 
425   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
426     std::string argName = formatv("arg{0}_{1}", depth, i);
427     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
428       os << "Value " << argName << ";\n";
429     } else {
430       auto leaf = tree.getArgAsLeaf(i);
431       if (leaf.isAttrMatcher() || leaf.isConstantAttr()) {
432         os << "Attribute " << argName << ";\n";
433       } else {
434         os << "Value " << argName << ";\n";
435       }
436     }
437 
438     capture.push_back(std::move(argName));
439   }
440 
441   auto tail = getTrailingDirectives(tree);
442   if (tail.returnType)
443     PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
444   auto locToUse = getLocation(tail);
445 
446   auto fmt = tree.getNativeCodeTemplate();
447   if (fmt.count("$_self") != 1)
448     PrintFatalError(loc, "NativeCodeCall must have $_self as argument for "
449                          "passing the defining Operation");
450 
451   auto nativeCodeCall = std::string(
452       tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()),
453             static_cast<ArrayRef<std::string>>(capture)));
454 
455   emitMatchCheck(opName, formatv("!failed({0})", nativeCodeCall),
456                  formatv("\"{0} return failure\"", nativeCodeCall));
457 
458   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
459     auto name = tree.getArgName(i);
460     if (!name.empty() && name != "_") {
461       os << formatv("{0} = {1};\n", name, capture[i]);
462     }
463   }
464 
465   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
466     std::string argName = capture[i];
467 
468     // Handle nested DAG construct first
469     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
470       PrintFatalError(
471           loc, formatv("Matching nested tree in NativeCodecall not support for "
472                        "{0} as arg {1}",
473                        argName, i));
474     }
475 
476     DagLeaf leaf = tree.getArgAsLeaf(i);
477 
478     // The parameter for native function doesn't bind any constraints.
479     if (leaf.isUnspecified())
480       continue;
481 
482     auto constraint = leaf.getAsConstraint();
483 
484     std::string self;
485     if (leaf.isAttrMatcher() || leaf.isConstantAttr())
486       self = argName;
487     else
488       self = formatv("{0}.getType()", argName);
489     emitMatchCheck(
490         opName,
491         tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
492         formatv("\"operand {0} of native code call '{1}' failed to satisfy "
493                 "constraint: "
494                 "'{2}'\"",
495                 i, tree.getNativeCodeTemplate(), constraint.getSummary()));
496   }
497 
498   LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n");
499 }
500 
501 // Helper function to match patterns.
emitOpMatch(DagNode tree,StringRef opName,int depth)502 void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) {
503   Operator &op = tree.getDialectOp(opMap);
504   LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '"
505                           << op.getOperationName() << "' at depth " << depth
506                           << '\n');
507 
508   auto getCastedName = [depth]() -> std::string {
509     return formatv("castedOp{0}", depth);
510   };
511 
512   // The order of generating static matcher follows the topological order so
513   // that for every dependent DagNode already have their static matcher
514   // generated if needed. The reason we check if `getMatcherName(tree).empty()`
515   // is when we are generating the static matcher for a DagNode itself. In this
516   // case, we need to emit the function body rather than a function call.
517   if (staticMatcherHelper.useStaticMatcher(tree) &&
518       !staticMatcherHelper.getMatcherName(tree).empty()) {
519     emitStaticMatchCall(tree, opName);
520     // In the codegen of rewriter, we suppose that castedOp0 will capture the
521     // root operation. Manually add it if the root DagNode is a static matcher.
522     if (depth == 0)
523       os << formatv("auto {2} = ::llvm::dyn_cast_or_null<{1}>({0}); "
524                     "(void){2};\n",
525                     opName, op.getQualCppClassName(), getCastedName());
526     return;
527   }
528 
529   std::string castedName = getCastedName();
530   os << formatv("auto {0} = ::llvm::dyn_cast<{2}>({1}); "
531                 "(void){0};\n",
532                 castedName, opName, op.getQualCppClassName());
533 
534   // Skip the operand matching at depth 0 as the pattern rewriter already does.
535   if (depth != 0)
536     emitMatchCheck(opName, /*matchStr=*/castedName,
537                    formatv("\"{0} is not {1} type\"", castedName,
538                            op.getQualCppClassName()));
539 
540   if (tree.getNumArgs() != op.getNumArgs())
541     PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
542                                  "pattern vs. {2} in definition",
543                                  op.getOperationName(), tree.getNumArgs(),
544                                  op.getNumArgs()));
545 
546   // If the operand's name is set, set to that variable.
547   auto name = tree.getSymbol();
548   if (!name.empty())
549     os << formatv("{0} = {1};\n", name, castedName);
550 
551   for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) {
552     auto opArg = op.getArg(i);
553     std::string argName = formatv("op{0}", depth + 1);
554 
555     // Handle nested DAG construct first
556     if (DagNode argTree = tree.getArgAsNestedDag(i)) {
557       if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
558         if (operand->isVariableLength()) {
559           auto error = formatv("use nested DAG construct to match op {0}'s "
560                                "variadic operand #{1} unsupported now",
561                                op.getOperationName(), i);
562           PrintFatalError(loc, error);
563         }
564       }
565       os << "{\n";
566 
567       // Attributes don't count for getODSOperands.
568       // TODO: Operand is a Value, check if we should remove `getDefiningOp()`.
569       os.indent() << formatv(
570           "auto *{0} = "
571           "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
572           argName, castedName, nextOperand);
573       // Null check of operand's definingOp
574       emitMatchCheck(castedName, /*matchStr=*/argName,
575                      formatv("\"Operand {0} of {1} has null definingOp\"",
576                              nextOperand++, castedName));
577       emitMatch(argTree, argName, depth + 1);
578       os << formatv("tblgen_ops.push_back({0});\n", argName);
579       os.unindent() << "}\n";
580       continue;
581     }
582 
583     // Next handle DAG leaf: operand or attribute
584     if (opArg.is<NamedTypeConstraint *>()) {
585       // emitOperandMatch's argument indexing counts attributes.
586       emitOperandMatch(tree, castedName, i, nextOperand, depth);
587       ++nextOperand;
588     } else if (opArg.is<NamedAttribute *>()) {
589       emitAttributeMatch(tree, opName, i, depth);
590     } else {
591       PrintFatalError(loc, "unhandled case when matching op");
592     }
593   }
594   LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '"
595                           << op.getOperationName() << "' at depth " << depth
596                           << '\n');
597 }
598 
emitOperandMatch(DagNode tree,StringRef opName,int argIndex,int operandIndex,int depth)599 void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName,
600                                       int argIndex, int operandIndex,
601                                       int depth) {
602   Operator &op = tree.getDialectOp(opMap);
603   auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
604   auto matcher = tree.getArgAsLeaf(argIndex);
605 
606   // If a constraint is specified, we need to generate C++ statements to
607   // check the constraint.
608   if (!matcher.isUnspecified()) {
609     if (!matcher.isOperandMatcher()) {
610       PrintFatalError(
611           loc, formatv("the {1}-th argument of op '{0}' should be an operand",
612                        op.getOperationName(), argIndex + 1));
613     }
614 
615     // Only need to verify if the matcher's type is different from the one
616     // of op definition.
617     Constraint constraint = matcher.getAsConstraint();
618     if (operand->constraint != constraint) {
619       if (operand->isVariableLength()) {
620         auto error = formatv(
621             "further constrain op {0}'s variadic operand #{1} unsupported now",
622             op.getOperationName(), argIndex);
623         PrintFatalError(loc, error);
624       }
625       auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()",
626                           opName, operandIndex);
627       emitMatchCheck(
628           opName,
629           tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)),
630           formatv("\"operand {0} of op '{1}' failed to satisfy constraint: "
631                   "'{2}'\"",
632                   operand - op.operand_begin(), op.getOperationName(),
633                   constraint.getSummary()));
634     }
635   }
636 
637   // Capture the value
638   auto name = tree.getArgName(argIndex);
639   // `$_` is a special symbol to ignore op argument matching.
640   if (!name.empty() && name != "_") {
641     // We need to subtract the number of attributes before this operand to get
642     // the index in the operand list.
643     auto numPrevAttrs = std::count_if(
644         op.arg_begin(), op.arg_begin() + argIndex,
645         [](const Argument &arg) { return arg.is<NamedAttribute *>(); });
646 
647     auto res = symbolInfoMap.findBoundSymbol(name, tree, op, argIndex);
648     os << formatv("{0} = {1}.getODSOperands({2});\n",
649                   res->second.getVarName(name), opName,
650                   argIndex - numPrevAttrs);
651   }
652 }
653 
emitAttributeMatch(DagNode tree,StringRef opName,int argIndex,int depth)654 void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName,
655                                         int argIndex, int depth) {
656   Operator &op = tree.getDialectOp(opMap);
657   auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
658   const auto &attr = namedAttr->attr;
659 
660   os << "{\n";
661   os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");"
662                          "(void)tblgen_attr;\n",
663                          opName, attr.getStorageType(), namedAttr->name);
664 
665   // TODO: This should use getter method to avoid duplication.
666   if (attr.hasDefaultValue()) {
667     os << "if (!tblgen_attr) tblgen_attr = "
668        << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
669                             attr.getDefaultValue()))
670        << ";\n";
671   } else if (attr.isOptional()) {
672     // For a missing attribute that is optional according to definition, we
673     // should just capture a mlir::Attribute() to signal the missing state.
674     // That is precisely what getAttr() returns on missing attributes.
675   } else {
676     emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx),
677                    formatv("\"expected op '{0}' to have attribute '{1}' "
678                            "of type '{2}'\"",
679                            op.getOperationName(), namedAttr->name,
680                            attr.getStorageType()));
681   }
682 
683   auto matcher = tree.getArgAsLeaf(argIndex);
684   if (!matcher.isUnspecified()) {
685     if (!matcher.isAttrMatcher()) {
686       PrintFatalError(
687           loc, formatv("the {1}-th argument of op '{0}' should be an attribute",
688                        op.getOperationName(), argIndex + 1));
689     }
690 
691     // If a constraint is specified, we need to generate C++ statements to
692     // check the constraint.
693     emitMatchCheck(
694         opName,
695         tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")),
696         formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: "
697                 "{2}\"",
698                 op.getOperationName(), namedAttr->name,
699                 matcher.getAsConstraint().getSummary()));
700   }
701 
702   // Capture the value
703   auto name = tree.getArgName(argIndex);
704   // `$_` is a special symbol to ignore op argument matching.
705   if (!name.empty() && name != "_") {
706     os << formatv("{0} = tblgen_attr;\n", name);
707   }
708 
709   os.unindent() << "}\n";
710 }
711 
emitMatchCheck(StringRef opName,const FmtObjectBase & matchFmt,const llvm::formatv_object_base & failureFmt)712 void PatternEmitter::emitMatchCheck(
713     StringRef opName, const FmtObjectBase &matchFmt,
714     const llvm::formatv_object_base &failureFmt) {
715   emitMatchCheck(opName, matchFmt.str(), failureFmt.str());
716 }
717 
emitMatchCheck(StringRef opName,const std::string & matchStr,const std::string & failureStr)718 void PatternEmitter::emitMatchCheck(StringRef opName,
719                                     const std::string &matchStr,
720                                     const std::string &failureStr) {
721 
722   os << "if (!(" << matchStr << "))";
723   os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName
724                               << ", [&](::mlir::Diagnostic &diag) {\n  diag << "
725                               << failureStr << ";\n});";
726 }
727 
emitMatchLogic(DagNode tree,StringRef opName)728 void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) {
729   LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n");
730   int depth = 0;
731   emitMatch(tree, opName, depth);
732 
733   for (auto &appliedConstraint : pattern.getConstraints()) {
734     auto &constraint = appliedConstraint.constraint;
735     auto &entities = appliedConstraint.entities;
736 
737     auto condition = constraint.getConditionTemplate();
738     if (isa<TypeConstraint>(constraint)) {
739       auto self = formatv("({0}.getType())",
740                           symbolInfoMap.getValueAndRangeUse(entities.front()));
741       emitMatchCheck(
742           opName, tgfmt(condition, &fmtCtx.withSelf(self.str())),
743           formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"",
744                   entities.front(), constraint.getSummary()));
745 
746     } else if (isa<AttrConstraint>(constraint)) {
747       PrintFatalError(
748           loc, "cannot use AttrConstraint in Pattern multi-entity constraints");
749     } else {
750       // TODO: replace formatv arguments with the exact specified
751       // args.
752       if (entities.size() > 4) {
753         PrintFatalError(loc, "only support up to 4-entity constraints now");
754       }
755       SmallVector<std::string, 4> names;
756       int i = 0;
757       for (int e = entities.size(); i < e; ++i)
758         names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i]));
759       std::string self = appliedConstraint.self;
760       if (!self.empty())
761         self = symbolInfoMap.getValueAndRangeUse(self);
762       for (; i < 4; ++i)
763         names.push_back("<unused>");
764       emitMatchCheck(opName,
765                      tgfmt(condition, &fmtCtx.withSelf(self), names[0],
766                            names[1], names[2], names[3]),
767                      formatv("\"entities '{0}' failed to satisfy constraint: "
768                              "{1}\"",
769                              llvm::join(entities, ", "),
770                              constraint.getSummary()));
771     }
772   }
773 
774   // Some of the operands could be bound to the same symbol name, we need
775   // to enforce equality constraint on those.
776   // TODO: we should be able to emit equality checks early
777   // and short circuit unnecessary work if vars are not equal.
778   for (auto symbolInfoIt = symbolInfoMap.begin();
779        symbolInfoIt != symbolInfoMap.end();) {
780     auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first);
781     auto startRange = range.first;
782     auto endRange = range.second;
783 
784     auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first);
785     for (++startRange; startRange != endRange; ++startRange) {
786       auto secondOperand = startRange->second.getVarName(symbolInfoIt->first);
787       emitMatchCheck(
788           opName,
789           formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand),
790           formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand,
791                   secondOperand));
792     }
793 
794     symbolInfoIt = endRange;
795   }
796 
797   LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n");
798 }
799 
collectOps(DagNode tree,llvm::SmallPtrSetImpl<const Operator * > & ops)800 void PatternEmitter::collectOps(DagNode tree,
801                                 llvm::SmallPtrSetImpl<const Operator *> &ops) {
802   // Check if this tree is an operation.
803   if (tree.isOperation()) {
804     const Operator &op = tree.getDialectOp(opMap);
805     LLVM_DEBUG(llvm::dbgs()
806                << "found operation " << op.getOperationName() << '\n');
807     ops.insert(&op);
808   }
809 
810   // Recurse the arguments of the tree.
811   for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i)
812     if (auto child = tree.getArgAsNestedDag(i))
813       collectOps(child, ops);
814 }
815 
emit(StringRef rewriteName)816 void PatternEmitter::emit(StringRef rewriteName) {
817   // Get the DAG tree for the source pattern.
818   DagNode sourceTree = pattern.getSourcePattern();
819 
820   const Operator &rootOp = pattern.getSourceRootOp();
821   auto rootName = rootOp.getOperationName();
822 
823   // Collect the set of result operations.
824   llvm::SmallPtrSet<const Operator *, 4> resultOps;
825   LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n");
826   for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) {
827     collectOps(pattern.getResultPattern(i), resultOps);
828   }
829   LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n");
830 
831   // Emit RewritePattern for Pattern.
832   auto locs = pattern.getLocation();
833   os << formatv("/* Generated from:\n    {0:$[ instantiating\n    ]}\n*/\n",
834                 make_range(locs.rbegin(), locs.rend()));
835   os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
836   {0}(::mlir::MLIRContext *context)
837       : ::mlir::RewritePattern("{1}", {2}, context, {{)",
838                 rewriteName, rootName, pattern.getBenefit());
839   // Sort result operators by name.
840   llvm::SmallVector<const Operator *, 4> sortedResultOps(resultOps.begin(),
841                                                          resultOps.end());
842   llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) {
843     return lhs->getOperationName() < rhs->getOperationName();
844   });
845   llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) {
846     os << '"' << op->getOperationName() << '"';
847   });
848   os << "}) {}\n";
849 
850   // Emit matchAndRewrite() function.
851   {
852     auto classScope = os.scope();
853     os.reindent(R"(
854     ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
855         ::mlir::PatternRewriter &rewriter) const override {)")
856         << '\n';
857     {
858       auto functionScope = os.scope();
859 
860       // Register all symbols bound in the source pattern.
861       pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
862 
863       LLVM_DEBUG(llvm::dbgs()
864                  << "start creating local variables for capturing matches\n");
865       os << "// Variables for capturing values and attributes used while "
866             "creating ops\n";
867       // Create local variables for storing the arguments and results bound
868       // to symbols.
869       for (const auto &symbolInfoPair : symbolInfoMap) {
870         const auto &symbol = symbolInfoPair.first;
871         const auto &info = symbolInfoPair.second;
872 
873         os << info.getVarDecl(symbol);
874       }
875       // TODO: capture ops with consistent numbering so that it can be
876       // reused for fused loc.
877       os << "::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops;\n\n";
878       LLVM_DEBUG(llvm::dbgs()
879                  << "done creating local variables for capturing matches\n");
880 
881       os << "// Match\n";
882       os << "tblgen_ops.push_back(op0);\n";
883       emitMatchLogic(sourceTree, "op0");
884 
885       os << "\n// Rewrite\n";
886       emitRewriteLogic();
887 
888       os << "return ::mlir::success();\n";
889     }
890     os << "};\n";
891   }
892   os << "};\n\n";
893 }
894 
emitRewriteLogic()895 void PatternEmitter::emitRewriteLogic() {
896   LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n");
897   const Operator &rootOp = pattern.getSourceRootOp();
898   int numExpectedResults = rootOp.getNumResults();
899   int numResultPatterns = pattern.getNumResultPatterns();
900 
901   // First register all symbols bound to ops generated in result patterns.
902   pattern.collectResultPatternBoundSymbols(symbolInfoMap);
903 
904   // Only the last N static values generated are used to replace the matched
905   // root N-result op. We need to calculate the starting index (of the results
906   // of the matched op) each result pattern is to replace.
907   SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
908   // If we don't need to replace any value at all, set the replacement starting
909   // index as the number of result patterns so we skip all of them when trying
910   // to replace the matched op's results.
911   int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
912   for (int i = numResultPatterns - 1; i >= 0; --i) {
913     auto numValues = getNodeValueCount(pattern.getResultPattern(i));
914     offsets[i] = offsets[i + 1] - numValues;
915     if (offsets[i] == 0) {
916       if (replStartIndex == -1)
917         replStartIndex = i;
918     } else if (offsets[i] < 0 && offsets[i + 1] > 0) {
919       auto error = formatv(
920           "cannot use the same multi-result op '{0}' to generate both "
921           "auxiliary values and values to be used for replacing the matched op",
922           pattern.getResultPattern(i).getSymbol());
923       PrintFatalError(loc, error);
924     }
925   }
926 
927   if (offsets.front() > 0) {
928     const char error[] = "no enough values generated to replace the matched op";
929     PrintFatalError(loc, error);
930   }
931 
932   os << "auto odsLoc = rewriter.getFusedLoc({";
933   for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
934     os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
935   }
936   os << "}); (void)odsLoc;\n";
937 
938   // Process auxiliary result patterns.
939   for (int i = 0; i < replStartIndex; ++i) {
940     DagNode resultTree = pattern.getResultPattern(i);
941     auto val = handleResultPattern(resultTree, offsets[i], 0);
942     // Normal op creation will be streamed to `os` by the above call; but
943     // NativeCodeCall will only be materialized to `os` if it is used. Here
944     // we are handling auxiliary patterns so we want the side effect even if
945     // NativeCodeCall is not replacing matched root op's results.
946     if (resultTree.isNativeCodeCall() &&
947         resultTree.getNumReturnsOfNativeCode() == 0)
948       os << val << ";\n";
949   }
950 
951   if (numExpectedResults == 0) {
952     assert(replStartIndex >= numResultPatterns &&
953            "invalid auxiliary vs. replacement pattern division!");
954     // No result to replace. Just erase the op.
955     os << "rewriter.eraseOp(op0);\n";
956   } else {
957     // Process replacement result patterns.
958     os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
959     for (int i = replStartIndex; i < numResultPatterns; ++i) {
960       DagNode resultTree = pattern.getResultPattern(i);
961       auto val = handleResultPattern(resultTree, offsets[i], 0);
962       os << "\n";
963       // Resolve each symbol for all range use so that we can loop over them.
964       // We need an explicit cast to `SmallVector` to capture the cases where
965       // `{0}` resolves to an `Operation::result_range` as well as cases that
966       // are not iterable (e.g. vector that gets wrapped in additional braces by
967       // RewriterGen).
968       // TODO: Revisit the need for materializing a vector.
969       os << symbolInfoMap.getAllRangeUse(
970           val,
971           "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
972           "  tblgen_repl_values.push_back(v);\n}\n",
973           "\n");
974     }
975     os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
976   }
977 
978   LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
979 }
980 
getUniqueSymbol(const Operator * op)981 std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
982   return std::string(
983       formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++));
984 }
985 
handleResultPattern(DagNode resultTree,int resultIndex,int depth)986 std::string PatternEmitter::handleResultPattern(DagNode resultTree,
987                                                 int resultIndex, int depth) {
988   LLVM_DEBUG(llvm::dbgs() << "handle result pattern: ");
989   LLVM_DEBUG(resultTree.print(llvm::dbgs()));
990   LLVM_DEBUG(llvm::dbgs() << '\n');
991 
992   if (resultTree.isLocationDirective()) {
993     PrintFatalError(loc,
994                     "location directive can only be used with op creation");
995   }
996 
997   if (resultTree.isNativeCodeCall())
998     return handleReplaceWithNativeCodeCall(resultTree, depth);
999 
1000   if (resultTree.isReplaceWithValue())
1001     return handleReplaceWithValue(resultTree).str();
1002 
1003   // Normal op creation.
1004   auto symbol = handleOpCreation(resultTree, resultIndex, depth);
1005   if (resultTree.getSymbol().empty()) {
1006     // This is an op not explicitly bound to a symbol in the rewrite rule.
1007     // Register the auto-generated symbol for it.
1008     symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
1009   }
1010   return symbol;
1011 }
1012 
handleReplaceWithValue(DagNode tree)1013 StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) {
1014   assert(tree.isReplaceWithValue());
1015 
1016   if (tree.getNumArgs() != 1) {
1017     PrintFatalError(
1018         loc, "replaceWithValue directive must take exactly one argument");
1019   }
1020 
1021   if (!tree.getSymbol().empty()) {
1022     PrintFatalError(loc, "cannot bind symbol to replaceWithValue");
1023   }
1024 
1025   return tree.getArgName(0);
1026 }
1027 
handleLocationDirective(DagNode tree)1028 std::string PatternEmitter::handleLocationDirective(DagNode tree) {
1029   assert(tree.isLocationDirective());
1030   auto lookUpArgLoc = [this, &tree](int idx) {
1031     const auto *const lookupFmt = "(*{0}.begin()).getLoc()";
1032     return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt);
1033   };
1034 
1035   if (tree.getNumArgs() == 0)
1036     llvm::PrintFatalError(
1037         "At least one argument to location directive required");
1038 
1039   if (!tree.getSymbol().empty())
1040     PrintFatalError(loc, "cannot bind symbol to location");
1041 
1042   if (tree.getNumArgs() == 1) {
1043     DagLeaf leaf = tree.getArgAsLeaf(0);
1044     if (leaf.isStringAttr())
1045       return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"))",
1046                      leaf.getStringAttr())
1047           .str();
1048     return lookUpArgLoc(0);
1049   }
1050 
1051   std::string ret;
1052   llvm::raw_string_ostream os(ret);
1053   std::string strAttr;
1054   os << "rewriter.getFusedLoc({";
1055   bool first = true;
1056   for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
1057     DagLeaf leaf = tree.getArgAsLeaf(i);
1058     // Handle the optional string value.
1059     if (leaf.isStringAttr()) {
1060       if (!strAttr.empty())
1061         llvm::PrintFatalError("Only one string attribute may be specified");
1062       strAttr = leaf.getStringAttr();
1063       continue;
1064     }
1065     os << (first ? "" : ", ") << lookUpArgLoc(i);
1066     first = false;
1067   }
1068   os << "}";
1069   if (!strAttr.empty()) {
1070     os << ", rewriter.getStringAttr(\"" << strAttr << "\")";
1071   }
1072   os << ")";
1073   return os.str();
1074 }
1075 
handleReturnTypeArg(DagNode returnType,int i,int depth)1076 std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i,
1077                                                 int depth) {
1078   // Nested NativeCodeCall.
1079   if (auto dagNode = returnType.getArgAsNestedDag(i)) {
1080     if (!dagNode.isNativeCodeCall())
1081       PrintFatalError(loc, "nested DAG in `returnType` must be a native code "
1082                            "call");
1083     return handleReplaceWithNativeCodeCall(dagNode, depth);
1084   }
1085   // String literal.
1086   auto dagLeaf = returnType.getArgAsLeaf(i);
1087   if (dagLeaf.isStringAttr())
1088     return tgfmt(dagLeaf.getStringAttr(), &fmtCtx);
1089   return tgfmt(
1090       "$0.getType()", &fmtCtx,
1091       handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i)));
1092 }
1093 
handleOpArgument(DagLeaf leaf,StringRef patArgName)1094 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
1095                                              StringRef patArgName) {
1096   if (leaf.isStringAttr())
1097     PrintFatalError(loc, "raw string not supported as argument");
1098   if (leaf.isConstantAttr()) {
1099     auto constAttr = leaf.getAsConstantAttr();
1100     return handleConstantAttr(constAttr.getAttribute(),
1101                               constAttr.getConstantValue());
1102   }
1103   if (leaf.isEnumAttrCase()) {
1104     auto enumCase = leaf.getAsEnumAttrCase();
1105     if (enumCase.isStrCase())
1106       return handleConstantAttr(enumCase, enumCase.getSymbol());
1107     // This is an enum case backed by an IntegerAttr. We need to get its value
1108     // to build the constant.
1109     std::string val = std::to_string(enumCase.getValue());
1110     return handleConstantAttr(enumCase, val);
1111   }
1112 
1113   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
1114   auto argName = symbolInfoMap.getValueAndRangeUse(patArgName);
1115   if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
1116     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName
1117                             << "' (via symbol ref)\n");
1118     return argName;
1119   }
1120   if (leaf.isNativeCodeCall()) {
1121     auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName));
1122     LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl
1123                             << "' (via NativeCodeCall)\n");
1124     return std::string(repl);
1125   }
1126   PrintFatalError(loc, "unhandled case when rewriting op");
1127 }
1128 
handleReplaceWithNativeCodeCall(DagNode tree,int depth)1129 std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree,
1130                                                             int depth) {
1131   LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: ");
1132   LLVM_DEBUG(tree.print(llvm::dbgs()));
1133   LLVM_DEBUG(llvm::dbgs() << '\n');
1134 
1135   auto fmt = tree.getNativeCodeTemplate();
1136 
1137   SmallVector<std::string, 16> attrs;
1138 
1139   auto tail = getTrailingDirectives(tree);
1140   if (tail.returnType)
1141     PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier");
1142   auto locToUse = getLocation(tail);
1143 
1144   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1145     if (tree.isNestedDagArg(i)) {
1146       attrs.push_back(
1147           handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1));
1148     } else {
1149       attrs.push_back(
1150           handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)));
1151     }
1152     LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i
1153                             << " replacement: " << attrs[i] << "\n");
1154   }
1155 
1156   std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse),
1157                              static_cast<ArrayRef<std::string>>(attrs));
1158 
1159   // In general, NativeCodeCall without naming binding don't need this. To
1160   // ensure void helper function has been correctly labeled, i.e., use
1161   // NativeCodeCallVoid, we cache the result to a local variable so that we will
1162   // get a compilation error in the auto-generated file.
1163   // Example.
1164   //   // In the td file
1165   //   Pat<(...), (NativeCodeCall<Foo> ...)>
1166   //
1167   //   ---
1168   //
1169   //   // In the auto-generated .cpp
1170   //   ...
1171   //   // Causes compilation error if Foo() returns void.
1172   //   auto nativeVar = Foo();
1173   //   ...
1174   if (tree.getNumReturnsOfNativeCode() != 0) {
1175     // Determine the local variable name for return value.
1176     std::string varName =
1177         SymbolInfoMap::getValuePackName(tree.getSymbol()).str();
1178     if (varName.empty()) {
1179       varName = formatv("nativeVar_{0}", nextValueId++);
1180       // Register the local variable for later uses.
1181       symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode());
1182     }
1183 
1184     // Catch the return value of helper function.
1185     os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol);
1186 
1187     if (!tree.getSymbol().empty())
1188       symbol = tree.getSymbol().str();
1189     else
1190       symbol = varName;
1191   }
1192 
1193   return symbol;
1194 }
1195 
getNodeValueCount(DagNode node)1196 int PatternEmitter::getNodeValueCount(DagNode node) {
1197   if (node.isOperation()) {
1198     // If the op is bound to a symbol in the rewrite rule, query its result
1199     // count from the symbol info map.
1200     auto symbol = node.getSymbol();
1201     if (!symbol.empty()) {
1202       return symbolInfoMap.getStaticValueCount(symbol);
1203     }
1204     // Otherwise this is an unbound op; we will use all its results.
1205     return pattern.getDialectOp(node).getNumResults();
1206   }
1207 
1208   if (node.isNativeCodeCall())
1209     return node.getNumReturnsOfNativeCode();
1210 
1211   return 1;
1212 }
1213 
1214 PatternEmitter::TrailingDirectives
getTrailingDirectives(DagNode tree)1215 PatternEmitter::getTrailingDirectives(DagNode tree) {
1216   TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0};
1217 
1218   // Look backwards through the arguments.
1219   auto numPatArgs = tree.getNumArgs();
1220   for (int i = numPatArgs - 1; i >= 0; --i) {
1221     auto dagArg = tree.getArgAsNestedDag(i);
1222     // A leaf is not a directive. Stop looking.
1223     if (!dagArg)
1224       break;
1225 
1226     auto isLocation = dagArg.isLocationDirective();
1227     auto isReturnType = dagArg.isReturnTypeDirective();
1228     // If encountered a DAG node that isn't a trailing directive, stop looking.
1229     if (!(isLocation || isReturnType))
1230       break;
1231     // Save the directive, but error if one of the same type was already
1232     // found.
1233     ++tail.numDirectives;
1234     if (isLocation) {
1235       if (tail.location)
1236         PrintFatalError(loc, "`location` directive can only be specified "
1237                              "once");
1238       tail.location = dagArg;
1239     } else if (isReturnType) {
1240       if (tail.returnType)
1241         PrintFatalError(loc, "`returnType` directive can only be specified "
1242                              "once");
1243       tail.returnType = dagArg;
1244     }
1245   }
1246 
1247   return tail;
1248 }
1249 
1250 std::string
getLocation(PatternEmitter::TrailingDirectives & tail)1251 PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) {
1252   if (tail.location)
1253     return handleLocationDirective(tail.location);
1254 
1255   // If no explicit location is given, use the default, all fused, location.
1256   return "odsLoc";
1257 }
1258 
handleOpCreation(DagNode tree,int resultIndex,int depth)1259 std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
1260                                              int depth) {
1261   LLVM_DEBUG(llvm::dbgs() << "create op for pattern: ");
1262   LLVM_DEBUG(tree.print(llvm::dbgs()));
1263   LLVM_DEBUG(llvm::dbgs() << '\n');
1264 
1265   Operator &resultOp = tree.getDialectOp(opMap);
1266   auto numOpArgs = resultOp.getNumArgs();
1267   auto numPatArgs = tree.getNumArgs();
1268 
1269   auto tail = getTrailingDirectives(tree);
1270   auto locToUse = getLocation(tail);
1271 
1272   auto inPattern = numPatArgs - tail.numDirectives;
1273   if (numOpArgs != inPattern) {
1274     PrintFatalError(loc,
1275                     formatv("resultant op '{0}' argument number mismatch: "
1276                             "{1} in pattern vs. {2} in definition",
1277                             resultOp.getOperationName(), inPattern, numOpArgs));
1278   }
1279 
1280   // A map to collect all nested DAG child nodes' names, with operand index as
1281   // the key. This includes both bound and unbound child nodes.
1282   ChildNodeIndexNameMap childNodeNames;
1283 
1284   // First go through all the child nodes who are nested DAG constructs to
1285   // create ops for them and remember the symbol names for them, so that we can
1286   // use the results in the current node. This happens in a recursive manner.
1287   for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) {
1288     if (auto child = tree.getArgAsNestedDag(i))
1289       childNodeNames[i] = handleResultPattern(child, i, depth + 1);
1290   }
1291 
1292   // The name of the local variable holding this op.
1293   std::string valuePackName;
1294   // The symbol for holding the result of this pattern. Note that the result of
1295   // this pattern is not necessarily the same as the variable created by this
1296   // pattern because we can use `__N` suffix to refer only a specific result if
1297   // the generated op is a multi-result op.
1298   std::string resultValue;
1299   if (tree.getSymbol().empty()) {
1300     // No symbol is explicitly bound to this op in the pattern. Generate a
1301     // unique name.
1302     valuePackName = resultValue = getUniqueSymbol(&resultOp);
1303   } else {
1304     resultValue = std::string(tree.getSymbol());
1305     // Strip the index to get the name for the value pack and use it to name the
1306     // local variable for the op.
1307     valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue));
1308   }
1309 
1310   // Create the local variable for this op.
1311   os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
1312                 valuePackName);
1313 
1314   // Right now ODS don't have general type inference support. Except a few
1315   // special cases listed below, DRR needs to supply types for all results
1316   // when building an op.
1317   bool isSameOperandsAndResultType =
1318       resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType");
1319   bool useFirstAttr =
1320       resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType");
1321 
1322   if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) {
1323     // We know how to deduce the result type for ops with these traits and we've
1324     // generated builders taking aggregate parameters. Use those builders to
1325     // create the ops.
1326 
1327     // First prepare local variables for op arguments used in builder call.
1328     createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1329 
1330     // Then create the op.
1331     os.scope("", "\n}\n").os << formatv(
1332         "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
1333         valuePackName, resultOp.getQualCppClassName(), locToUse);
1334     return resultValue;
1335   }
1336 
1337   bool usePartialResults = valuePackName != resultValue;
1338 
1339   if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) {
1340     // For these cases (broadcastable ops, op results used both as auxiliary
1341     // values and replacement values, ops in nested patterns, auxiliary ops), we
1342     // still need to supply the result types when building the op. But because
1343     // we don't generate a builder automatically with ODS for them, it's the
1344     // developer's responsibility to make sure such a builder (with result type
1345     // deduction ability) exists. We go through the separate-parameter builder
1346     // here given that it's easier for developers to write compared to
1347     // aggregate-parameter builders.
1348     createSeparateLocalVarsForOpArgs(tree, childNodeNames);
1349 
1350     os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
1351                              resultOp.getQualCppClassName(), locToUse);
1352     supplyValuesForOpArgs(tree, childNodeNames, depth);
1353     os << "\n  );\n}\n";
1354     return resultValue;
1355   }
1356 
1357   // If we are provided explicit return types, use them to build the op.
1358   // However, if depth == 0 and resultIndex >= 0, it means we are replacing
1359   // the values generated from the source pattern root op. Then we must use the
1360   // source pattern's value types to determine the value type of the generated
1361   // op here.
1362   if (depth == 0 && resultIndex >= 0 && tail.returnType)
1363     PrintFatalError(loc, "Cannot specify explicit return types in an op whose "
1364                          "return values replace the source pattern's root op");
1365 
1366   // First prepare local variables for op arguments used in builder call.
1367   createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth);
1368 
1369   // Then prepare the result types. We need to specify the types for all
1370   // results.
1371   os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
1372                          "(void)tblgen_types;\n");
1373   int numResults = resultOp.getNumResults();
1374   if (tail.returnType) {
1375     auto numRetTys = tail.returnType.getNumArgs();
1376     for (int i = 0; i < numRetTys; ++i) {
1377       auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1);
1378       os << "tblgen_types.push_back(" << varName << ");\n";
1379     }
1380   } else {
1381     if (numResults != 0) {
1382       // Copy the result types from the source pattern.
1383       for (int i = 0; i < numResults; ++i)
1384         os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
1385                       "  tblgen_types.push_back(v.getType());\n}\n",
1386                       resultIndex + i);
1387     }
1388   }
1389   os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
1390                 "tblgen_values, tblgen_attrs);\n",
1391                 valuePackName, resultOp.getQualCppClassName(), locToUse);
1392   os.unindent() << "}\n";
1393   return resultValue;
1394 }
1395 
createSeparateLocalVarsForOpArgs(DagNode node,ChildNodeIndexNameMap & childNodeNames)1396 void PatternEmitter::createSeparateLocalVarsForOpArgs(
1397     DagNode node, ChildNodeIndexNameMap &childNodeNames) {
1398   Operator &resultOp = node.getDialectOp(opMap);
1399 
1400   // Now prepare operands used for building this op:
1401   // * If the operand is non-variadic, we create a `Value` local variable.
1402   // * If the operand is variadic, we create a `SmallVector<Value>` local
1403   //   variable.
1404 
1405   int valueIndex = 0; // An index for uniquing local variable names.
1406   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1407     const auto *operand =
1408         resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
1409     // We do not need special handling for attributes.
1410     if (!operand)
1411       continue;
1412 
1413     raw_indented_ostream::DelimitedScope scope(os);
1414     std::string varName;
1415     if (operand->isVariadic()) {
1416       varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
1417       os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
1418       std::string range;
1419       if (node.isNestedDagArg(argIndex)) {
1420         range = childNodeNames[argIndex];
1421       } else {
1422         range = std::string(node.getArgName(argIndex));
1423       }
1424       // Resolve the symbol for all range use so that we have a uniform way of
1425       // capturing the values.
1426       range = symbolInfoMap.getValueAndRangeUse(range);
1427       os << formatv("for (auto v: {0}) {{\n  {1}.push_back(v);\n}\n", range,
1428                     varName);
1429     } else {
1430       varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
1431       os << formatv("::mlir::Value {0} = ", varName);
1432       if (node.isNestedDagArg(argIndex)) {
1433         os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
1434       } else {
1435         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1436         auto symbol =
1437             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1438         if (leaf.isNativeCodeCall()) {
1439           os << std::string(
1440               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1441         } else {
1442           os << symbol;
1443         }
1444       }
1445       os << ";\n";
1446     }
1447 
1448     // Update to use the newly created local variable for building the op later.
1449     childNodeNames[argIndex] = varName;
1450   }
1451 }
1452 
supplyValuesForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1453 void PatternEmitter::supplyValuesForOpArgs(
1454     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1455   Operator &resultOp = node.getDialectOp(opMap);
1456   for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
1457        argIndex != numOpArgs; ++argIndex) {
1458     // Start each argument on its own line.
1459     os << ",\n    ";
1460 
1461     Argument opArg = resultOp.getArg(argIndex);
1462     // Handle the case of operand first.
1463     if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
1464       if (!operand->name.empty())
1465         os << "/*" << operand->name << "=*/";
1466       os << childNodeNames.lookup(argIndex);
1467       continue;
1468     }
1469 
1470     // The argument in the op definition.
1471     auto opArgName = resultOp.getArgName(argIndex);
1472     if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1473       if (!subTree.isNativeCodeCall())
1474         PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1475                              "for creating attribute");
1476       os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex));
1477     } else {
1478       auto leaf = node.getArgAsLeaf(argIndex);
1479       // The argument in the result DAG pattern.
1480       auto patArgName = node.getArgName(argIndex);
1481       if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
1482         // TODO: Refactor out into map to avoid recomputing these.
1483         if (!opArg.is<NamedAttribute *>())
1484           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
1485         if (!patArgName.empty())
1486           os << "/*" << patArgName << "=*/";
1487       } else {
1488         os << "/*" << opArgName << "=*/";
1489       }
1490       os << handleOpArgument(leaf, patArgName);
1491     }
1492   }
1493 }
1494 
createAggregateLocalVarsForOpArgs(DagNode node,const ChildNodeIndexNameMap & childNodeNames,int depth)1495 void PatternEmitter::createAggregateLocalVarsForOpArgs(
1496     DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) {
1497   Operator &resultOp = node.getDialectOp(opMap);
1498 
1499   auto scope = os.scope();
1500   os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
1501                 "tblgen_values; (void)tblgen_values;\n");
1502   os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
1503                 "tblgen_attrs; (void)tblgen_attrs;\n");
1504 
1505   const char *addAttrCmd =
1506       "if (auto tmpAttr = {1}) {\n"
1507       "  tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
1508       "tmpAttr);\n}\n";
1509   for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
1510     if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
1511       // The argument in the op definition.
1512       auto opArgName = resultOp.getArgName(argIndex);
1513       if (auto subTree = node.getArgAsNestedDag(argIndex)) {
1514         if (!subTree.isNativeCodeCall())
1515           PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
1516                                "for creating attribute");
1517         os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex));
1518       } else {
1519         auto leaf = node.getArgAsLeaf(argIndex);
1520         // The argument in the result DAG pattern.
1521         auto patArgName = node.getArgName(argIndex);
1522         os << formatv(addAttrCmd, opArgName,
1523                       handleOpArgument(leaf, patArgName));
1524       }
1525       continue;
1526     }
1527 
1528     const auto *operand =
1529         resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
1530     std::string varName;
1531     if (operand->isVariadic()) {
1532       std::string range;
1533       if (node.isNestedDagArg(argIndex)) {
1534         range = childNodeNames.lookup(argIndex);
1535       } else {
1536         range = std::string(node.getArgName(argIndex));
1537       }
1538       // Resolve the symbol for all range use so that we have a uniform way of
1539       // capturing the values.
1540       range = symbolInfoMap.getValueAndRangeUse(range);
1541       os << formatv("for (auto v: {0}) {{\n  tblgen_values.push_back(v);\n}\n",
1542                     range);
1543     } else {
1544       os << formatv("tblgen_values.push_back(");
1545       if (node.isNestedDagArg(argIndex)) {
1546         os << symbolInfoMap.getValueAndRangeUse(
1547             childNodeNames.lookup(argIndex));
1548       } else {
1549         DagLeaf leaf = node.getArgAsLeaf(argIndex);
1550         if (leaf.isConstantAttr())
1551           // TODO: Use better location
1552           PrintFatalError(
1553               loc,
1554               "attribute found where value was expected, if attempting to use "
1555               "constant value, construct a constant op with given attribute "
1556               "instead");
1557 
1558         auto symbol =
1559             symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
1560         if (leaf.isNativeCodeCall()) {
1561           os << std::string(
1562               tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)));
1563         } else {
1564           os << symbol;
1565         }
1566       }
1567       os << ");\n";
1568     }
1569   }
1570 }
1571 
StaticMatcherHelper(RecordOperatorMap & mapper)1572 StaticMatcherHelper::StaticMatcherHelper(RecordOperatorMap &mapper)
1573     : opMap(mapper) {}
1574 
populateStaticMatchers(raw_ostream & os)1575 void StaticMatcherHelper::populateStaticMatchers(raw_ostream &os) {
1576   // PatternEmitter will use the static matcher if there's one generated. To
1577   // ensure that all the dependent static matchers are generated before emitting
1578   // the matching logic of the DagNode, we use topological order to achieve it.
1579   for (auto &dagInfo : topologicalOrder) {
1580     DagNode node = dagInfo.first;
1581     if (!useStaticMatcher(node))
1582       continue;
1583 
1584     std::string funcName =
1585         formatv("static_dag_matcher_{0}", staticMatcherCounter++);
1586     assert(matcherNames.find(node) == matcherNames.end());
1587     PatternEmitter(dagInfo.second, &opMap, os, *this)
1588         .emitStaticMatcher(node, funcName);
1589     matcherNames[node] = funcName;
1590   }
1591 }
1592 
addPattern(Record * record)1593 void StaticMatcherHelper::addPattern(Record *record) {
1594   Pattern pat(record, &opMap);
1595 
1596   // While generating the function body of the DAG matcher, it may depends on
1597   // other DAG matchers. To ensure the dependent matchers are ready, we compute
1598   // the topological order for all the DAGs and emit the DAG matchers in this
1599   // order.
1600   llvm::unique_function<void(DagNode)> dfs = [&](DagNode node) {
1601     ++refStats[node];
1602 
1603     if (refStats[node] != 1)
1604       return;
1605 
1606     for (unsigned i = 0, e = node.getNumArgs(); i < e; ++i)
1607       if (DagNode sibling = node.getArgAsNestedDag(i))
1608         dfs(sibling);
1609 
1610     topologicalOrder.push_back(std::make_pair(node, record));
1611   };
1612 
1613   dfs(pat.getSourcePattern());
1614 }
1615 
emitRewriters(const RecordKeeper & recordKeeper,raw_ostream & os)1616 static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
1617   emitSourceFileHeader("Rewriters", os);
1618 
1619   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
1620 
1621   // We put the map here because it can be shared among multiple patterns.
1622   RecordOperatorMap recordOpMap;
1623 
1624   // Exam all the patterns and generate static matcher for the duplicated
1625   // DagNode.
1626   StaticMatcherHelper staticMatcher(recordOpMap);
1627   for (Record *p : patterns)
1628     staticMatcher.addPattern(p);
1629   staticMatcher.populateStaticMatchers(os);
1630 
1631   std::vector<std::string> rewriterNames;
1632   rewriterNames.reserve(patterns.size());
1633 
1634   std::string baseRewriterName = "GeneratedConvert";
1635   int rewriterIndex = 0;
1636 
1637   for (Record *p : patterns) {
1638     std::string name;
1639     if (p->isAnonymous()) {
1640       // If no name is provided, ensure unique rewriter names simply by
1641       // appending unique suffix.
1642       name = baseRewriterName + llvm::utostr(rewriterIndex++);
1643     } else {
1644       name = std::string(p->getName());
1645     }
1646     LLVM_DEBUG(llvm::dbgs()
1647                << "=== start generating pattern '" << name << "' ===\n");
1648     PatternEmitter(p, &recordOpMap, os, staticMatcher).emit(name);
1649     LLVM_DEBUG(llvm::dbgs()
1650                << "=== done generating pattern '" << name << "' ===\n");
1651     rewriterNames.push_back(std::move(name));
1652   }
1653 
1654   // Emit function to add the generated matchers to the pattern list.
1655   os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated("
1656         "::mlir::RewritePatternSet &patterns) {\n";
1657   for (const auto &name : rewriterNames) {
1658     os << "  patterns.add<" << name << ">(patterns.getContext());\n";
1659   }
1660   os << "}\n";
1661 }
1662 
1663 static mlir::GenRegistration
1664     genRewriters("gen-rewriters", "Generate pattern rewriters",
__anon78740d160802(const RecordKeeper &records, raw_ostream &os) 1665                  [](const RecordKeeper &records, raw_ostream &os) {
1666                    emitRewriters(records, os);
1667                    return false;
1668                  });
1669