1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 #include "../PassDetail.h"
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22
23 using namespace mlir;
24 using namespace mlir::pdl_to_pdl_interp;
25
26 //===----------------------------------------------------------------------===//
27 // PatternLowering
28 //===----------------------------------------------------------------------===//
29
30 namespace {
31 /// This class generators operations within the PDL Interpreter dialect from a
32 /// given module containing PDL pattern operations.
33 struct PatternLowering {
34 public:
35 PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule);
36
37 /// Generate code for matching and rewriting based on the pattern operations
38 /// within the module.
39 void lower(ModuleOp module);
40
41 private:
42 using ValueMap = llvm::ScopedHashTable<Position *, Value>;
43 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
44
45 /// Generate interpreter operations for the tree rooted at the given matcher
46 /// node.
47 Block *generateMatcher(MatcherNode &node);
48
49 /// Get or create an access to the provided positional value within the
50 /// current block.
51 Value getValueAt(Block *cur, Position *pos);
52
53 /// Create an interpreter predicate operation, branching to the provided true
54 /// and false destinations.
55 void generatePredicate(Block *currentBlock, Qualifier *question,
56 Qualifier *answer, Value val, Block *trueDest,
57 Block *falseDest);
58
59 /// Create an interpreter switch predicate operation, with a provided default
60 /// and several case destinations.
61 void generateSwitch(SwitchNode *switchNode, Block *currentBlock,
62 Qualifier *question, Value val, Block *defaultDest);
63
64 /// Create the interpreter operations to record a successful pattern match.
65 void generateRecordMatch(Block *currentBlock, Block *nextBlock,
66 pdl::PatternOp pattern);
67
68 /// Generate a rewriter function for the given pattern operation, and returns
69 /// a reference to that function.
70 SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
71 SmallVectorImpl<Position *> &usedMatchValues);
72
73 /// Generate the rewriter code for the given operation.
74 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
75 DenseMap<Value, Value> &rewriteValues,
76 function_ref<Value(Value)> mapRewriteValue);
77 void generateRewriter(pdl::AttributeOp attrOp,
78 DenseMap<Value, Value> &rewriteValues,
79 function_ref<Value(Value)> mapRewriteValue);
80 void generateRewriter(pdl::EraseOp eraseOp,
81 DenseMap<Value, Value> &rewriteValues,
82 function_ref<Value(Value)> mapRewriteValue);
83 void generateRewriter(pdl::OperationOp operationOp,
84 DenseMap<Value, Value> &rewriteValues,
85 function_ref<Value(Value)> mapRewriteValue);
86 void generateRewriter(pdl::ReplaceOp replaceOp,
87 DenseMap<Value, Value> &rewriteValues,
88 function_ref<Value(Value)> mapRewriteValue);
89 void generateRewriter(pdl::ResultOp resultOp,
90 DenseMap<Value, Value> &rewriteValues,
91 function_ref<Value(Value)> mapRewriteValue);
92 void generateRewriter(pdl::ResultsOp resultOp,
93 DenseMap<Value, Value> &rewriteValues,
94 function_ref<Value(Value)> mapRewriteValue);
95 void generateRewriter(pdl::TypeOp typeOp,
96 DenseMap<Value, Value> &rewriteValues,
97 function_ref<Value(Value)> mapRewriteValue);
98 void generateRewriter(pdl::TypesOp typeOp,
99 DenseMap<Value, Value> &rewriteValues,
100 function_ref<Value(Value)> mapRewriteValue);
101
102 /// Generate the values used for resolving the result types of an operation
103 /// created within a dag rewriter region.
104 void generateOperationResultTypeRewriter(
105 pdl::OperationOp op, SmallVectorImpl<Value> &types,
106 DenseMap<Value, Value> &rewriteValues,
107 function_ref<Value(Value)> mapRewriteValue);
108
109 /// A builder to use when generating interpreter operations.
110 OpBuilder builder;
111
112 /// The matcher function used for all match related logic within PDL patterns.
113 FuncOp matcherFunc;
114
115 /// The rewriter module containing the all rewrite related logic within PDL
116 /// patterns.
117 ModuleOp rewriterModule;
118
119 /// The symbol table of the rewriter module used for insertion.
120 SymbolTable rewriterSymbolTable;
121
122 /// A scoped map connecting a position with the corresponding interpreter
123 /// value.
124 ValueMap values;
125
126 /// A stack of blocks used as the failure destination for matcher nodes that
127 /// don't have an explicit failure path.
128 SmallVector<Block *, 8> failureBlockStack;
129
130 /// A mapping between values defined in a pattern match, and the corresponding
131 /// positional value.
132 DenseMap<Value, Position *> valueToPosition;
133
134 /// The set of operation values whose whose location will be used for newly
135 /// generated operations.
136 SetVector<Value> locOps;
137 };
138 } // end anonymous namespace
139
PatternLowering(FuncOp matcherFunc,ModuleOp rewriterModule)140 PatternLowering::PatternLowering(FuncOp matcherFunc, ModuleOp rewriterModule)
141 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
142 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
143
lower(ModuleOp module)144 void PatternLowering::lower(ModuleOp module) {
145 PredicateUniquer predicateUniquer;
146 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
147
148 // Define top-level scope for the arguments to the matcher function.
149 ValueMapScope topLevelValueScope(values);
150
151 // Insert the root operation, i.e. argument to the matcher, at the root
152 // position.
153 Block *matcherEntryBlock = matcherFunc.addEntryBlock();
154 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
155
156 // Generate a root matcher node from the provided PDL module.
157 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
158 module, predicateBuilder, valueToPosition);
159 Block *firstMatcherBlock = generateMatcher(*root);
160
161 // After generation, merged the first matched block into the entry.
162 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
163 firstMatcherBlock->getOperations());
164 firstMatcherBlock->erase();
165 }
166
generateMatcher(MatcherNode & node)167 Block *PatternLowering::generateMatcher(MatcherNode &node) {
168 // Push a new scope for the values used by this matcher.
169 Block *block = matcherFunc.addBlock();
170 ValueMapScope scope(values);
171
172 // If this is the return node, simply insert the corresponding interpreter
173 // finalize.
174 if (isa<ExitNode>(node)) {
175 builder.setInsertionPointToEnd(block);
176 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
177 return block;
178 }
179
180 // If this node contains a position, get the corresponding value for this
181 // block.
182 Position *position = node.getPosition();
183 Value val = position ? getValueAt(block, position) : Value();
184
185 // Get the next block in the match sequence.
186 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
187 Block *nextBlock;
188 if (failureNode) {
189 nextBlock = generateMatcher(*failureNode);
190 failureBlockStack.push_back(nextBlock);
191 } else {
192 assert(!failureBlockStack.empty() && "expected valid failure block");
193 nextBlock = failureBlockStack.back();
194 }
195
196 // If this value corresponds to an operation, record that we are going to use
197 // its location as part of a fused location.
198 bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
199 if (isOperationValue)
200 locOps.insert(val);
201
202 // Generate code for a boolean predicate node.
203 if (auto *boolNode = dyn_cast<BoolNode>(&node)) {
204 auto *child = generateMatcher(*boolNode->getSuccessNode());
205 generatePredicate(block, node.getQuestion(), boolNode->getAnswer(), val,
206 child, nextBlock);
207
208 // Generate code for a switch node.
209 } else if (auto *switchNode = dyn_cast<SwitchNode>(&node)) {
210 generateSwitch(switchNode, block, node.getQuestion(), val, nextBlock);
211
212 // Generate code for a success node.
213 } else if (auto *successNode = dyn_cast<SuccessNode>(&node)) {
214 generateRecordMatch(block, nextBlock, successNode->getPattern());
215 }
216
217 if (failureNode)
218 failureBlockStack.pop_back();
219 if (isOperationValue)
220 locOps.remove(val);
221 return block;
222 }
223
getValueAt(Block * cur,Position * pos)224 Value PatternLowering::getValueAt(Block *cur, Position *pos) {
225 if (Value val = values.lookup(pos))
226 return val;
227
228 // Get the value for the parent position.
229 Value parentVal = getValueAt(cur, pos->getParent());
230
231 // TODO: Use a location from the position.
232 Location loc = parentVal.getLoc();
233 builder.setInsertionPointToEnd(cur);
234 Value value;
235 switch (pos->getKind()) {
236 case Predicates::OperationPos:
237 value = builder.create<pdl_interp::GetDefiningOpOp>(
238 loc, builder.getType<pdl::OperationType>(), parentVal);
239 break;
240 case Predicates::OperandPos: {
241 auto *operandPos = cast<OperandPosition>(pos);
242 value = builder.create<pdl_interp::GetOperandOp>(
243 loc, builder.getType<pdl::ValueType>(), parentVal,
244 operandPos->getOperandNumber());
245 break;
246 }
247 case Predicates::OperandGroupPos: {
248 auto *operandPos = cast<OperandGroupPosition>(pos);
249 Type valueTy = builder.getType<pdl::ValueType>();
250 value = builder.create<pdl_interp::GetOperandsOp>(
251 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
252 parentVal, operandPos->getOperandGroupNumber());
253 break;
254 }
255 case Predicates::AttributePos: {
256 auto *attrPos = cast<AttributePosition>(pos);
257 value = builder.create<pdl_interp::GetAttributeOp>(
258 loc, builder.getType<pdl::AttributeType>(), parentVal,
259 attrPos->getName().strref());
260 break;
261 }
262 case Predicates::TypePos: {
263 if (parentVal.getType().isa<pdl::AttributeType>())
264 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
265 else
266 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
267 break;
268 }
269 case Predicates::ResultPos: {
270 auto *resPos = cast<ResultPosition>(pos);
271 value = builder.create<pdl_interp::GetResultOp>(
272 loc, builder.getType<pdl::ValueType>(), parentVal,
273 resPos->getResultNumber());
274 break;
275 }
276 case Predicates::ResultGroupPos: {
277 auto *resPos = cast<ResultGroupPosition>(pos);
278 Type valueTy = builder.getType<pdl::ValueType>();
279 value = builder.create<pdl_interp::GetResultsOp>(
280 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
281 parentVal, resPos->getResultGroupNumber());
282 break;
283 }
284 default:
285 llvm_unreachable("Generating unknown Position getter");
286 break;
287 }
288 values.insert(pos, value);
289 return value;
290 }
291
generatePredicate(Block * currentBlock,Qualifier * question,Qualifier * answer,Value val,Block * trueDest,Block * falseDest)292 void PatternLowering::generatePredicate(Block *currentBlock,
293 Qualifier *question, Qualifier *answer,
294 Value val, Block *trueDest,
295 Block *falseDest) {
296 builder.setInsertionPointToEnd(currentBlock);
297 Location loc = val.getLoc();
298 Predicates::Kind kind = question->getKind();
299 switch (kind) {
300 case Predicates::IsNotNullQuestion:
301 builder.create<pdl_interp::IsNotNullOp>(loc, val, trueDest, falseDest);
302 break;
303 case Predicates::OperationNameQuestion: {
304 auto *opNameAnswer = cast<OperationNameAnswer>(answer);
305 builder.create<pdl_interp::CheckOperationNameOp>(
306 loc, val, opNameAnswer->getValue().getStringRef(), trueDest, falseDest);
307 break;
308 }
309 case Predicates::TypeQuestion: {
310 auto *ans = cast<TypeAnswer>(answer);
311 if (val.getType().isa<pdl::RangeType>())
312 builder.create<pdl_interp::CheckTypesOp>(
313 loc, val, ans->getValue().cast<ArrayAttr>(), trueDest, falseDest);
314 else
315 builder.create<pdl_interp::CheckTypeOp>(
316 loc, val, ans->getValue().cast<TypeAttr>(), trueDest, falseDest);
317 break;
318 }
319 case Predicates::AttributeQuestion: {
320 auto *ans = cast<AttributeAnswer>(answer);
321 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
322 trueDest, falseDest);
323 break;
324 }
325 case Predicates::OperandCountAtLeastQuestion:
326 case Predicates::OperandCountQuestion:
327 builder.create<pdl_interp::CheckOperandCountOp>(
328 loc, val, cast<UnsignedAnswer>(answer)->getValue(),
329 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
330 trueDest, falseDest);
331 break;
332 case Predicates::ResultCountAtLeastQuestion:
333 case Predicates::ResultCountQuestion:
334 builder.create<pdl_interp::CheckResultCountOp>(
335 loc, val, cast<UnsignedAnswer>(answer)->getValue(),
336 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
337 trueDest, falseDest);
338 break;
339 case Predicates::EqualToQuestion: {
340 auto *equalToQuestion = cast<EqualToQuestion>(question);
341 builder.create<pdl_interp::AreEqualOp>(
342 loc, val, getValueAt(currentBlock, equalToQuestion->getValue()),
343 trueDest, falseDest);
344 break;
345 }
346 case Predicates::ConstraintQuestion: {
347 auto *cstQuestion = cast<ConstraintQuestion>(question);
348 SmallVector<Value, 2> args;
349 for (Position *position : std::get<1>(cstQuestion->getValue()))
350 args.push_back(getValueAt(currentBlock, position));
351 builder.create<pdl_interp::ApplyConstraintOp>(
352 loc, std::get<0>(cstQuestion->getValue()), args,
353 std::get<2>(cstQuestion->getValue()).cast<ArrayAttr>(), trueDest,
354 falseDest);
355 break;
356 }
357 default:
358 llvm_unreachable("Generating unknown Predicate operation");
359 }
360 }
361
362 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
createSwitchOp(Value val,Block * defaultDest,OpBuilder & builder,llvm::MapVector<Qualifier *,Block * > & dests)363 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
364 llvm::MapVector<Qualifier *, Block *> &dests) {
365 std::vector<ValT> values;
366 std::vector<Block *> blocks;
367 values.reserve(dests.size());
368 blocks.reserve(dests.size());
369 for (const auto &it : dests) {
370 blocks.push_back(it.second);
371 values.push_back(cast<PredT>(it.first)->getValue());
372 }
373 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
374 }
375
generateSwitch(SwitchNode * switchNode,Block * currentBlock,Qualifier * question,Value val,Block * defaultDest)376 void PatternLowering::generateSwitch(SwitchNode *switchNode,
377 Block *currentBlock, Qualifier *question,
378 Value val, Block *defaultDest) {
379 // If the switch question is not an exact answer, i.e. for the `at_least`
380 // cases, we generate a special block sequence.
381 Predicates::Kind kind = question->getKind();
382 if (kind == Predicates::OperandCountAtLeastQuestion ||
383 kind == Predicates::ResultCountAtLeastQuestion) {
384 // Order the children such that the cases are in reverse numerical order.
385 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
386 llvm::seq<unsigned>(0, switchNode->getChildren().size()));
387 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
388 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
389 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
390 });
391
392 // Build the destination for each child using the next highest child as a
393 // a failure destination. This essentially creates the following control
394 // flow:
395 //
396 // if (operand_count < 1)
397 // goto failure
398 // if (child1.match())
399 // ...
400 //
401 // if (operand_count < 2)
402 // goto failure
403 // if (child2.match())
404 // ...
405 //
406 // failure:
407 // ...
408 //
409 failureBlockStack.push_back(defaultDest);
410 for (unsigned idx : sortedChildren) {
411 auto &child = switchNode->getChild(idx);
412 Block *childBlock = generateMatcher(*child.second);
413 Block *predicateBlock = builder.createBlock(childBlock);
414 generatePredicate(predicateBlock, question, child.first, val, childBlock,
415 defaultDest);
416 failureBlockStack.back() = predicateBlock;
417 }
418 Block *firstPredicateBlock = failureBlockStack.pop_back_val();
419 currentBlock->getOperations().splice(currentBlock->end(),
420 firstPredicateBlock->getOperations());
421 firstPredicateBlock->erase();
422 return;
423 }
424
425 // Otherwise, generate each of the children and generate an interpreter
426 // switch.
427 llvm::MapVector<Qualifier *, Block *> children;
428 for (auto &it : switchNode->getChildren())
429 children.insert({it.first, generateMatcher(*it.second)});
430 builder.setInsertionPointToEnd(currentBlock);
431
432 switch (question->getKind()) {
433 case Predicates::OperandCountQuestion:
434 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
435 int32_t>(val, defaultDest, builder, children);
436 case Predicates::ResultCountQuestion:
437 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
438 int32_t>(val, defaultDest, builder, children);
439 case Predicates::OperationNameQuestion:
440 return createSwitchOp<pdl_interp::SwitchOperationNameOp,
441 OperationNameAnswer>(val, defaultDest, builder,
442 children);
443 case Predicates::TypeQuestion:
444 if (val.getType().isa<pdl::RangeType>()) {
445 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
446 val, defaultDest, builder, children);
447 }
448 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
449 val, defaultDest, builder, children);
450 case Predicates::AttributeQuestion:
451 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
452 val, defaultDest, builder, children);
453 default:
454 llvm_unreachable("Generating unknown switch predicate.");
455 }
456 }
457
generateRecordMatch(Block * currentBlock,Block * nextBlock,pdl::PatternOp pattern)458 void PatternLowering::generateRecordMatch(Block *currentBlock, Block *nextBlock,
459 pdl::PatternOp pattern) {
460 // Generate a rewriter for the pattern this success node represents, and track
461 // any values used from the match region.
462 SmallVector<Position *, 8> usedMatchValues;
463 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
464
465 // Process any values used in the rewrite that are defined in the match.
466 std::vector<Value> mappedMatchValues;
467 mappedMatchValues.reserve(usedMatchValues.size());
468 for (Position *position : usedMatchValues)
469 mappedMatchValues.push_back(getValueAt(currentBlock, position));
470
471 // Collect the set of operations generated by the rewriter.
472 SmallVector<StringRef, 4> generatedOps;
473 for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
474 generatedOps.push_back(*op.name());
475 ArrayAttr generatedOpsAttr;
476 if (!generatedOps.empty())
477 generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
478
479 // Grab the root kind if present.
480 StringAttr rootKindAttr;
481 if (Optional<StringRef> rootKind = pattern.getRootKind())
482 rootKindAttr = builder.getStringAttr(*rootKind);
483
484 builder.setInsertionPointToEnd(currentBlock);
485 builder.create<pdl_interp::RecordMatchOp>(
486 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
487 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
488 nextBlock);
489 }
490
generateRewriter(pdl::PatternOp pattern,SmallVectorImpl<Position * > & usedMatchValues)491 SymbolRefAttr PatternLowering::generateRewriter(
492 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
493 FuncOp rewriterFunc =
494 FuncOp::create(pattern.getLoc(), "pdl_generated_rewriter",
495 builder.getFunctionType(llvm::None, llvm::None));
496 rewriterSymbolTable.insert(rewriterFunc);
497
498 // Generate the rewriter function body.
499 builder.setInsertionPointToEnd(rewriterFunc.addEntryBlock());
500
501 // Map an input operand of the pattern to a generated interpreter value.
502 DenseMap<Value, Value> rewriteValues;
503 auto mapRewriteValue = [&](Value oldValue) {
504 Value &newValue = rewriteValues[oldValue];
505 if (newValue)
506 return newValue;
507
508 // Prefer materializing constants directly when possible.
509 Operation *oldOp = oldValue.getDefiningOp();
510 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
511 if (Attribute value = attrOp.valueAttr()) {
512 return newValue = builder.create<pdl_interp::CreateAttributeOp>(
513 attrOp.getLoc(), value);
514 }
515 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
516 if (TypeAttr type = typeOp.typeAttr()) {
517 return newValue = builder.create<pdl_interp::CreateTypeOp>(
518 typeOp.getLoc(), type);
519 }
520 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
521 if (ArrayAttr type = typeOp.typesAttr()) {
522 return newValue = builder.create<pdl_interp::CreateTypesOp>(
523 typeOp.getLoc(), typeOp.getType(), type);
524 }
525 }
526
527 // Otherwise, add this as an input to the rewriter.
528 Position *inputPos = valueToPosition.lookup(oldValue);
529 assert(inputPos && "expected value to be a pattern input");
530 usedMatchValues.push_back(inputPos);
531 return newValue = rewriterFunc.front().addArgument(oldValue.getType());
532 };
533
534 // If this is a custom rewriter, simply dispatch to the registered rewrite
535 // method.
536 pdl::RewriteOp rewriter = pattern.getRewriter();
537 if (StringAttr rewriteName = rewriter.nameAttr()) {
538 auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
539 SmallVector<Value, 4> args(1, mapRewriteValue(rewriter.root()));
540 args.append(mappedArgs.begin(), mappedArgs.end());
541 builder.create<pdl_interp::ApplyRewriteOp>(
542 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args,
543 rewriter.externalConstParamsAttr());
544 } else {
545 // Otherwise this is a dag rewriter defined using PDL operations.
546 for (Operation &rewriteOp : *rewriter.getBody()) {
547 llvm::TypeSwitch<Operation *>(&rewriteOp)
548 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
549 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
550 pdl::TypeOp, pdl::TypesOp>([&](auto op) {
551 this->generateRewriter(op, rewriteValues, mapRewriteValue);
552 });
553 }
554 }
555
556 // Update the signature of the rewrite function.
557 rewriterFunc.setType(builder.getFunctionType(
558 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
559 /*results=*/llvm::None));
560
561 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
562 return builder.getSymbolRefAttr(
563 pdl_interp::PDLInterpDialect::getRewriterModuleName(),
564 builder.getSymbolRefAttr(rewriterFunc));
565 }
566
generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)567 void PatternLowering::generateRewriter(
568 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
569 function_ref<Value(Value)> mapRewriteValue) {
570 SmallVector<Value, 2> arguments;
571 for (Value argument : rewriteOp.args())
572 arguments.push_back(mapRewriteValue(argument));
573 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
574 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
575 arguments, rewriteOp.constParamsAttr());
576 for (auto it : llvm::zip(rewriteOp.results(), interpOp.results()))
577 rewriteValues[std::get<0>(it)] = std::get<1>(it);
578 }
579
generateRewriter(pdl::AttributeOp attrOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)580 void PatternLowering::generateRewriter(
581 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
582 function_ref<Value(Value)> mapRewriteValue) {
583 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
584 attrOp.getLoc(), attrOp.valueAttr());
585 rewriteValues[attrOp] = newAttr;
586 }
587
generateRewriter(pdl::EraseOp eraseOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)588 void PatternLowering::generateRewriter(
589 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
590 function_ref<Value(Value)> mapRewriteValue) {
591 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
592 mapRewriteValue(eraseOp.operation()));
593 }
594
generateRewriter(pdl::OperationOp operationOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)595 void PatternLowering::generateRewriter(
596 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
597 function_ref<Value(Value)> mapRewriteValue) {
598 SmallVector<Value, 4> operands;
599 for (Value operand : operationOp.operands())
600 operands.push_back(mapRewriteValue(operand));
601
602 SmallVector<Value, 4> attributes;
603 for (Value attr : operationOp.attributes())
604 attributes.push_back(mapRewriteValue(attr));
605
606 SmallVector<Value, 2> types;
607 generateOperationResultTypeRewriter(operationOp, types, rewriteValues,
608 mapRewriteValue);
609
610 // Create the new operation.
611 Location loc = operationOp.getLoc();
612 Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
613 loc, *operationOp.name(), types, operands, attributes,
614 operationOp.attributeNames());
615 rewriteValues[operationOp.op()] = createdOp;
616
617 // Generate accesses for any results that have their types constrained.
618 // Handle the case where there is a single range representing all of the
619 // result types.
620 OperandRange resultTys = operationOp.types();
621 if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
622 Value &type = rewriteValues[resultTys[0]];
623 if (!type) {
624 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
625 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
626 }
627 return;
628 }
629
630 // Otherwise, populate the individual results.
631 bool seenVariableLength = false;
632 Type valueTy = builder.getType<pdl::ValueType>();
633 Type valueRangeTy = pdl::RangeType::get(valueTy);
634 for (auto it : llvm::enumerate(resultTys)) {
635 Value &type = rewriteValues[it.value()];
636 if (type)
637 continue;
638 bool isVariadic = it.value().getType().isa<pdl::RangeType>();
639 seenVariableLength |= isVariadic;
640
641 // After a variable length result has been seen, we need to use result
642 // groups because the exact index of the result is not statically known.
643 Value resultVal;
644 if (seenVariableLength)
645 resultVal = builder.create<pdl_interp::GetResultsOp>(
646 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
647 else
648 resultVal = builder.create<pdl_interp::GetResultOp>(
649 loc, valueTy, createdOp, it.index());
650 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
651 }
652 }
653
generateRewriter(pdl::ReplaceOp replaceOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)654 void PatternLowering::generateRewriter(
655 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
656 function_ref<Value(Value)> mapRewriteValue) {
657 SmallVector<Value, 4> replOperands;
658
659 // If the replacement was another operation, get its results. `pdl` allows
660 // for using an operation for simplicitly, but the interpreter isn't as
661 // user facing.
662 if (Value replOp = replaceOp.replOperation()) {
663 // Don't use replace if we know the replaced operation has no results.
664 auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
665 if (!opOp || !opOp.types().empty()) {
666 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
667 replOp.getLoc(), mapRewriteValue(replOp)));
668 }
669 } else {
670 for (Value operand : replaceOp.replValues())
671 replOperands.push_back(mapRewriteValue(operand));
672 }
673
674 // If there are no replacement values, just create an erase instead.
675 if (replOperands.empty()) {
676 builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
677 mapRewriteValue(replaceOp.operation()));
678 return;
679 }
680
681 builder.create<pdl_interp::ReplaceOp>(
682 replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
683 }
684
generateRewriter(pdl::ResultOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)685 void PatternLowering::generateRewriter(
686 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
687 function_ref<Value(Value)> mapRewriteValue) {
688 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
689 resultOp.getLoc(), builder.getType<pdl::ValueType>(),
690 mapRewriteValue(resultOp.parent()), resultOp.index());
691 }
692
generateRewriter(pdl::ResultsOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)693 void PatternLowering::generateRewriter(
694 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
695 function_ref<Value(Value)> mapRewriteValue) {
696 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
697 resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
698 resultOp.index());
699 }
700
generateRewriter(pdl::TypeOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)701 void PatternLowering::generateRewriter(
702 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
703 function_ref<Value(Value)> mapRewriteValue) {
704 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
705 // type.
706 if (TypeAttr typeAttr = typeOp.typeAttr()) {
707 rewriteValues[typeOp] =
708 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
709 }
710 }
711
generateRewriter(pdl::TypesOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)712 void PatternLowering::generateRewriter(
713 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
714 function_ref<Value(Value)> mapRewriteValue) {
715 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
716 // type.
717 if (ArrayAttr typeAttr = typeOp.typesAttr()) {
718 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
719 typeOp.getLoc(), typeOp.getType(), typeAttr);
720 }
721 }
722
generateOperationResultTypeRewriter(pdl::OperationOp op,SmallVectorImpl<Value> & types,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)723 void PatternLowering::generateOperationResultTypeRewriter(
724 pdl::OperationOp op, SmallVectorImpl<Value> &types,
725 DenseMap<Value, Value> &rewriteValues,
726 function_ref<Value(Value)> mapRewriteValue) {
727 // Look for an operation that was replaced by `op`. The result types will be
728 // inferred from the results that were replaced.
729 Block *rewriterBlock = op->getBlock();
730 Value replacedOp;
731 for (OpOperand &use : op.op().getUses()) {
732 // Check that the use corresponds to a ReplaceOp and that it is the
733 // replacement value, not the operation being replaced.
734 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
735 if (!replOpUser || use.getOperandNumber() == 0)
736 continue;
737 // Make sure the replaced operation was defined before this one.
738 Value replOpVal = replOpUser.operation();
739 Operation *replacedOp = replOpVal.getDefiningOp();
740 if (replacedOp->getBlock() == rewriterBlock &&
741 !replacedOp->isBeforeInBlock(op))
742 continue;
743
744 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
745 replacedOp->getLoc(), mapRewriteValue(replOpVal));
746 types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
747 replacedOp->getLoc(), replacedOpResults));
748 return;
749 }
750
751 // Check if the operation has type inference support.
752 if (op.hasTypeInference()) {
753 types.push_back(builder.create<pdl_interp::InferredTypesOp>(op.getLoc()));
754 return;
755 }
756
757 // Otherwise, handle inference for each of the result types individually.
758 OperandRange resultTypeValues = op.types();
759 types.reserve(resultTypeValues.size());
760 for (auto it : llvm::enumerate(resultTypeValues)) {
761 Value resultType = it.value();
762
763 // Check for an already translated value.
764 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
765 types.push_back(existingRewriteValue);
766 continue;
767 }
768
769 // Check for an input from the matcher.
770 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
771 types.push_back(mapRewriteValue(resultType));
772 continue;
773 }
774
775 // The verifier asserts that the result types of each pdl.operation can be
776 // inferred. If we reach here, there is a bug either in the logic above or
777 // in the verifier for pdl.operation.
778 op->emitOpError() << "unable to infer result type for operation";
779 llvm_unreachable("unable to infer result type for operation");
780 }
781 }
782
783 //===----------------------------------------------------------------------===//
784 // Conversion Pass
785 //===----------------------------------------------------------------------===//
786
787 namespace {
788 struct PDLToPDLInterpPass
789 : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
790 void runOnOperation() final;
791 };
792 } // namespace
793
794 /// Convert the given module containing PDL pattern operations into a PDL
795 /// Interpreter operations.
runOnOperation()796 void PDLToPDLInterpPass::runOnOperation() {
797 ModuleOp module = getOperation();
798
799 // Create the main matcher function This function contains all of the match
800 // related functionality from patterns in the module.
801 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
802 FuncOp matcherFunc = builder.create<FuncOp>(
803 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
804 builder.getFunctionType(builder.getType<pdl::OperationType>(),
805 /*results=*/llvm::None),
806 /*attrs=*/llvm::None);
807
808 // Create a nested module to hold the functions invoked for rewriting the IR
809 // after a successful match.
810 ModuleOp rewriterModule = builder.create<ModuleOp>(
811 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
812
813 // Generate the code for the patterns within the module.
814 PatternLowering generator(matcherFunc, rewriterModule);
815 generator.lower(module);
816
817 // After generation, delete all of the pattern operations.
818 for (pdl::PatternOp pattern :
819 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
820 pattern.erase();
821 }
822
createPDLToPDLInterpPass()823 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
824 return std::make_unique<PDLToPDLInterpPass>();
825 }
826