1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 #include "../PassDetail.h"
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 using namespace mlir;
22 using namespace mlir::pdl_to_pdl_interp;
23 
24 //===----------------------------------------------------------------------===//
25 // PatternLowering
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 /// This class generators operations within the PDL Interpreter dialect from a
30 /// given module containing PDL pattern operations.
31 struct PatternLowering {
32 public:
33   PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule);
34 
35   /// Generate code for matching and rewriting based on the pattern operations
36   /// within the module.
37   void lower(ModuleOp module);
38 
39 private:
40   using ValueMap = llvm::ScopedHashTable<Position *, Value>;
41   using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
42 
43   /// Generate interpreter operations for the tree rooted at the given matcher
44   /// node.
45   Block *generateMatcher(MatcherNode &node);
46 
47   /// Get or create an access to the provided positional value within the
48   /// current block.
49   Value getValueAt(Block *cur, Position *pos);
50 
51   /// Create an interpreter predicate operation, branching to the provided true
52   /// and false destinations.
53   void generatePredicate(Block *currentBlock, Qualifier *question,
54                          Qualifier *answer, Value val, Block *trueDest,
55                          Block *falseDest);
56 
57   /// Create an interpreter switch predicate operation, with a provided default
58   /// and several case destinations.
59   void generateSwitch(Block *currentBlock, Qualifier *question, Value val,
60                       Block *defaultDest,
61                       ArrayRef<std::pair<Qualifier *, Block *>> dests);
62 
63   /// Create the interpreter operations to record a successful pattern match.
64   void generateRecordMatch(Block *currentBlock, Block *nextBlock,
65                            pdl::PatternOp pattern);
66 
67   /// Generate a rewriter function for the given pattern operation, and returns
68   /// a reference to that function.
69   SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
70                                  SmallVectorImpl<Position *> &usedMatchValues);
71 
72   /// Generate the rewriter code for the given operation.
73   void generateRewriter(pdl::AttributeOp attrOp,
74                         DenseMap<Value, Value> &rewriteValues,
75                         function_ref<Value(Value)> mapRewriteValue);
76   void generateRewriter(pdl::EraseOp eraseOp,
77                         DenseMap<Value, Value> &rewriteValues,
78                         function_ref<Value(Value)> mapRewriteValue);
79   void generateRewriter(pdl::OperationOp operationOp,
80                         DenseMap<Value, Value> &rewriteValues,
81                         function_ref<Value(Value)> mapRewriteValue);
82   void generateRewriter(pdl::CreateNativeOp createNativeOp,
83                         DenseMap<Value, Value> &rewriteValues,
84                         function_ref<Value(Value)> mapRewriteValue);
85   void generateRewriter(pdl::ReplaceOp replaceOp,
86                         DenseMap<Value, Value> &rewriteValues,
87                         function_ref<Value(Value)> mapRewriteValue);
88   void generateRewriter(pdl::TypeOp typeOp,
89                         DenseMap<Value, Value> &rewriteValues,
90                         function_ref<Value(Value)> mapRewriteValue);
91 
92   /// Generate the values used for resolving the result types of an operation
93   /// created within a dag rewriter region.
94   void generateOperationResultTypeRewriter(
95       pdl::OperationOp op, SmallVectorImpl<Value> &types,
96       DenseMap<Value, Value> &rewriteValues,
97       function_ref<Value(Value)> mapRewriteValue);
98 
99   /// A builder to use when generating interpreter operations.
100   OpBuilder builder;
101 
102   /// The matcher function used for all match related logic within PDL patterns.
103   FuncOp matcherFunc;
104 
105   /// The rewriter module containing the all rewrite related logic within PDL
106   /// patterns.
107   ModuleOp rewriterModule;
108 
109   /// The symbol table of the rewriter module used for insertion.
110   SymbolTable rewriterSymbolTable;
111 
112   /// A scoped map connecting a position with the corresponding interpreter
113   /// value.
114   ValueMap values;
115 
116   /// A stack of blocks used as the failure destination for matcher nodes that
117   /// don't have an explicit failure path.
118   SmallVector<Block *, 8> failureBlockStack;
119 
120   /// A mapping between values defined in a pattern match, and the corresponding
121   /// positional value.
122   DenseMap<Value, Position *> valueToPosition;
123 
124   /// The set of operation values whose whose location will be used for newly
125   /// generated operations.
126   llvm::SetVector<Value> locOps;
127 };
128 } // end anonymous namespace
129 
PatternLowering(FuncOp matcherFunc,ModuleOp rewriterModule)130 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule)
131     : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
132       rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
133 
lower(ModuleOp module)134 void PatternLowering::lower(ModuleOp module) {
135   PredicateUniquer predicateUniquer;
136   PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
137 
138   // Define top-level scope for the arguments to the matcher function.
139   ValueMapScope topLevelValueScope(values);
140 
141   // Insert the root operation, i.e. argument to the matcher, at the root
142   // position.
143   Block *matcherEntryBlock = matcherFunc.addEntryBlock();
144   values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
145 
146   // Generate a root matcher node from the provided PDL module.
147   std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
148       module, predicateBuilder, valueToPosition);
149   Block *firstMatcherBlock = generateMatcher(*root);
150 
151   // After generation, merged the first matched block into the entry.
152   matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
153                                             firstMatcherBlock->getOperations());
154   firstMatcherBlock->erase();
155 }
156 
generateMatcher(MatcherNode & node)157 Block *PatternLowering::generateMatcher(MatcherNode &node) {
158   // Push a new scope for the values used by this matcher.
159   Block *block = matcherFunc.addBlock();
160   ValueMapScope scope(values);
161 
162   // If this is the return node, simply insert the corresponding interpreter
163   // finalize.
164   if (isa<ExitNode>(node)) {
165     builder.setInsertionPointToEnd(block);
166     builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
167     return block;
168   }
169 
170   // If this node contains a position, get the corresponding value for this
171   // block.
172   Position *position = node.getPosition();
173   Value val = position ? getValueAt(block, position) : Value();
174 
175   // Get the next block in the match sequence.
176   std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
177   Block *nextBlock;
178   if (failureNode) {
179     nextBlock = generateMatcher(*failureNode);
180     failureBlockStack.push_back(nextBlock);
181   } else {
182     assert(!failureBlockStack.empty() && "expected valid failure block");
183     nextBlock = failureBlockStack.back();
184   }
185 
186   // If this value corresponds to an operation, record that we are going to use
187   // its location as part of a fused location.
188   bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
189   if (isOperationValue)
190     locOps.insert(val);
191 
192   // Generate code for a boolean predicate node.
193   if (auto *boolNode = dyn_cast<BoolNode>(&node)) {
194     auto *child = generateMatcher(*boolNode->getSuccessNode());
195     generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val,
196                       child, nextBlock);
197 
198     // Generate code for a switch node.
199   } else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) {
200     // Collect the next blocks for all of the children and generate a switch.
201     llvm::MapVector<Qualifier *, Block *> children;
202     for (auto &it : switchNode->getChildren())
203       children.insert({it.first, generateMatcher(*it.second)});
204     generateSwitch(block, node.getQuestion(), val, nextBlock,
205                    children.takeVector());
206 
207     // Generate code for a success node.
208   } else if (auto *successNode = dyn_cast<SuccessNode>(&node)) {
209     generateRecordMatch(block, nextBlock, successNode->getPattern());
210   }
211 
212   if (failureNode)
213     failureBlockStack.pop_back();
214   if (isOperationValue)
215     locOps.remove(val);
216   return block;
217 }
218 
getValueAt(Block * cur,Position * pos)219 Value PatternLowering::getValueAt(Block *cur, Position *pos) {
220   if (Value val = values.lookup(pos))
221     return val;
222 
223   // Get the value for the parent position.
224   Value parentVal = getValueAt(cur, pos->getParent());
225 
226   // TODO: Use a location from the position.
227   Location loc = parentVal.getLoc();
228   builder.setInsertionPointToEnd(cur);
229   Value value;
230   switch (pos->getKind()) {
231   case Predicates::OperationPos:
232     value = builder.create<pdl_interp::GetDefiningOpOp>(
233         loc, builder.getType<pdl::OperationType>(), parentVal);
234     break;
235   case Predicates::OperandPos: {
236     auto *operandPos = cast<OperandPosition>(pos);
237     value = builder.create<pdl_interp::GetOperandOp>(
238         loc, builder.getType<pdl::ValueType>(), parentVal,
239         operandPos->getOperandNumber());
240     break;
241   }
242   case Predicates::AttributePos: {
243     auto *attrPos = cast<AttributePosition>(pos);
244     value = builder.create<pdl_interp::GetAttributeOp>(
245         loc, builder.getType<pdl::AttributeType>(), parentVal,
246         attrPos->getName().strref());
247     break;
248   }
249   case Predicates::TypePos: {
250     if (parentVal.getType().isa<pdl::ValueType>())
251       value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
252     else
253       value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
254     break;
255   }
256   case Predicates::ResultPos: {
257     auto *resPos = cast<ResultPosition>(pos);
258     value = builder.create<pdl_interp::GetResultOp>(
259         loc, builder.getType<pdl::ValueType>(), parentVal,
260         resPos->getResultNumber());
261     break;
262   }
263   default:
264     llvm_unreachable("Generating unknown Position getter");
265     break;
266   }
267   values.insert(pos, value);
268   return value;
269 }
270 
generatePredicate(Block * currentBlock,Qualifier * question,Qualifier * answer,Value val,Block * trueDest,Block * falseDest)271 void PatternLowering::generatePredicate(Block *currentBlock,
272                                         Qualifier *question, Qualifier *answer,
273                                         Value val, Block *trueDest,
274                                         Block *falseDest) {
275   builder.setInsertionPointToEnd(currentBlock);
276   Location loc = val.getLoc();
277   switch (question->getKind()) {
278   case Predicates::IsNotNullQuestion:
279     builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest);
280     break;
281   case Predicates::OperationNameQuestion: {
282     auto *opNameAnswer = cast<OperationNameAnswer>(answer);
283     builder.create<pdl_interp::CheckOperationNameOp>(
284         loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest);
285     break;
286   }
287   case Predicates::TypeQuestion: {
288     auto *ans = cast<TypeAnswer>(answer);
289     builder.create<pdl_interp::CheckTypeOp>(
290         loc, val, TypeAttr::get(ans->getValue()), trueDest, falseDest);
291     break;
292   }
293   case Predicates::AttributeQuestion: {
294     auto *ans = cast<AttributeAnswer>(answer);
295     builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
296                                                  trueDest, falseDest);
297     break;
298   }
299   case Predicates::OperandCountQuestion: {
300     auto *unsignedAnswer = cast<UnsignedAnswer>(answer);
301     builder.create<pdl_interp::CheckOperandCountOp>(
302         loc, val, unsignedAnswer->getValue(), trueDest, falseDest);
303     break;
304   }
305   case Predicates::ResultCountQuestion: {
306     auto *unsignedAnswer = cast<UnsignedAnswer>(answer);
307     builder.create<pdl_interp::CheckResultCountOp>(
308         loc, val, unsignedAnswer->getValue(), trueDest, falseDest);
309     break;
310   }
311   case Predicates::EqualToQuestion: {
312     auto *equalToQuestion = cast<EqualToQuestion>(question);
313     builder.create<pdl_interp::AreEqualOp>(
314         loc, val, getValueAt(currentBlock, equalToQuestion->getValue()),
315         trueDest, falseDest);
316     break;
317   }
318   case Predicates::ConstraintQuestion: {
319     auto *cstQuestion = cast<ConstraintQuestion>(question);
320     SmallVector<Value, 2> args;
321     for (Position *position : std::get<1>(cstQuestion->getValue()))
322       args.push_back(getValueAt(currentBlock, position));
323     builder.create<pdl_interp::ApplyConstraintOp>(
324         loc, std::get<0>(cstQuestion->getValue()), args,
325         std::get<2>(cstQuestion->getValue()).cast<ArrayAttr>(), trueDest,
326         falseDest);
327     break;
328   }
329   default:
330     llvm_unreachable("Generating unknown Predicate operation");
331   }
332 }
333 
334 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
createSwitchOp(Value val,Block * defaultDest,OpBuilder & builder,ArrayRef<std::pair<Qualifier *,Block * >> dests)335 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
336                            ArrayRef<std::pair<Qualifier *, Block *>> dests) {
337   std::vector<ValT> values;
338   std::vector<Block *> blocks;
339   values.reserve(dests.size());
340   blocks.reserve(dests.size());
341   for (const auto &it : dests) {
342     blocks.push_back(it.second);
343     values.push_back(cast<PredT>(it.first)->getValue());
344   }
345   builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
346 }
347 
generateSwitch(Block * currentBlock,Qualifier * question,Value val,Block * defaultDest,ArrayRef<std::pair<Qualifier *,Block * >> dests)348 void PatternLowering::generateSwitch(
349     Block *currentBlock, Qualifier *question, Value val, Block *defaultDest,
350     ArrayRef<std::pair<Qualifier *, Block *>> dests) {
351   builder.setInsertionPointToEnd(currentBlock);
352   switch (question->getKind()) {
353   case Predicates::OperandCountQuestion:
354     return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
355                           int32_t>(val, defaultDest, builder, dests);
356   case Predicates::ResultCountQuestion:
357     return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
358                           int32_t>(val, defaultDest, builder, dests);
359   case Predicates::OperationNameQuestion:
360     return createSwitchOp<pdl_interp::SwitchOperationNameOp,
361                           OperationNameAnswer>(val, defaultDest, builder,
362                                                dests);
363   case Predicates::TypeQuestion:
364     return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
365         val, defaultDest, builder, dests);
366   case Predicates::AttributeQuestion:
367     return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
368         val, defaultDest, builder, dests);
369   default:
370     llvm_unreachable("Generating unknown switch predicate.");
371   }
372 }
373 
generateRecordMatch(Block * currentBlock,Block * nextBlock,pdl::PatternOp pattern)374 void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock,
375                                           pdl::PatternOp pattern) {
376   // Generate a rewriter for the pattern this success node represents, and track
377   // any values used from the match region.
378   SmallVector<Position *, 8> usedMatchValues;
379   SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
380 
381   // Process any values used in the rewrite that are defined in the match.
382   std::vector<Value> mappedMatchValues;
383   mappedMatchValues.reserve(usedMatchValues.size());
384   for (Position *position : usedMatchValues)
385     mappedMatchValues.push_back(getValueAt(currentBlock, position));
386 
387   // Collect the set of operations generated by the rewriter.
388   SmallVector<StringRef, 4> generatedOps;
389   for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
390     generatedOps.push_back(*op.name());
391   ArrayAttr generatedOpsAttr;
392   if (!generatedOps.empty())
393     generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
394 
395   // Grab the root kind if present.
396   StringAttr rootKindAttr;
397   if (Optional<StringRef> rootKind = pattern.getRootKind())
398     rootKindAttr = builder.getStringAttr(*rootKind);
399 
400   builder.setInsertionPointToEnd(currentBlock);
401   builder.create<pdl_interp::RecordMatchOp>(
402       pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
403       rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
404       nextBlock);
405 }
406 
generateRewriter(pdl::PatternOp pattern,SmallVectorImpl<Position * > & usedMatchValues)407 SymbolRefAttr PatternLowering::generateRewriter(
408     pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
409   FuncOp rewriterFunc =
410       FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter",
411                      builder.getFunctionType(llvm::None, llvm::None));
412   rewriterSymbolTable.insert(rewriterFunc);
413 
414   // Generate the rewriter function body.
415   builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock());
416 
417   // Map an input operand of the pattern to a generated interpreter value.
418   DenseMap<Value, Value> rewriteValues;
419   auto mapRewriteValue = [&](Value oldValue) {
420     Value &newValue = rewriteValues[oldValue];
421     if (newValue)
422       return newValue;
423 
424     // Prefer materializing constants directly when possible.
425     Operation *oldOp = oldValue.getDefiningOp();
426     if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
427       if (Attribute value = attrOp.valueAttr()) {
428         return newValue = builder.create<pdl_interp::CreateAttributeOp>(
429                    attrOp.getLoc(), value);
430       }
431     } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
432       if (TypeAttr type = typeOp.typeAttr()) {
433         return newValue = builder.create<pdl_interp::CreateTypeOp>(
434                    typeOp.getLoc(), type);
435       }
436     }
437 
438     // Otherwise, add this as an input to the rewriter.
439     Position *inputPos = valueToPosition.lookup(oldValue);
440     assert(inputPos && "expected value to be a pattern input");
441     usedMatchValues.push_back(inputPos);
442     return newValue = rewriterFunc.front().addArgument(oldValue.getType());
443   };
444 
445   // If this is a custom rewriter, simply dispatch to the registered rewrite
446   // method.
447   pdl::RewriteOp rewriter = pattern.getRewriter();
448   if (StringAttr rewriteName = rewriter.nameAttr()) {
449     Value root = mapRewriteValue(rewriter.root());
450     SmallVector<Value, 4> args = llvm::to_vector<4>(
451         llvm::map_range(rewriter.externalArgs(), mapRewriteValue));
452     builder.create<pdl_interp::ApplyRewriteOp>(
453         rewriter.getLoc(), rewriteName, root, args,
454         rewriter.externalConstParamsAttr());
455   } else {
456     // Otherwise this is a dag rewriter defined using PDL operations.
457     for (Operation &rewriteOp : *rewriter.getBody()) {
458       llvm::TypeSwitch<Operation *>(&rewriteOp)
459           .Case<pdl::AttributeOp, pdl::CreateNativeOp, pdl::EraseOp,
460                 pdl::OperationOp, pdl::ReplaceOp, pdl::TypeOp>([&](auto op) {
461             this->generateRewriter(op, rewriteValues, mapRewriteValue);
462           });
463     }
464   }
465 
466   // Update the signature of the rewrite function.
467   rewriterFunc.setType(builder.getFunctionType(
468       llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
469       /*results=*/llvm::None));
470 
471   builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
472   return builder.getSymbolRefAttr(
473       pdl_interp::PDLInterpDialect::getRewriterModuleName(),
474       builder.getSymbolRefAttr(rewriterFunc));
475 }
476 
generateRewriter(pdl::AttributeOp attrOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)477 void PatternLowering::generateRewriter(
478     pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
479     function_ref<Value(Value)> mapRewriteValue) {
480   Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
481       attrOp.getLoc(), attrOp.valueAttr());
482   rewriteValues[attrOp] = newAttr;
483 }
484 
generateRewriter(pdl::EraseOp eraseOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)485 void PatternLowering::generateRewriter(
486     pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
487     function_ref<Value(Value)> mapRewriteValue) {
488   builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
489                                       mapRewriteValue(eraseOp.operation()));
490 }
491 
generateRewriter(pdl::OperationOp operationOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)492 void PatternLowering::generateRewriter(
493     pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
494     function_ref<Value(Value)> mapRewriteValue) {
495   SmallVector<Value, 4> operands;
496   for (Value operand : operationOp.operands())
497     operands.push_back(mapRewriteValue(operand));
498 
499   SmallVector<Value, 4> attributes;
500   for (Value attr : operationOp.attributes())
501     attributes.push_back(mapRewriteValue(attr));
502 
503   SmallVector<Value, 2> types;
504   generateOperationResultTypeRewriter(operationOp, types, rewriteValues,
505                                       mapRewriteValue);
506 
507   // Create the new operation.
508   Location loc = operationOp.getLoc();
509   Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
510       loc, *operationOp.name(), types, operands, attributes,
511       operationOp.attributeNames());
512   rewriteValues[operationOp.op()] = createdOp;
513 
514   // Make all of the new operation results available.
515   OperandRange resultTypes = operationOp.types();
516   for (auto it : llvm::enumerate(operationOp.results())) {
517     Value getResultVal = builder.create<pdl_interp::GetResultOp>(
518         loc, builder.getType<pdl::ValueType>(), createdOp, it.index());
519     rewriteValues[it.value()] = getResultVal;
520 
521     // If any of the types have not been resolved, make those available as well.
522     Value &type = rewriteValues[resultTypes[it.index()]];
523     if (!type)
524       type = builder.create<pdl_interp::GetValueTypeOp>(loc, getResultVal);
525   }
526 }
527 
generateRewriter(pdl::CreateNativeOp createNativeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)528 void PatternLowering::generateRewriter(
529     pdl::CreateNativeOp createNativeOp, DenseMap<Value, Value> &rewriteValues,
530     function_ref<Value(Value)> mapRewriteValue) {
531   SmallVector<Value, 2> arguments;
532   for (Value argument : createNativeOp.args())
533     arguments.push_back(mapRewriteValue(argument));
534   Value result = builder.create<pdl_interp::CreateNativeOp>(
535       createNativeOp.getLoc(), createNativeOp.result().getType(),
536       createNativeOp.nameAttr(), arguments, createNativeOp.constParamsAttr());
537   rewriteValues[createNativeOp] = result;
538 }
539 
generateRewriter(pdl::ReplaceOp replaceOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)540 void PatternLowering::generateRewriter(
541     pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
542     function_ref<Value(Value)> mapRewriteValue) {
543   // If the replacement was another operation, get its results. `pdl` allows
544   // for using an operation for simplicitly, but the interpreter isn't as
545   // user facing.
546   ValueRange origOperands;
547   if (Value replOp = replaceOp.replOperation())
548     origOperands = cast<pdl::OperationOp>(replOp.getDefiningOp()).results();
549   else
550     origOperands = replaceOp.replValues();
551 
552   // If there are no replacement values, just create an erase instead.
553   if (origOperands.empty()) {
554     builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
555                                         mapRewriteValue(replaceOp.operation()));
556     return;
557   }
558 
559   SmallVector<Value, 4> replOperands;
560   for (Value operand : origOperands)
561     replOperands.push_back(mapRewriteValue(operand));
562   builder.create<pdl_interp::ReplaceOp>(
563       replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
564 }
565 
generateRewriter(pdl::TypeOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)566 void PatternLowering::generateRewriter(
567     pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
568     function_ref<Value(Value)> mapRewriteValue) {
569   // If the type isn't constant, the users (e.g. OperationOp) will resolve this
570   // type.
571   if (TypeAttr typeAttr = typeOp.typeAttr()) {
572     Value newType =
573         builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
574     rewriteValues[typeOp] = newType;
575   }
576 }
577 
generateOperationResultTypeRewriter(pdl::OperationOp op,SmallVectorImpl<Value> & types,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)578 void PatternLowering::generateOperationResultTypeRewriter(
579     pdl::OperationOp op, SmallVectorImpl<Value> &types,
580     DenseMap<Value, Value> &rewriteValues,
581     function_ref<Value(Value)> mapRewriteValue) {
582   // Functor that returns if the given use can be used to infer a type.
583   Block *rewriterBlock = op->getBlock();
584   auto getReplacedOperationFrom = [&](OpOperand &use) -> Operation * {
585     // Check that the use corresponds to a ReplaceOp and that it is the
586     // replacement value, not the operation being replaced.
587     pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
588     if (!replOpUser || use.getOperandNumber() == 0)
589       return nullptr;
590     // Make sure the replaced operation was defined before this one.
591     Operation *replacedOp = replOpUser.operation().getDefiningOp();
592     if (replacedOp->getBlock() != rewriterBlock ||
593         replacedOp->isBeforeInBlock(op))
594       return replacedOp;
595     return nullptr;
596   };
597 
598   // If non-None/non-Null, this is an operation that is replaced by `op`.
599   // If Null, there is no full replacement operation for `op`.
600   // If None, a replacement operation hasn't been searched for.
601   Optional<Operation *> fullReplacedOperation;
602   bool hasTypeInference = op.hasTypeInference();
603   auto resultTypeValues = op.types();
604   types.reserve(resultTypeValues.size());
605   for (auto it : llvm::enumerate(op.results())) {
606     Value result = it.value(), resultType = resultTypeValues[it.index()];
607 
608     // Check for an already translated value.
609     if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
610       types.push_back(existingRewriteValue);
611       continue;
612     }
613 
614     // Check for an input from the matcher.
615     if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
616       types.push_back(mapRewriteValue(resultType));
617       continue;
618     }
619 
620     // Check if the operation has type inference support.
621     if (hasTypeInference) {
622       types.push_back(builder.create<pdl_interp::InferredTypeOp>(op.getLoc()));
623       continue;
624     }
625 
626     // Look for an operation that was replaced by `op`. The result type will be
627     // inferred from the result that was replaced. There is guaranteed to be a
628     // replacement for either the op, or this specific result. Note that this is
629     // guaranteed by the verifier of `pdl::OperationOp`.
630     Operation *replacedOp = nullptr;
631     if (!fullReplacedOperation.hasValue()) {
632       for (OpOperand &use : op.op().getUses())
633         if ((replacedOp = getReplacedOperationFrom(use)))
634           break;
635       fullReplacedOperation = replacedOp;
636     } else {
637       replacedOp = fullReplacedOperation.getValue();
638     }
639     // Infer from the result, as there was no fully replaced op.
640     if (!replacedOp) {
641       for (OpOperand &use : result.getUses())
642         if ((replacedOp = getReplacedOperationFrom(use)))
643           break;
644       assert(replacedOp && "expected replaced op to infer a result type from");
645     }
646 
647     auto replOpOp = cast<pdl::OperationOp>(replacedOp);
648     types.push_back(mapRewriteValue(replOpOp.types()[it.index()]));
649   }
650 }
651 
652 //===----------------------------------------------------------------------===//
653 // Conversion Pass
654 //===----------------------------------------------------------------------===//
655 
656 namespace {
657 struct PDLToPDLInterpPass
658     : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
659   void runOnOperation() final;
660 };
661 } // namespace
662 
663 /// Convert the given module containing PDL pattern operations into a PDL
664 /// Interpreter operations.
runOnOperation()665 void PDLToPDLInterpPass::runOnOperation() {
666   ModuleOp module = getOperation();
667 
668   // Create the main matcher function This function contains all of the match
669   // related functionality from patterns in the module.
670   OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
671   FuncOp matcherFunc = builder.create<FuncOp>(
672       module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
673       builder.getFunctionType(builder.getType<pdl::OperationType>(),
674                               /*results=*/llvm::None),
675       /*attrs=*/llvm::None);
676 
677   // Create a nested module to hold the functions invoked for rewriting the IR
678   // after a successful match.
679   ModuleOp rewriterModule = builder.create<ModuleOp>(
680       module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
681 
682   // Generate the code for the patterns within the module.
683   PatternLowering generator(matcherFunc, rewriterModule);
684   generator.lower(module);
685 
686   // After generation, delete all of the pattern operations.
687   for (pdl::PatternOp pattern :
688        llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
689     pattern.erase();
690 }
691 
createPDLToPDLInterpPass()692 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
693   return std::make_unique<PDLToPDLInterpPass>();
694 }
695