1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
10 #include "ByteCode.h"
11 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
12 #include "mlir/Dialect/PDL/IR/PDLOps.h"
13 #include "mlir/Interfaces/SideEffectInterfaces.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16
17 using namespace mlir;
18
convertPDLToPDLInterp(ModuleOp pdlModule)19 static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
20 // Skip the conversion if the module doesn't contain pdl.
21 if (llvm::empty(pdlModule.getOps<pdl::PatternOp>()))
22 return success();
23
24 // Simplify the provided PDL module. Note that we can't use the canonicalizer
25 // here because it would create a cyclic dependency.
26 auto simplifyFn = [](Operation *op) {
27 // TODO: Add folding here if ever necessary.
28 if (isOpTriviallyDead(op))
29 op->erase();
30 };
31 pdlModule.getBody()->walk(simplifyFn);
32
33 /// Lower the PDL pattern module to the interpreter dialect.
34 PassManager pdlPipeline(pdlModule.getContext());
35 #ifdef NDEBUG
36 // We don't want to incur the hit of running the verifier when in release
37 // mode.
38 pdlPipeline.enableVerifier(false);
39 #endif
40 pdlPipeline.addPass(createPDLToPDLInterpPass());
41 if (failed(pdlPipeline.run(pdlModule)))
42 return failure();
43
44 // Simplify again after running the lowering pipeline.
45 pdlModule.getBody()->walk(simplifyFn);
46 return success();
47 }
48
49 //===----------------------------------------------------------------------===//
50 // FrozenRewritePatternSet
51 //===----------------------------------------------------------------------===//
52
FrozenRewritePatternSet()53 FrozenRewritePatternSet::FrozenRewritePatternSet()
54 : impl(std::make_shared<Impl>()) {}
55
FrozenRewritePatternSet(RewritePatternSet && patterns,ArrayRef<std::string> disabledPatternLabels,ArrayRef<std::string> enabledPatternLabels)56 FrozenRewritePatternSet::FrozenRewritePatternSet(
57 RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
58 ArrayRef<std::string> enabledPatternLabels)
59 : impl(std::make_shared<Impl>()) {
60 DenseSet<StringRef> disabledPatterns, enabledPatterns;
61 disabledPatterns.insert(disabledPatternLabels.begin(),
62 disabledPatternLabels.end());
63 enabledPatterns.insert(enabledPatternLabels.begin(),
64 enabledPatternLabels.end());
65
66 // Functor used to walk all of the operations registered in the context. This
67 // is useful for patterns that get applied to multiple operations, such as
68 // interface and trait based patterns.
69 std::vector<AbstractOperation *> abstractOps;
70 auto addToOpsWhen = [&](std::unique_ptr<RewritePattern> &pattern,
71 function_ref<bool(AbstractOperation *)> callbackFn) {
72 if (abstractOps.empty())
73 abstractOps = pattern->getContext()->getRegisteredOperations();
74 for (AbstractOperation *absOp : abstractOps) {
75 if (callbackFn(absOp)) {
76 OperationName opName(absOp);
77 impl->nativeOpSpecificPatternMap[opName].push_back(pattern.get());
78 }
79 }
80 impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
81 };
82
83 for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
84 // Don't add patterns that haven't been enabled by the user.
85 if (!enabledPatterns.empty()) {
86 auto isEnabledFn = [&](StringRef label) {
87 return enabledPatterns.count(label);
88 };
89 if (!isEnabledFn(pat->getDebugName()) &&
90 llvm::none_of(pat->getDebugLabels(), isEnabledFn))
91 continue;
92 }
93 // Don't add patterns that have been disabled by the user.
94 if (!disabledPatterns.empty()) {
95 auto isDisabledFn = [&](StringRef label) {
96 return disabledPatterns.count(label);
97 };
98 if (isDisabledFn(pat->getDebugName()) ||
99 llvm::any_of(pat->getDebugLabels(), isDisabledFn))
100 continue;
101 }
102
103 if (Optional<OperationName> rootName = pat->getRootKind()) {
104 impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
105 impl->nativeOpSpecificPatternList.push_back(std::move(pat));
106 continue;
107 }
108 if (Optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
109 addToOpsWhen(pat, [&](AbstractOperation *absOp) {
110 return absOp->hasInterface(*interfaceID);
111 });
112 continue;
113 }
114 if (Optional<TypeID> traitID = pat->getRootTraitID()) {
115 addToOpsWhen(pat, [&](AbstractOperation *absOp) {
116 return absOp->hasTrait(*traitID);
117 });
118 continue;
119 }
120 impl->nativeAnyOpPatterns.push_back(std::move(pat));
121 }
122
123 // Generate the bytecode for the PDL patterns if any were provided.
124 PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
125 ModuleOp pdlModule = pdlPatterns.getModule();
126 if (!pdlModule)
127 return;
128 if (failed(convertPDLToPDLInterp(pdlModule)))
129 llvm::report_fatal_error(
130 "failed to lower PDL pattern module to the PDL Interpreter");
131
132 // Generate the pdl bytecode.
133 impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
134 pdlModule, pdlPatterns.takeConstraintFunctions(),
135 pdlPatterns.takeRewriteFunctions());
136 }
137
~FrozenRewritePatternSet()138 FrozenRewritePatternSet::~FrozenRewritePatternSet() {}
139