1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 #include "../PassDetail.h"
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::pdl_to_pdl_interp;
25 
26 //===----------------------------------------------------------------------===//
27 // PatternLowering
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 /// This class generators operations within the PDL Interpreter dialect from a
32 /// given module containing PDL pattern operations.
33 struct PatternLowering {
34 public:
35   PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule);
36 
37   /// Generate code for matching and rewriting based on the pattern operations
38   /// within the module.
39   void lower(ModuleOp module);
40 
41 private:
42   using ValueMap = llvm::ScopedHashTable<Position *, Value>;
43   using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
44 
45   /// Generate interpreter operations for the tree rooted at the given matcher
46   /// node.
47   Block *generateMatcher(MatcherNode &node);
48 
49   /// Get or create an access to the provided positional value within the
50   /// current block.
51   Value getValueAt(Block *cur, Position *pos);
52 
53   /// Create an interpreter predicate operation, branching to the provided true
54   /// and false destinations.
55   void generatePredicate(Block *currentBlock, Qualifier *question,
56                          Qualifier *answer, Value val, Block *trueDest,
57                          Block *falseDest);
58 
59   /// Create an interpreter switch predicate operation, with a provided default
60   /// and several case destinations.
61   void generateSwitch(SwitchNode *switchNode, Block *currentBlock,
62                       Qualifier *question, Value val, Block *defaultDest);
63 
64   /// Create the interpreter operations to record a successful pattern match.
65   void generateRecordMatch(Block *currentBlock, Block *nextBlock,
66                            pdl::PatternOp pattern);
67 
68   /// Generate a rewriter function for the given pattern operation, and returns
69   /// a reference to that function.
70   SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
71                                  SmallVectorImpl<Position *> &usedMatchValues);
72 
73   /// Generate the rewriter code for the given operation.
74   void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
75                         DenseMap<Value, Value> &rewriteValues,
76                         function_ref<Value(Value)> mapRewriteValue);
77   void generateRewriter(pdl::AttributeOp attrOp,
78                         DenseMap<Value, Value> &rewriteValues,
79                         function_ref<Value(Value)> mapRewriteValue);
80   void generateRewriter(pdl::EraseOp eraseOp,
81                         DenseMap<Value, Value> &rewriteValues,
82                         function_ref<Value(Value)> mapRewriteValue);
83   void generateRewriter(pdl::OperationOp operationOp,
84                         DenseMap<Value, Value> &rewriteValues,
85                         function_ref<Value(Value)> mapRewriteValue);
86   void generateRewriter(pdl::ReplaceOp replaceOp,
87                         DenseMap<Value, Value> &rewriteValues,
88                         function_ref<Value(Value)> mapRewriteValue);
89   void generateRewriter(pdl::ResultOp resultOp,
90                         DenseMap<Value, Value> &rewriteValues,
91                         function_ref<Value(Value)> mapRewriteValue);
92   void generateRewriter(pdl::ResultsOp resultOp,
93                         DenseMap<Value, Value> &rewriteValues,
94                         function_ref<Value(Value)> mapRewriteValue);
95   void generateRewriter(pdl::TypeOp typeOp,
96                         DenseMap<Value, Value> &rewriteValues,
97                         function_ref<Value(Value)> mapRewriteValue);
98   void generateRewriter(pdl::TypesOp typeOp,
99                         DenseMap<Value, Value> &rewriteValues,
100                         function_ref<Value(Value)> mapRewriteValue);
101 
102   /// Generate the values used for resolving the result types of an operation
103   /// created within a dag rewriter region.
104   void generateOperationResultTypeRewriter(
105       pdl::OperationOp op, SmallVectorImpl<Value> &types,
106       DenseMap<Value, Value> &rewriteValues,
107       function_ref<Value(Value)> mapRewriteValue);
108 
109   /// A builder to use when generating interpreter operations.
110   OpBuilder builder;
111 
112   /// The matcher function used for all match related logic within PDL patterns.
113   FuncOp matcherFunc;
114 
115   /// The rewriter module containing the all rewrite related logic within PDL
116   /// patterns.
117   ModuleOp rewriterModule;
118 
119   /// The symbol table of the rewriter module used for insertion.
120   SymbolTable rewriterSymbolTable;
121 
122   /// A scoped map connecting a position with the corresponding interpreter
123   /// value.
124   ValueMap values;
125 
126   /// A stack of blocks used as the failure destination for matcher nodes that
127   /// don't have an explicit failure path.
128   SmallVector<Block *, 8> failureBlockStack;
129 
130   /// A mapping between values defined in a pattern match, and the corresponding
131   /// positional value.
132   DenseMap<Value, Position *> valueToPosition;
133 
134   /// The set of operation values whose whose location will be used for newly
135   /// generated operations.
136   SetVector<Value> locOps;
137 };
138 } // end anonymous namespace
139 
PatternLowering(FuncOp matcherFunc,ModuleOp rewriterModule)140 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule)
141     : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
142       rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
143 
lower(ModuleOp module)144 void PatternLowering::lower(ModuleOp module) {
145   PredicateUniquer predicateUniquer;
146   PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
147 
148   // Define top-level scope for the arguments to the matcher function.
149   ValueMapScope topLevelValueScope(values);
150 
151   // Insert the root operation, i.e. argument to the matcher, at the root
152   // position.
153   Block *matcherEntryBlock = matcherFunc.addEntryBlock();
154   values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
155 
156   // Generate a root matcher node from the provided PDL module.
157   std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
158       module, predicateBuilder, valueToPosition);
159   Block *firstMatcherBlock = generateMatcher(*root);
160 
161   // After generation, merged the first matched block into the entry.
162   matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
163                                             firstMatcherBlock->getOperations());
164   firstMatcherBlock->erase();
165 }
166 
generateMatcher(MatcherNode & node)167 Block *PatternLowering::generateMatcher(MatcherNode &node) {
168   // Push a new scope for the values used by this matcher.
169   Block *block = matcherFunc.addBlock();
170   ValueMapScope scope(values);
171 
172   // If this is the return node, simply insert the corresponding interpreter
173   // finalize.
174   if (isa<ExitNode>(node)) {
175     builder.setInsertionPointToEnd(block);
176     builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
177     return block;
178   }
179 
180   // If this node contains a position, get the corresponding value for this
181   // block.
182   Position *position = node.getPosition();
183   Value val = position ? getValueAt(block, position) : Value();
184 
185   // Get the next block in the match sequence.
186   std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
187   Block *nextBlock;
188   if (failureNode) {
189     nextBlock = generateMatcher(*failureNode);
190     failureBlockStack.push_back(nextBlock);
191   } else {
192     assert(!failureBlockStack.empty() && "expected valid failure block");
193     nextBlock = failureBlockStack.back();
194   }
195 
196   // If this value corresponds to an operation, record that we are going to use
197   // its location as part of a fused location.
198   bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
199   if (isOperationValue)
200     locOps.insert(val);
201 
202   // Generate code for a boolean predicate node.
203   if (auto *boolNode = dyn_cast<BoolNode>(&node)) {
204     auto *child = generateMatcher(*boolNode->getSuccessNode());
205     generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val,
206                       child, nextBlock);
207 
208     // Generate code for a switch node.
209   } else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) {
210     generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock);
211 
212     // Generate code for a success node.
213   } else if (auto *successNode = dyn_cast<SuccessNode>(&node)) {
214     generateRecordMatch(block, nextBlock, successNode->getPattern());
215   }
216 
217   if (failureNode)
218     failureBlockStack.pop_back();
219   if (isOperationValue)
220     locOps.remove(val);
221   return block;
222 }
223 
getValueAt(Block * cur,Position * pos)224 Value PatternLowering::getValueAt(Block *cur, Position *pos) {
225   if (Value val = values.lookup(pos))
226     return val;
227 
228   // Get the value for the parent position.
229   Value parentVal = getValueAt(cur, pos->getParent());
230 
231   // TODO: Use a location from the position.
232   Location loc = parentVal.getLoc();
233   builder.setInsertionPointToEnd(cur);
234   Value value;
235   switch (pos->getKind()) {
236   case Predicates::OperationPos:
237     value = builder.create<pdl_interp::GetDefiningOpOp>(
238         loc, builder.getType<pdl::OperationType>(), parentVal);
239     break;
240   case Predicates::OperandPos: {
241     auto *operandPos = cast<OperandPosition>(pos);
242     value = builder.create<pdl_interp::GetOperandOp>(
243         loc, builder.getType<pdl::ValueType>(), parentVal,
244         operandPos->getOperandNumber());
245     break;
246   }
247   case Predicates::OperandGroupPos: {
248     auto *operandPos = cast<OperandGroupPosition>(pos);
249     Type valueTy = builder.getType<pdl::ValueType>();
250     value = builder.create<pdl_interp::GetOperandsOp>(
251         loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
252         parentVal, operandPos->getOperandGroupNumber());
253     break;
254   }
255   case Predicates::AttributePos: {
256     auto *attrPos = cast<AttributePosition>(pos);
257     value = builder.create<pdl_interp::GetAttributeOp>(
258         loc, builder.getType<pdl::AttributeType>(), parentVal,
259         attrPos->getName().strref());
260     break;
261   }
262   case Predicates::TypePos: {
263     if (parentVal.getType().isa<pdl::AttributeType>())
264       value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
265     else
266       value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
267     break;
268   }
269   case Predicates::ResultPos: {
270     auto *resPos = cast<ResultPosition>(pos);
271     value = builder.create<pdl_interp::GetResultOp>(
272         loc, builder.getType<pdl::ValueType>(), parentVal,
273         resPos->getResultNumber());
274     break;
275   }
276   case Predicates::ResultGroupPos: {
277     auto *resPos = cast<ResultGroupPosition>(pos);
278     Type valueTy = builder.getType<pdl::ValueType>();
279     value = builder.create<pdl_interp::GetResultsOp>(
280         loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
281         parentVal, resPos->getResultGroupNumber());
282     break;
283   }
284   default:
285     llvm_unreachable("Generating unknown Position getter");
286     break;
287   }
288   values.insert(pos, value);
289   return value;
290 }
291 
generatePredicate(Block * currentBlock,Qualifier * question,Qualifier * answer,Value val,Block * trueDest,Block * falseDest)292 void PatternLowering::generatePredicate(Block *currentBlock,
293                                         Qualifier *question, Qualifier *answer,
294                                         Value val, Block *trueDest,
295                                         Block *falseDest) {
296   builder.setInsertionPointToEnd(currentBlock);
297   Location loc = val.getLoc();
298   Predicates::Kind kind = question->getKind();
299   switch (kind) {
300   case Predicates::IsNotNullQuestion:
301     builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest);
302     break;
303   case Predicates::OperationNameQuestion: {
304     auto *opNameAnswer = cast<OperationNameAnswer>(answer);
305     builder.create<pdl_interp::CheckOperationNameOp>(
306         loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest);
307     break;
308   }
309   case Predicates::TypeQuestion: {
310     auto *ans = cast<TypeAnswer>(answer);
311     if (val.getType().isa<pdl::RangeType>())
312       builder.create<pdl_interp::CheckTypesOp>(
313           loc, val, ans->getValue().cast<ArrayAttr>(), trueDest, falseDest);
314     else
315       builder.create<pdl_interp::CheckTypeOp>(
316           loc, val, ans->getValue().cast<TypeAttr>(), trueDest, falseDest);
317     break;
318   }
319   case Predicates::AttributeQuestion: {
320     auto *ans = cast<AttributeAnswer>(answer);
321     builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
322                                                  trueDest, falseDest);
323     break;
324   }
325   case Predicates::OperandCountAtLeastQuestion:
326   case Predicates::OperandCountQuestion:
327     builder.create<pdl_interp::CheckOperandCountOp>(
328         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
329         /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
330         trueDest, falseDest);
331     break;
332   case Predicates::ResultCountAtLeastQuestion:
333   case Predicates::ResultCountQuestion:
334     builder.create<pdl_interp::CheckResultCountOp>(
335         loc, val, cast<UnsignedAnswer>(answer)->getValue(),
336         /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
337         trueDest, falseDest);
338     break;
339   case Predicates::EqualToQuestion: {
340     auto *equalToQuestion = cast<EqualToQuestion>(question);
341     builder.create<pdl_interp::AreEqualOp>(
342         loc, val, getValueAt(currentBlock, equalToQuestion->getValue()),
343         trueDest, falseDest);
344     break;
345   }
346   case Predicates::ConstraintQuestion: {
347     auto *cstQuestion = cast<ConstraintQuestion>(question);
348     SmallVector<Value, 2> args;
349     for (Position *position : std::get<1>(cstQuestion->getValue()))
350       args.push_back(getValueAt(currentBlock, position));
351     builder.create<pdl_interp::ApplyConstraintOp>(
352         loc, std::get<0>(cstQuestion->getValue()), args,
353         std::get<2>(cstQuestion->getValue()).cast<ArrayAttr>(), trueDest,
354         falseDest);
355     break;
356   }
357   default:
358     llvm_unreachable("Generating unknown Predicate operation");
359   }
360 }
361 
362 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
createSwitchOp(Value val,Block * defaultDest,OpBuilder & builder,llvm::MapVector<Qualifier *,Block * > & dests)363 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
364                            llvm::MapVector<Qualifier *, Block *> &dests) {
365   std::vector<ValT> values;
366   std::vector<Block *> blocks;
367   values.reserve(dests.size());
368   blocks.reserve(dests.size());
369   for (const auto &it : dests) {
370     blocks.push_back(it.second);
371     values.push_back(cast<PredT>(it.first)->getValue());
372   }
373   builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
374 }
375 
generateSwitch(SwitchNode * switchNode,Block * currentBlock,Qualifier * question,Value val,Block * defaultDest)376 void PatternLowering::generateSwitch(SwitchNode *switchNode,
377                                      Block *currentBlock, Qualifier *question,
378                                      Value val, Block *defaultDest) {
379   // If the switch question is not an exact answer, i.e. for the `at_least`
380   // cases, we generate a special block sequence.
381   Predicates::Kind kind = question->getKind();
382   if (kind == Predicates::OperandCountAtLeastQuestion ||
383       kind == Predicates::ResultCountAtLeastQuestion) {
384     // Order the children such that the cases are in reverse numerical order.
385     SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
386         llvm::seq<unsigned>(0, switchNode->getChildren().size()));
387     llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
388       return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
389              cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
390     });
391 
392     // Build the destination for each child using the next highest child as a
393     // a failure destination. This essentially creates the following control
394     // flow:
395     //
396     // if (operand_count < 1)
397     //   goto failure
398     // if (child1.match())
399     //   ...
400     //
401     // if (operand_count < 2)
402     //   goto failure
403     // if (child2.match())
404     //   ...
405     //
406     // failure:
407     //   ...
408     //
409     failureBlockStack.push_back(defaultDest);
410     for (unsigned idx : sortedChildren) {
411       auto &child = switchNode->getChild(idx);
412       Block *childBlock = generateMatcher(*child.second);
413       Block *predicateBlock = builder.createBlock(childBlock);
414       generatePredicate(predicateBlock, question, child.first, val, childBlock,
415                         defaultDest);
416       failureBlockStack.back() = predicateBlock;
417     }
418     Block *firstPredicateBlock = failureBlockStack.pop_back_val();
419     currentBlock->getOperations().splice(currentBlock->end(),
420                                          firstPredicateBlock->getOperations());
421     firstPredicateBlock->erase();
422     return;
423   }
424 
425   // Otherwise, generate each of the children and generate an interpreter
426   // switch.
427   llvm::MapVector<Qualifier *, Block *> children;
428   for (auto &it : switchNode->getChildren())
429     children.insert({it.first, generateMatcher(*it.second)});
430   builder.setInsertionPointToEnd(currentBlock);
431 
432   switch (question->getKind()) {
433   case Predicates::OperandCountQuestion:
434     return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
435                           int32_t>(val, defaultDest, builder, children);
436   case Predicates::ResultCountQuestion:
437     return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
438                           int32_t>(val, defaultDest, builder, children);
439   case Predicates::OperationNameQuestion:
440     return createSwitchOp<pdl_interp::SwitchOperationNameOp,
441                           OperationNameAnswer>(val, defaultDest, builder,
442                                                children);
443   case Predicates::TypeQuestion:
444     if (val.getType().isa<pdl::RangeType>()) {
445       return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
446           val, defaultDest, builder, children);
447     }
448     return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
449         val, defaultDest, builder, children);
450   case Predicates::AttributeQuestion:
451     return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
452         val, defaultDest, builder, children);
453   default:
454     llvm_unreachable("Generating unknown switch predicate.");
455   }
456 }
457 
generateRecordMatch(Block * currentBlock,Block * nextBlock,pdl::PatternOp pattern)458 void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock,
459                                           pdl::PatternOp pattern) {
460   // Generate a rewriter for the pattern this success node represents, and track
461   // any values used from the match region.
462   SmallVector<Position *, 8> usedMatchValues;
463   SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
464 
465   // Process any values used in the rewrite that are defined in the match.
466   std::vector<Value> mappedMatchValues;
467   mappedMatchValues.reserve(usedMatchValues.size());
468   for (Position *position : usedMatchValues)
469     mappedMatchValues.push_back(getValueAt(currentBlock, position));
470 
471   // Collect the set of operations generated by the rewriter.
472   SmallVector<StringRef, 4> generatedOps;
473   for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
474     generatedOps.push_back(*op.name());
475   ArrayAttr generatedOpsAttr;
476   if (!generatedOps.empty())
477     generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
478 
479   // Grab the root kind if present.
480   StringAttr rootKindAttr;
481   if (Optional<StringRef> rootKind = pattern.getRootKind())
482     rootKindAttr = builder.getStringAttr(*rootKind);
483 
484   builder.setInsertionPointToEnd(currentBlock);
485   builder.create<pdl_interp::RecordMatchOp>(
486       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
487       rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
488       nextBlock);
489 }
490 
generateRewriter(pdl::PatternOp pattern,SmallVectorImpl<Position * > & usedMatchValues)491 SymbolRefAttr PatternLowering::generateRewriter(
492     pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
493   FuncOp rewriterFunc =
494       FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter",
495                      builder.getFunctionType(llvm::None, llvm::None));
496   rewriterSymbolTable.insert(rewriterFunc);
497 
498   // Generate the rewriter function body.
499   builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock());
500 
501   // Map an input operand of the pattern to a generated interpreter value.
502   DenseMap<Value, Value> rewriteValues;
503   auto mapRewriteValue = [&](Value oldValue) {
504     Value &newValue = rewriteValues[oldValue];
505     if (newValue)
506       return newValue;
507 
508     // Prefer materializing constants directly when possible.
509     Operation *oldOp = oldValue.getDefiningOp();
510     if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
511       if (Attribute value = attrOp.valueAttr()) {
512         return newValue = builder.create<pdl_interp::CreateAttributeOp>(
513                    attrOp.getLoc(), value);
514       }
515     } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
516       if (TypeAttr type = typeOp.typeAttr()) {
517         return newValue = builder.create<pdl_interp::CreateTypeOp>(
518                    typeOp.getLoc(), type);
519       }
520     } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
521       if (ArrayAttr type = typeOp.typesAttr()) {
522         return newValue = builder.create<pdl_interp::CreateTypesOp>(
523                    typeOp.getLoc(), typeOp.getType(), type);
524       }
525     }
526 
527     // Otherwise, add this as an input to the rewriter.
528     Position *inputPos = valueToPosition.lookup(oldValue);
529     assert(inputPos && "expected value to be a pattern input");
530     usedMatchValues.push_back(inputPos);
531     return newValue = rewriterFunc.front().addArgument(oldValue.getType());
532   };
533 
534   // If this is a custom rewriter, simply dispatch to the registered rewrite
535   // method.
536   pdl::RewriteOp rewriter = pattern.getRewriter();
537   if (StringAttr rewriteName = rewriter.nameAttr()) {
538     auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
539     SmallVector<Value, 4> args(1, mapRewriteValue(rewriter.root()));
540     args.append(mappedArgs.begin(), mappedArgs.end());
541     builder.create<pdl_interp::ApplyRewriteOp>(
542         rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args,
543         rewriter.externalConstParamsAttr());
544   } else {
545     // Otherwise this is a dag rewriter defined using PDL operations.
546     for (Operation &rewriteOp : *rewriter.getBody()) {
547       llvm::TypeSwitch<Operation *>(&rewriteOp)
548           .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
549                 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
550                 pdl::TypeOp, pdl::TypesOp>([&](auto op) {
551             this->generateRewriter(op, rewriteValues, mapRewriteValue);
552           });
553     }
554   }
555 
556   // Update the signature of the rewrite function.
557   rewriterFunc.setType(builder.getFunctionType(
558       llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
559       /*results=*/llvm::None));
560 
561   builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
562   return builder.getSymbolRefAttr(
563       pdl_interp::PDLInterpDialect::getRewriterModuleName(),
564       builder.getSymbolRefAttr(rewriterFunc));
565 }
566 
generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)567 void PatternLowering::generateRewriter(
568     pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
569     function_ref<Value(Value)> mapRewriteValue) {
570   SmallVector<Value, 2> arguments;
571   for (Value argument : rewriteOp.args())
572     arguments.push_back(mapRewriteValue(argument));
573   auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
574       rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
575       arguments, rewriteOp.constParamsAttr());
576   for (auto it : llvm::zip(rewriteOp.results(), interpOp.results()))
577     rewriteValues[std::get<0>(it)] = std::get<1>(it);
578 }
579 
generateRewriter(pdl::AttributeOp attrOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)580 void PatternLowering::generateRewriter(
581     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
582     function_ref<Value(Value)> mapRewriteValue) {
583   Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
584       attrOp.getLoc(), attrOp.valueAttr());
585   rewriteValues[attrOp] = newAttr;
586 }
587 
generateRewriter(pdl::EraseOp eraseOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)588 void PatternLowering::generateRewriter(
589     pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
590     function_ref<Value(Value)> mapRewriteValue) {
591   builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
592                                       mapRewriteValue(eraseOp.operation()));
593 }
594 
generateRewriter(pdl::OperationOp operationOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)595 void PatternLowering::generateRewriter(
596     pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
597     function_ref<Value(Value)> mapRewriteValue) {
598   SmallVector<Value, 4> operands;
599   for (Value operand : operationOp.operands())
600     operands.push_back(mapRewriteValue(operand));
601 
602   SmallVector<Value, 4> attributes;
603   for (Value attr : operationOp.attributes())
604     attributes.push_back(mapRewriteValue(attr));
605 
606   SmallVector<Value, 2> types;
607   generateOperationResultTypeRewriter(operationOp, types, rewriteValues,
608                                       mapRewriteValue);
609 
610   // Create the new operation.
611   Location loc = operationOp.getLoc();
612   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
613       loc, *operationOp.name(), types, operands, attributes,
614       operationOp.attributeNames());
615   rewriteValues[operationOp.op()] = createdOp;
616 
617   // Generate accesses for any results that have their types constrained.
618   // Handle the case where there is a single range representing all of the
619   // result types.
620   OperandRange resultTys = operationOp.types();
621   if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
622     Value &type = rewriteValues[resultTys[0]];
623     if (!type) {
624       auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
625       type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
626     }
627     return;
628   }
629 
630   // Otherwise, populate the individual results.
631   bool seenVariableLength = false;
632   Type valueTy = builder.getType<pdl::ValueType>();
633   Type valueRangeTy = pdl::RangeType::get(valueTy);
634   for (auto it : llvm::enumerate(resultTys)) {
635     Value &type = rewriteValues[it.value()];
636     if (type)
637       continue;
638     bool isVariadic = it.value().getType().isa<pdl::RangeType>();
639     seenVariableLength |= isVariadic;
640 
641     // After a variable length result has been seen, we need to use result
642     // groups because the exact index of the result is not statically known.
643     Value resultVal;
644     if (seenVariableLength)
645       resultVal = builder.create<pdl_interp::GetResultsOp>(
646           loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
647     else
648       resultVal = builder.create<pdl_interp::GetResultOp>(
649           loc, valueTy, createdOp, it.index());
650     type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
651   }
652 }
653 
generateRewriter(pdl::ReplaceOp replaceOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)654 void PatternLowering::generateRewriter(
655     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
656     function_ref<Value(Value)> mapRewriteValue) {
657   SmallVector<Value, 4> replOperands;
658 
659   // If the replacement was another operation, get its results. `pdl` allows
660   // for using an operation for simplicitly, but the interpreter isn't as
661   // user facing.
662   if (Value replOp = replaceOp.replOperation()) {
663     // Don't use replace if we know the replaced operation has no results.
664     auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
665     if (!opOp || !opOp.types().empty()) {
666       replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
667           replOp.getLoc(), mapRewriteValue(replOp)));
668     }
669   } else {
670     for (Value operand : replaceOp.replValues())
671       replOperands.push_back(mapRewriteValue(operand));
672   }
673 
674   // If there are no replacement values, just create an erase instead.
675   if (replOperands.empty()) {
676     builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
677                                         mapRewriteValue(replaceOp.operation()));
678     return;
679   }
680 
681   builder.create<pdl_interp::ReplaceOp>(
682       replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
683 }
684 
generateRewriter(pdl::ResultOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)685 void PatternLowering::generateRewriter(
686     pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
687     function_ref<Value(Value)> mapRewriteValue) {
688   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
689       resultOp.getLoc(), builder.getType<pdl::ValueType>(),
690       mapRewriteValue(resultOp.parent()), resultOp.index());
691 }
692 
generateRewriter(pdl::ResultsOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)693 void PatternLowering::generateRewriter(
694     pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
695     function_ref<Value(Value)> mapRewriteValue) {
696   rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
697       resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
698       resultOp.index());
699 }
700 
generateRewriter(pdl::TypeOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)701 void PatternLowering::generateRewriter(
702     pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
703     function_ref<Value(Value)> mapRewriteValue) {
704   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
705   // type.
706   if (TypeAttr typeAttr = typeOp.typeAttr()) {
707     rewriteValues[typeOp] =
708         builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
709   }
710 }
711 
generateRewriter(pdl::TypesOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)712 void PatternLowering::generateRewriter(
713     pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
714     function_ref<Value(Value)> mapRewriteValue) {
715   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
716   // type.
717   if (ArrayAttr typeAttr = typeOp.typesAttr()) {
718     rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
719         typeOp.getLoc(), typeOp.getType(), typeAttr);
720   }
721 }
722 
generateOperationResultTypeRewriter(pdl::OperationOp op,SmallVectorImpl<Value> & types,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)723 void PatternLowering::generateOperationResultTypeRewriter(
724     pdl::OperationOp op, SmallVectorImpl<Value> &types,
725     DenseMap<Value, Value> &rewriteValues,
726     function_ref<Value(Value)> mapRewriteValue) {
727   // Look for an operation that was replaced by `op`. The result types will be
728   // inferred from the results that were replaced.
729   Block *rewriterBlock = op->getBlock();
730   Value replacedOp;
731   for (OpOperand &use : op.op().getUses()) {
732     // Check that the use corresponds to a ReplaceOp and that it is the
733     // replacement value, not the operation being replaced.
734     pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
735     if (!replOpUser || use.getOperandNumber() == 0)
736       continue;
737     // Make sure the replaced operation was defined before this one.
738     Value replOpVal = replOpUser.operation();
739     Operation *replacedOp = replOpVal.getDefiningOp();
740     if (replacedOp->getBlock() == rewriterBlock &&
741         !replacedOp->isBeforeInBlock(op))
742       continue;
743 
744     Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
745         replacedOp->getLoc(), mapRewriteValue(replOpVal));
746     types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
747         replacedOp->getLoc(), replacedOpResults));
748     return;
749   }
750 
751   // Check if the operation has type inference support.
752   if (op.hasTypeInference()) {
753     types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc()));
754     return;
755   }
756 
757   // Otherwise, handle inference for each of the result types individually.
758   OperandRange resultTypeValues = op.types();
759   types.reserve(resultTypeValues.size());
760   for (auto it : llvm::enumerate(resultTypeValues)) {
761     Value resultType = it.value();
762 
763     // Check for an already translated value.
764     if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
765       types.push_back(existingRewriteValue);
766       continue;
767     }
768 
769     // Check for an input from the matcher.
770     if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
771       types.push_back(mapRewriteValue(resultType));
772       continue;
773     }
774 
775     // The verifier asserts that the result types of each pdl.operation can be
776     // inferred. If we reach here, there is a bug either in the logic above or
777     // in the verifier for pdl.operation.
778     op->emitOpError() << "unable to infer result type for operation";
779     llvm_unreachable("unable to infer result type for operation");
780   }
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // Conversion Pass
785 //===----------------------------------------------------------------------===//
786 
787 namespace {
788 struct PDLToPDLInterpPass
789     : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
790   void runOnOperation() final;
791 };
792 } // namespace
793 
794 /// Convert the given module containing PDL pattern operations into a PDL
795 /// Interpreter operations.
runOnOperation()796 void PDLToPDLInterpPass::runOnOperation() {
797   ModuleOp module = getOperation();
798 
799   // Create the main matcher function This function contains all of the match
800   // related functionality from patterns in the module.
801   OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
802   FuncOp matcherFunc = builder.create<FuncOp>(
803       module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
804       builder.getFunctionType(builder.getType<pdl::OperationType>(),
805                               /*results=*/llvm::None),
806       /*attrs=*/llvm::None);
807 
808   // Create a nested module to hold the functions invoked for rewriting the IR
809   // after a successful match.
810   ModuleOp rewriterModule = builder.create<ModuleOp>(
811       module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
812 
813   // Generate the code for the patterns within the module.
814   PatternLowering generator(matcherFunc, rewriterModule);
815   generator.lower(module);
816 
817   // After generation, delete all of the pattern operations.
818   for (pdl::PatternOp pattern :
819        llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
820     pattern.erase();
821 }
822 
createPDLToPDLInterpPass()823 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
824   return std::make_unique<PDLToPDLInterpPass>();
825 }
826