1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
10 #include "ByteCode.h"
11 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
12 #include "mlir/Dialect/PDL/IR/PDLOps.h"
13 #include "mlir/Interfaces/SideEffectInterfaces.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16 
17 using namespace mlir;
18 
convertPDLToPDLInterp(ModuleOp pdlModule)19 static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
20   // Skip the conversion if the module doesn't contain pdl.
21   if (llvm::empty(pdlModule.getOps<pdl::PatternOp>()))
22     return success();
23 
24   // Simplify the provided PDL module. Note that we can't use the canonicalizer
25   // here because it would create a cyclic dependency.
26   auto simplifyFn = [](Operation *op) {
27     // TODO: Add folding here if ever necessary.
28     if (isOpTriviallyDead(op))
29       op->erase();
30   };
31   pdlModule.getBody()->walk(simplifyFn);
32 
33   /// Lower the PDL pattern module to the interpreter dialect.
34   PassManager pdlPipeline(pdlModule.getContext());
35 #ifdef NDEBUG
36   // We don't want to incur the hit of running the verifier when in release
37   // mode.
38   pdlPipeline.enableVerifier(false);
39 #endif
40   pdlPipeline.addPass(createPDLToPDLInterpPass());
41   if (failed(pdlPipeline.run(pdlModule)))
42     return failure();
43 
44   // Simplify again after running the lowering pipeline.
45   pdlModule.getBody()->walk(simplifyFn);
46   return success();
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // FrozenRewritePatternSet
51 //===----------------------------------------------------------------------===//
52 
FrozenRewritePatternSet()53 FrozenRewritePatternSet::FrozenRewritePatternSet()
54     : impl(std::make_shared<Impl>()) {}
55 
FrozenRewritePatternSet(RewritePatternSet && patterns,ArrayRef<std::string> disabledPatternLabels,ArrayRef<std::string> enabledPatternLabels)56 FrozenRewritePatternSet::FrozenRewritePatternSet(
57     RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
58     ArrayRef<std::string> enabledPatternLabels)
59     : impl(std::make_shared<Impl>()) {
60   DenseSet<StringRef> disabledPatterns, enabledPatterns;
61   disabledPatterns.insert(disabledPatternLabels.begin(),
62                           disabledPatternLabels.end());
63   enabledPatterns.insert(enabledPatternLabels.begin(),
64                          enabledPatternLabels.end());
65 
66   // Functor used to walk all of the operations registered in the context. This
67   // is useful for patterns that get applied to multiple operations, such as
68   // interface and trait based patterns.
69   std::vector<AbstractOperation *> abstractOps;
70   auto addToOpsWhen = [&](std::unique_ptr<RewritePattern> &pattern,
71                           function_ref<bool(AbstractOperation *)> callbackFn) {
72     if (abstractOps.empty())
73       abstractOps = pattern->getContext()->getRegisteredOperations();
74     for (AbstractOperation *absOp : abstractOps) {
75       if (callbackFn(absOp)) {
76         OperationName opName(absOp);
77         impl->nativeOpSpecificPatternMap[opName].push_back(pattern.get());
78       }
79     }
80     impl->nativeOpSpecificPatternList.push_back(std::move(pattern));
81   };
82 
83   for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
84     // Don't add patterns that haven't been enabled by the user.
85     if (!enabledPatterns.empty()) {
86       auto isEnabledFn = [&](StringRef label) {
87         return enabledPatterns.count(label);
88       };
89       if (!isEnabledFn(pat->getDebugName()) &&
90           llvm::none_of(pat->getDebugLabels(), isEnabledFn))
91         continue;
92     }
93     // Don't add patterns that have been disabled by the user.
94     if (!disabledPatterns.empty()) {
95       auto isDisabledFn = [&](StringRef label) {
96         return disabledPatterns.count(label);
97       };
98       if (isDisabledFn(pat->getDebugName()) ||
99           llvm::any_of(pat->getDebugLabels(), isDisabledFn))
100         continue;
101     }
102 
103     if (Optional<OperationName> rootName = pat->getRootKind()) {
104       impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
105       impl->nativeOpSpecificPatternList.push_back(std::move(pat));
106       continue;
107     }
108     if (Optional<TypeID> interfaceID = pat->getRootInterfaceID()) {
109       addToOpsWhen(pat, [&](AbstractOperation *absOp) {
110         return absOp->hasInterface(*interfaceID);
111       });
112       continue;
113     }
114     if (Optional<TypeID> traitID = pat->getRootTraitID()) {
115       addToOpsWhen(pat, [&](AbstractOperation *absOp) {
116         return absOp->hasTrait(*traitID);
117       });
118       continue;
119     }
120     impl->nativeAnyOpPatterns.push_back(std::move(pat));
121   }
122 
123   // Generate the bytecode for the PDL patterns if any were provided.
124   PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
125   ModuleOp pdlModule = pdlPatterns.getModule();
126   if (!pdlModule)
127     return;
128   if (failed(convertPDLToPDLInterp(pdlModule)))
129     llvm::report_fatal_error(
130         "failed to lower PDL pattern module to the PDL Interpreter");
131 
132   // Generate the pdl bytecode.
133   impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
134       pdlModule, pdlPatterns.takeConstraintFunctions(),
135       pdlPatterns.takeRewriteFunctions());
136 }
137 
~FrozenRewritePatternSet()138 FrozenRewritePatternSet::~FrozenRewritePatternSet() {}
139