1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 #include "../PassDetail.h"
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::pdl_to_pdl_interp;
25 
26 //===----------------------------------------------------------------------===//
27 // PatternLowering
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 /// This class generators operations within the PDL Interpreter dialect from a
32 /// given module containing PDL pattern operations.
33 struct PatternLowering {
34 public:
35   PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule);
36 
37   /// Generate code for matching and rewriting based on the pattern operations
38   /// within the module.
39   void lower(ModuleOp module);
40 
41 private:
42   using ValueMap = llvm::ScopedHashTable<Position *, Value>;
43   using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
44 
45   /// Generate interpreter operations for the tree rooted at the given matcher
46   /// node.
47   Block *generateMatcher(MatcherNode &node);
48 
49   /// Get or create an access to the provided positional value within the
50   /// current block.
51   Value getValueAt(Block *cur, Position *pos);
52 
53   /// Create an interpreter predicate operation, branching to the provided true
54   /// and false destinations.
55   void generatePredicate(Block *currentBlock, Qualifier *question,
56                          Qualifier *answer, Value val, Block *trueDest,
57                          Block *falseDest);
58 
59   /// Create an interpreter switch predicate operation, with a provided default
60   /// and several case destinations.
61   void generateSwitch(SwitchNode *switchNode, Block *currentBlock,
62                       Qualifier *question, Value val, Block *defaultDest);
63 
64   /// Create the interpreter operations to record a successful pattern match.
65   void generateRecordMatch(Block *currentBlock, Block *nextBlock,
66                            pdl::PatternOp pattern);
67 
68   /// Generate a rewriter function for the given pattern operation, and returns
69   /// a reference to that function.
70   SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
71                                  SmallVectorImpl<Position *> &usedMatchValues);
72 
73   /// Generate the rewriter code for the given operation.
74   void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
75                         DenseMap<Value, Value> &rewriteValues,
76                         function_ref<Value(Value)> mapRewriteValue);
77   void generateRewriter(pdl::AttributeOp attrOp,
78                         DenseMap<Value, Value> &rewriteValues,
79                         function_ref<Value(Value)> mapRewriteValue);
80   void generateRewriter(pdl::EraseOp eraseOp,
81                         DenseMap<Value, Value> &rewriteValues,
82                         function_ref<Value(Value)> mapRewriteValue);
83   void generateRewriter(pdl::OperationOp operationOp,
84                         DenseMap<Value, Value> &rewriteValues,
85                         function_ref<Value(Value)> mapRewriteValue);
86   void generateRewriter(pdl::ReplaceOp replaceOp,
87                         DenseMap<Value, Value> &rewriteValues,
88                         function_ref<Value(Value)> mapRewriteValue);
89   void generateRewriter(pdl::ResultOp resultOp,
90                         DenseMap<Value, Value> &rewriteValues,
91                         function_ref<Value(Value)> mapRewriteValue);
92   void generateRewriter(pdl::ResultsOp resultOp,
93                         DenseMap<Value, Value> &rewriteValues,
94                         function_ref<Value(Value)> mapRewriteValue);
95   void generateRewriter(pdl::TypeOp typeOp,
96                         DenseMap<Value, Value> &rewriteValues,
97                         function_ref<Value(Value)> mapRewriteValue);
98   void generateRewriter(pdl::TypesOp typeOp,
99                         DenseMap<Value, Value> &rewriteValues,
100                         function_ref<Value(Value)> mapRewriteValue);
101 
102   /// Generate the values used for resolving the result types of an operation
103   /// created within a dag rewriter region.
104   void generateOperationResultTypeRewriter(
105       pdl::OperationOp op, SmallVectorImpl<Value> &types,
106       DenseMap<Value, Value> &rewriteValues,
107       function_ref<Value(Value)> mapRewriteValue);
108 
109   /// A builder to use when generating interpreter operations.
110   OpBuilder builder;
111 
112   /// The matcher function used for all match related logic within PDL patterns.
113   FuncOp matcherFunc;
114 
115   /// The rewriter module containing the all rewrite related logic within PDL
116   /// patterns.
117   ModuleOp rewriterModule;
118 
119   /// The symbol table of the rewriter module used for insertion.
120   SymbolTable rewriterSymbolTable;
121 
122   /// A scoped map connecting a position with the corresponding interpreter
123   /// value.
124   ValueMap values;
125 
126   /// A stack of blocks used as the failure destination for matcher nodes that
127   /// don't have an explicit failure path.
128   SmallVector<Block *, 8> failureBlockStack;
129 
130   /// A mapping between values defined in a pattern match, and the corresponding
131   /// positional value.
132   DenseMap<Value, Position *> valueToPosition;
133 
134   /// The set of operation values whose whose location will be used for newly
135   /// generated operations.
136   SetVector<Value> locOps;
137 };
138 } // end anonymous namespace
139 
PatternLowering(FuncOp matcherFunc,ModuleOp rewriterModule)140 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule)
141     : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
142       rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
143 
lower(ModuleOp module)144 void PatternLowering::lower(ModuleOp module) {
145   PredicateUniquer predicateUniquer;
146   PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
147 
148   // Define top-level scope for the arguments to the matcher function.
149   ValueMapScope topLevelValueScope(values);
150 
151   // Insert the root operation, i.e. argument to the matcher, at the root
152   // position.
153   Block *matcherEntryBlock = matcherFunc.addEntryBlock();
154   values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
155 
156   // Generate a root matcher node from the provided PDL module.
157   std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
158       module, predicateBuilder, valueToPosition);
159   Block *firstMatcherBlock = generateMatcher(*root);
160 
161   // After generation, merged the first matched block into the entry.
162   matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
163                                             firstMatcherBlock->getOperations());
164   firstMatcherBlock->erase();
165 }
166 
generateMatcher(MatcherNode & node)167 Block *PatternLowering::generateMatcher(MatcherNode &node) {
168   // Push a new scope for the values used by this matcher.
169   Block *block = matcherFunc.addBlock();
170   ValueMapScope scope(values);
171 
172   // If this is the return node, simply insert the corresponding interpreter
173   // finalize.
174   if (isa<ExitNode>(node)) {
175     builder.setInsertionPointToEnd(block);
176     builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
177     return block;
178   }
179 
180   // If this node contains a position, get the corresponding value for this
181   // block.
182   Position *position = node.getPosition();
183   Value val = position ? getValueAt(block, position) : Value();
184 
185   // Get the next block in the match sequence.
186   std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
187   Block *nextBlock;
188   if (failureNode) {
189     nextBlock = generateMatcher(*failureNode);
190     failureBlockStack.push_back(nextBlock);
191   } else {
192     assert(!failureBlockStack.empty() && "expected valid failure block");
193     nextBlock = failureBlockStack.back();
194   }
195 
196   // If this value corresponds to an operation, record that we are going to use
197   // its location as part of a fused location.
198   bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
199   if (isOperationValue)
200     locOps.insert(val);
201 
202   // Generate code for a boolean predicate node.
203   if (auto *boolNode = dyn_cast<BoolNode>(&node)) {
204     auto *child = generateMatcher(*boolNode->getSuccessNode());
205     generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val,
206                       child, nextBlock);
207 
208     // Generate code for a switch node.
209   } else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) {
210     generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock);
211 
212     // Generate code for a success node.
213   } else if (auto *successNode = dyn_cast<SuccessNode>(&node)) {
214     generateRecordMatch(block, nextBlock, successNode->getPattern());
215   }
216 
217   if (failureNode)
218     failureBlockStack.pop_back();
219   if (isOperationValue)
220     locOps.remove(val);
221   return block;
222 }
223 
getValueAt(Block * cur,Position * pos)224 Value PatternLowering::getValueAt(Block *cur, Position *pos) {
225   if (Value val = values.lookup(pos))
226     return val;
227 
228   // Get the value for the parent position.
229   Value parentVal = getValueAt(cur, pos->getParent());
230 
231   // TODO: Use a location from the position.
232   Location loc = parentVal.getLoc();
233   builder.setInsertionPointToEnd(cur);
234   Value value;
235   switch (pos->getKind()) {
236   case Predicates::OperationPos:
237     value = builder.create<pdl_interp::GetDefiningOpOp>(
238         loc, builder.getType<pdl::OperationType>(), parentVal);
239     break;
240   case Predicates::OperandPos: {
241     auto *operandPos = cast<OperandPosition>(pos);
242     value = builder.create<pdl_interp::GetOperandOp>(
243         loc, builder.getType<pdl::ValueType>(), parentVal,
244         operandPos->getOperandNumber());
245     break;
246   }
247   case Predicates::OperandGroupPos: {
248     auto *operandPos = cast<OperandGroupPosition>(pos);
249     Type valueTy = builder.getType<pdl::ValueType>();
250     value = builder.create<pdl_interp::GetOperandsOp>(
251         loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
252         parentVal, operandPos->getOperandGroupNumber());
253     break;
254   }
255   case Predicates::AttributePos: {
256     auto *attrPos = cast<AttributePosition>(pos);
257     value = builder.create<pdl_interp::GetAttributeOp>(
258         loc, builder.getType<pdl::AttributeType>(), parentVal,
259         attrPos->getName().strref());
260     break;
261   }
262   case Predicates::TypePos: {
263     if (parentVal.getType().isa<pdl::AttributeType>())
264       value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
265     else
266       value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
267     break;
268   }
269   case Predicates::ResultPos: {
270     auto *resPos = cast<ResultPosition>(pos);
271     value = builder.create<pdl_interp::GetResultOp>(
272         loc, builder.getType<pdl::ValueType>(), parentVal,
273         resPos->getResultNumber());
274     break;
275   }
276   case Predicates::ResultGroupPos: {
277     auto *resPos = cast<ResultGroupPosition>(pos);
278     Type valueTy = builder.getType<pdl::ValueType>();
279     value = builder.create<pdl_interp::GetResultsOp>(
280         loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
281         parentVal, resPos->getResultGroupNumber());
282     break;
283   }
284   default:
285     llvm_unreachable("Generating unknown Position getter");
286     break;
287   }
288   values.insert(pos, value);
289   return value;
290 }
291 
generatePredicate(Block * currentBlock,Qualifier * question,Qualifier * answer,Value val,Block * trueDest,Block * falseDest)292 void PatternLowering::generatePredicate(Block *currentBlock,
293                                         Qualifier *question, Qualifier *answer,
294                                         Value val, Block *trueDest,
295                                         Block *falseDest) {
296   builder.setInsertionPointToEnd(currentBlock);
297   Location loc = val.getLoc();
298   Predicates::Kind kind = question->getKind();
299   switch (kind) {
300   case Predicates::IsNotNullQuestion:
301     builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest);
302     break;
303   case Predicates::OperationNameQuestion: {
304     auto *opNameAnswer = cast<OperationNameAnswer>(answer);
305     builder.create<pdl_interp::CheckOperationNameOp>(
306         loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest);
307     break;
308   }
309   case Predicates::TypeQuestion: {
310     auto *ans = cast<TypeAnswer>(answer);
311     if (val.getType().isa<pdl::RangeType>())
312       builder.create<pdl_interp::CheckTypesOp>(
313           loc, val, ans->getValue().cast<ArrayAttr>(), trueDest, falseDest);
314     else
315       builder.create<pdl_interp::CheckTypeOp>(
316           loc, val, ans->getValue().cast<TypeAttr>(), trueDest, falseDest);
317     break;
318   }
319   case Predicates::AttributeQuestion: {
320     auto *ans = cast<AttributeAnswer>(answer);
321     builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
322                                                  trueDest, falseDest);
323     break;
324   }
325   case Predicates::OperandCountAtLeastQuestion:
326   case Predicates::OperandCountQuestion:
327     builder.create<pdl_interp::CheckOperandCountOp>(
328         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
329         /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
330         trueDest, falseDest);
331     break;
332   case Predicates::ResultCountAtLeastQuestion:
333   case Predicates::ResultCountQuestion:
334     builder.create<pdl_interp::CheckResultCountOp>(
335         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
336         /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
337         trueDest, falseDest);
338     break;
339   case Predicates::EqualToQuestion: {
340     auto *equalToQuestion = cast<EqualToQuestion>(question);
341     builder.create<pdl_interp::AreEqualOp>(
342         loc, val, getValueAt(currentBlock, equalToQuestion->getValue()),
343         trueDest, falseDest);
344     break;
345   }
346   case Predicates::ConstraintQuestion: {
347     auto *cstQuestion = cast<ConstraintQuestion>(question);
348     SmallVector<Value, 2> args;
349     for (Position *position : std::get<1>(cstQuestion->getValue()))
350       args.push_back(getValueAt(currentBlock, position));
351     builder.create<pdl_interp::ApplyConstraintOp>(
352         loc, std::get<0>(cstQuestion->getValue()), args,
353         std::get<2>(cstQuestion->getValue()).cast<ArrayAttr>(), trueDest,
354         falseDest);
355     break;
356   }
357   default:
358     llvm_unreachable("Generating unknown Predicate operation");
359   }
360 }
361 
362 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
createSwitchOp(Value val,Block * defaultDest,OpBuilder & builder,llvm::MapVector<Qualifier *,Block * > & dests)363 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
364                            llvm::MapVector<Qualifier *, Block *> &dests) {
365   std::vector<ValT> values;
366   std::vector<Block *> blocks;
367   values.reserve(dests.size());
368   blocks.reserve(dests.size());
369   for (const auto &it : dests) {
370     blocks.push_back(it.second);
371     values.push_back(cast<PredT>(it.first)->getValue());
372   }
373   builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
374 }
375 
generateSwitch(SwitchNode * switchNode,Block * currentBlock,Qualifier * question,Value val,Block * defaultDest)376 void PatternLowering::generateSwitch(SwitchNode *switchNode,
377                                      Block *currentBlock, Qualifier *question,
378                                      Value val, Block *defaultDest) {
379   // If the switch question is not an exact answer, i.e. for the `at_least`
380   // cases, we generate a special block sequence.
381   Predicates::Kind kind = question->getKind();
382   if (kind == Predicates::OperandCountAtLeastQuestion ||
383       kind == Predicates::ResultCountAtLeastQuestion) {
384     // Order the children such that the cases are in reverse numerical order.
385     SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
386         llvm::seq<unsigned>(0, switchNode->getChildren().size()));
387     llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
388       return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
389              cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
390     });
391 
392     // Build the destination for each child using the next highest child as a
393     // a failure destination. This essentially creates the following control
394     // flow:
395     //
396     // if (operand_count < 1)
397     //   goto failure
398     // if (child1.match())
399     //   ...
400     //
401     // if (operand_count < 2)
402     //   goto failure
403     // if (child2.match())
404     //   ...
405     //
406     // failure:
407     //   ...
408     //
409     failureBlockStack.push_back(defaultDest);
410     for (unsigned idx : sortedChildren) {
411       auto &child = switchNode->getChild(idx);
412       Block *childBlock = generateMatcher(*child.second);
413       Block *predicateBlock = builder.createBlock(childBlock);
414       generatePredicate(predicateBlock, question, child.first, val, childBlock,
415                         defaultDest);
416       failureBlockStack.back() = predicateBlock;
417     }
418     Block *firstPredicateBlock = failureBlockStack.pop_back_val();
419     currentBlock->getOperations().splice(currentBlock->end(),
420                                          firstPredicateBlock->getOperations());
421     firstPredicateBlock->erase();
422     return;
423   }
424 
425   // Otherwise, generate each of the children and generate an interpreter
426   // switch.
427   llvm::MapVector<Qualifier *, Block *> children;
428   for (auto &it : switchNode->getChildren())
429     children.insert({it.first, generateMatcher(*it.second)});
430   builder.setInsertionPointToEnd(currentBlock);
431 
432   switch (question->getKind()) {
433   case Predicates::OperandCountQuestion:
434     return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
435                           int32_t>(val, defaultDest, builder, children);
436   case Predicates::ResultCountQuestion:
437     return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
438                           int32_t>(val, defaultDest, builder, children);
439   case Predicates::OperationNameQuestion:
440     return createSwitchOp<pdl_interp::SwitchOperationNameOp,
441                           OperationNameAnswer>(val, defaultDest, builder,
442                                                children);
443   case Predicates::TypeQuestion:
444     if (val.getType().isa<pdl::RangeType>()) {
445       return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
446           val, defaultDest, builder, children);
447     }
448     return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
449         val, defaultDest, builder, children);
450   case Predicates::AttributeQuestion:
451     return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
452         val, defaultDest, builder, children);
453   default:
454     llvm_unreachable("Generating unknown switch predicate.");
455   }
456 }
457 
generateRecordMatch(Block * currentBlock,Block * nextBlock,pdl::PatternOp pattern)458 void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock,
459                                           pdl::PatternOp pattern) {
460   // Generate a rewriter for the pattern this success node represents, and track
461   // any values used from the match region.
462   SmallVector<Position *, 8> usedMatchValues;
463   SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
464 
465   // Process any values used in the rewrite that are defined in the match.
466   std::vector<Value> mappedMatchValues;
467   mappedMatchValues.reserve(usedMatchValues.size());
468   for (Position *position : usedMatchValues)
469     mappedMatchValues.push_back(getValueAt(currentBlock, position));
470 
471   // Collect the set of operations generated by the rewriter.
472   SmallVector<StringRef, 4> generatedOps;
473   for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
474     generatedOps.push_back(*op.name());
475   ArrayAttr generatedOpsAttr;
476   if (!generatedOps.empty())
477     generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
478 
479   // Grab the root kind if present.
480   StringAttr rootKindAttr;
481   if (Optional<StringRef> rootKind = pattern.getRootKind())
482     rootKindAttr = builder.getStringAttr(*rootKind);
483 
484   builder.setInsertionPointToEnd(currentBlock);
485   builder.create<pdl_interp::RecordMatchOp>(
486       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
487       rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
488       nextBlock);
489 }
490 
generateRewriter(pdl::PatternOp pattern,SmallVectorImpl<Position * > & usedMatchValues)491 SymbolRefAttr PatternLowering::generateRewriter(
492     pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
493   FuncOp rewriterFunc =
494       FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter",
495                      builder.getFunctionType(llvm::None, llvm::None));
496   rewriterSymbolTable.insert(rewriterFunc);
497 
498   // Generate the rewriter function body.
499   builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock());
500 
501   // Map an input operand of the pattern to a generated interpreter value.
502   DenseMap<Value, Value> rewriteValues;
503   auto mapRewriteValue = [&](Value oldValue) {
504     Value &newValue = rewriteValues[oldValue];
505     if (newValue)
506       return newValue;
507 
508     // Prefer materializing constants directly when possible.
509     Operation *oldOp = oldValue.getDefiningOp();
510     if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
511       if (Attribute value = attrOp.valueAttr()) {
512         return newValue = builder.create<pdl_interp::CreateAttributeOp>(
513                    attrOp.getLoc(), value);
514       }
515     } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
516       if (TypeAttr type = typeOp.typeAttr()) {
517         return newValue = builder.create<pdl_interp::CreateTypeOp>(
518                    typeOp.getLoc(), type);
519       }
520     } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
521       if (ArrayAttr type = typeOp.typesAttr()) {
522         return newValue = builder.create<pdl_interp::CreateTypesOp>(
523                    typeOp.getLoc(), typeOp.getType(), type);
524       }
525     }
526 
527     // Otherwise, add this as an input to the rewriter.
528     Position *inputPos = valueToPosition.lookup(oldValue);
529     assert(inputPos && "expected value to be a pattern input");
530     usedMatchValues.push_back(inputPos);
531     return newValue = rewriterFunc.front().addArgument(oldValue.getType());
532   };
533 
534   // If this is a custom rewriter, simply dispatch to the registered rewrite
535   // method.
536   pdl::RewriteOp rewriter = pattern.getRewriter();
537   if (StringAttr rewriteName = rewriter.nameAttr()) {
538     auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
539     SmallVector<Value, 4> args(1, mapRewriteValue(rewriter.root()));
540     args.append(mappedArgs.begin(), mappedArgs.end());
541     builder.create<pdl_interp::ApplyRewriteOp>(
542         rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args,
543         rewriter.externalConstParamsAttr());
544   } else {
545     // Otherwise this is a dag rewriter defined using PDL operations.
546     for (Operation &rewriteOp : *rewriter.getBody()) {
547       llvm::TypeSwitch<Operation *>(&rewriteOp)
548           .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
549                 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
550                 pdl::TypeOp, pdl::TypesOp>([&](auto op) {
551             this->generateRewriter(op, rewriteValues, mapRewriteValue);
552           });
553     }
554   }
555 
556   // Update the signature of the rewrite function.
557   rewriterFunc.setType(builder.getFunctionType(
558       llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
559       /*results=*/llvm::None));
560 
561   builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
562   return SymbolRefAttr::get(
563       builder.getContext(),
564       pdl_interp::PDLInterpDialect::getRewriterModuleName(),
565       SymbolRefAttr::get(rewriterFunc));
566 }
567 
generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)568 void PatternLowering::generateRewriter(
569     pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
570     function_ref<Value(Value)> mapRewriteValue) {
571   SmallVector<Value, 2> arguments;
572   for (Value argument : rewriteOp.args())
573     arguments.push_back(mapRewriteValue(argument));
574   auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
575       rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
576       arguments, rewriteOp.constParamsAttr());
577   for (auto it : llvm::zip(rewriteOp.results(), interpOp.results()))
578     rewriteValues[std::get<0>(it)] = std::get<1>(it);
579 }
580 
generateRewriter(pdl::AttributeOp attrOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)581 void PatternLowering::generateRewriter(
582     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
583     function_ref<Value(Value)> mapRewriteValue) {
584   Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
585       attrOp.getLoc(), attrOp.valueAttr());
586   rewriteValues[attrOp] = newAttr;
587 }
588 
generateRewriter(pdl::EraseOp eraseOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)589 void PatternLowering::generateRewriter(
590     pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
591     function_ref<Value(Value)> mapRewriteValue) {
592   builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
593                                       mapRewriteValue(eraseOp.operation()));
594 }
595 
generateRewriter(pdl::OperationOp operationOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)596 void PatternLowering::generateRewriter(
597     pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
598     function_ref<Value(Value)> mapRewriteValue) {
599   SmallVector<Value, 4> operands;
600   for (Value operand : operationOp.operands())
601     operands.push_back(mapRewriteValue(operand));
602 
603   SmallVector<Value, 4> attributes;
604   for (Value attr : operationOp.attributes())
605     attributes.push_back(mapRewriteValue(attr));
606 
607   SmallVector<Value, 2> types;
608   generateOperationResultTypeRewriter(operationOp, types, rewriteValues,
609                                       mapRewriteValue);
610 
611   // Create the new operation.
612   Location loc = operationOp.getLoc();
613   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
614       loc, *operationOp.name(), types, operands, attributes,
615       operationOp.attributeNames());
616   rewriteValues[operationOp.op()] = createdOp;
617 
618   // Generate accesses for any results that have their types constrained.
619   // Handle the case where there is a single range representing all of the
620   // result types.
621   OperandRange resultTys = operationOp.types();
622   if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
623     Value &type = rewriteValues[resultTys[0]];
624     if (!type) {
625       auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
626       type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
627     }
628     return;
629   }
630 
631   // Otherwise, populate the individual results.
632   bool seenVariableLength = false;
633   Type valueTy = builder.getType<pdl::ValueType>();
634   Type valueRangeTy = pdl::RangeType::get(valueTy);
635   for (auto it : llvm::enumerate(resultTys)) {
636     Value &type = rewriteValues[it.value()];
637     if (type)
638       continue;
639     bool isVariadic = it.value().getType().isa<pdl::RangeType>();
640     seenVariableLength |= isVariadic;
641 
642     // After a variable length result has been seen, we need to use result
643     // groups because the exact index of the result is not statically known.
644     Value resultVal;
645     if (seenVariableLength)
646       resultVal = builder.create<pdl_interp::GetResultsOp>(
647           loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
648     else
649       resultVal = builder.create<pdl_interp::GetResultOp>(
650           loc, valueTy, createdOp, it.index());
651     type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
652   }
653 }
654 
generateRewriter(pdl::ReplaceOp replaceOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)655 void PatternLowering::generateRewriter(
656     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
657     function_ref<Value(Value)> mapRewriteValue) {
658   SmallVector<Value, 4> replOperands;
659 
660   // If the replacement was another operation, get its results. `pdl` allows
661   // for using an operation for simplicitly, but the interpreter isn't as
662   // user facing.
663   if (Value replOp = replaceOp.replOperation()) {
664     // Don't use replace if we know the replaced operation has no results.
665     auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
666     if (!opOp || !opOp.types().empty()) {
667       replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
668           replOp.getLoc(), mapRewriteValue(replOp)));
669     }
670   } else {
671     for (Value operand : replaceOp.replValues())
672       replOperands.push_back(mapRewriteValue(operand));
673   }
674 
675   // If there are no replacement values, just create an erase instead.
676   if (replOperands.empty()) {
677     builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
678                                         mapRewriteValue(replaceOp.operation()));
679     return;
680   }
681 
682   builder.create<pdl_interp::ReplaceOp>(
683       replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
684 }
685 
generateRewriter(pdl::ResultOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)686 void PatternLowering::generateRewriter(
687     pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
688     function_ref<Value(Value)> mapRewriteValue) {
689   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
690       resultOp.getLoc(), builder.getType<pdl::ValueType>(),
691       mapRewriteValue(resultOp.parent()), resultOp.index());
692 }
693 
generateRewriter(pdl::ResultsOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)694 void PatternLowering::generateRewriter(
695     pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
696     function_ref<Value(Value)> mapRewriteValue) {
697   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
698       resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
699       resultOp.index());
700 }
701 
generateRewriter(pdl::TypeOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)702 void PatternLowering::generateRewriter(
703     pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
704     function_ref<Value(Value)> mapRewriteValue) {
705   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
706   // type.
707   if (TypeAttr typeAttr = typeOp.typeAttr()) {
708     rewriteValues[typeOp] =
709         builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
710   }
711 }
712 
generateRewriter(pdl::TypesOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)713 void PatternLowering::generateRewriter(
714     pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
715     function_ref<Value(Value)> mapRewriteValue) {
716   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
717   // type.
718   if (ArrayAttr typeAttr = typeOp.typesAttr()) {
719     rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
720         typeOp.getLoc(), typeOp.getType(), typeAttr);
721   }
722 }
723 
generateOperationResultTypeRewriter(pdl::OperationOp op,SmallVectorImpl<Value> & types,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)724 void PatternLowering::generateOperationResultTypeRewriter(
725     pdl::OperationOp op, SmallVectorImpl<Value> &types,
726     DenseMap<Value, Value> &rewriteValues,
727     function_ref<Value(Value)> mapRewriteValue) {
728   // Look for an operation that was replaced by `op`. The result types will be
729   // inferred from the results that were replaced.
730   Block *rewriterBlock = op->getBlock();
731   Value replacedOp;
732   for (OpOperand &use : op.op().getUses()) {
733     // Check that the use corresponds to a ReplaceOp and that it is the
734     // replacement value, not the operation being replaced.
735     pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
736     if (!replOpUser || use.getOperandNumber() == 0)
737       continue;
738     // Make sure the replaced operation was defined before this one.
739     Value replOpVal = replOpUser.operation();
740     Operation *replacedOp = replOpVal.getDefiningOp();
741     if (replacedOp->getBlock() == rewriterBlock &&
742         !replacedOp->isBeforeInBlock(op))
743       continue;
744 
745     Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
746         replacedOp->getLoc(), mapRewriteValue(replOpVal));
747     types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
748         replacedOp->getLoc(), replacedOpResults));
749     return;
750   }
751 
752   // Check if the operation has type inference support.
753   if (op.hasTypeInference()) {
754     types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc()));
755     return;
756   }
757 
758   // Otherwise, handle inference for each of the result types individually.
759   OperandRange resultTypeValues = op.types();
760   types.reserve(resultTypeValues.size());
761   for (auto it : llvm::enumerate(resultTypeValues)) {
762     Value resultType = it.value();
763 
764     // Check for an already translated value.
765     if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
766       types.push_back(existingRewriteValue);
767       continue;
768     }
769 
770     // Check for an input from the matcher.
771     if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
772       types.push_back(mapRewriteValue(resultType));
773       continue;
774     }
775 
776     // The verifier asserts that the result types of each pdl.operation can be
777     // inferred. If we reach here, there is a bug either in the logic above or
778     // in the verifier for pdl.operation.
779     op->emitOpError() << "unable to infer result type for operation";
780     llvm_unreachable("unable to infer result type for operation");
781   }
782 }
783 
784 //===----------------------------------------------------------------------===//
785 // Conversion Pass
786 //===----------------------------------------------------------------------===//
787 
788 namespace {
789 struct PDLToPDLInterpPass
790     : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
791   void runOnOperation() final;
792 };
793 } // namespace
794 
795 /// Convert the given module containing PDL pattern operations into a PDL
796 /// Interpreter operations.
runOnOperation()797 void PDLToPDLInterpPass::runOnOperation() {
798   ModuleOp module = getOperation();
799 
800   // Create the main matcher function This function contains all of the match
801   // related functionality from patterns in the module.
802   OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
803   FuncOp matcherFunc = builder.create<FuncOp>(
804       module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
805       builder.getFunctionType(builder.getType<pdl::OperationType>(),
806                               /*results=*/llvm::None),
807       /*attrs=*/llvm::None);
808 
809   // Create a nested module to hold the functions invoked for rewriting the IR
810   // after a successful match.
811   ModuleOp rewriterModule = builder.create<ModuleOp>(
812       module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
813 
814   // Generate the code for the patterns within the module.
815   PatternLowering generator(matcherFunc, rewriterModule);
816   generator.lower(module);
817 
818   // After generation, delete all of the pattern operations.
819   for (pdl::PatternOp pattern :
820        llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
821     pattern.erase();
822 }
823 
createPDLToPDLInterpPass()824 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
825   return std::make_unique<PDLToPDLInterpPass>();
826 }
827