1 //===- FrozenRewritePatternList.cpp - Frozen Pattern List -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Rewrite/FrozenRewritePatternList.h"
10 #include "ByteCode.h"
11 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
12 #include "mlir/Dialect/PDL/IR/PDLOps.h"
13 #include "mlir/Interfaces/SideEffectInterfaces.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassManager.h"
16 
17 using namespace mlir;
18 
convertPDLToPDLInterp(ModuleOp pdlModule)19 static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) {
20   // Skip the conversion if the module doesn't contain pdl.
21   if (llvm::empty(pdlModule.getOps<pdl::PatternOp>()))
22     return success();
23 
24   // Simplify the provided PDL module. Note that we can't use the canonicalizer
25   // here because it would create a cyclic dependency.
26   auto simplifyFn = [](Operation *op) {
27     // TODO: Add folding here if ever necessary.
28     if (isOpTriviallyDead(op))
29       op->erase();
30   };
31   pdlModule.getBody()->walk(simplifyFn);
32 
33   /// Lower the PDL pattern module to the interpreter dialect.
34   PassManager pdlPipeline(pdlModule.getContext());
35 #ifdef NDEBUG
36   // We don't want to incur the hit of running the verifier when in release
37   // mode.
38   pdlPipeline.enableVerifier(false);
39 #endif
40   pdlPipeline.addPass(createPDLToPDLInterpPass());
41   if (failed(pdlPipeline.run(pdlModule)))
42     return failure();
43 
44   // Simplify again after running the lowering pipeline.
45   pdlModule.getBody()->walk(simplifyFn);
46   return success();
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // FrozenRewritePatternList
51 //===----------------------------------------------------------------------===//
52 
FrozenRewritePatternList()53 FrozenRewritePatternList::FrozenRewritePatternList()
54     : impl(std::make_shared<Impl>()) {}
55 
FrozenRewritePatternList(OwningRewritePatternList && patterns)56 FrozenRewritePatternList::FrozenRewritePatternList(
57     OwningRewritePatternList &&patterns)
58     : impl(std::make_shared<Impl>()) {
59   impl->nativePatterns = std::move(patterns.getNativePatterns());
60 
61   // Generate the bytecode for the PDL patterns if any were provided.
62   PDLPatternModule &pdlPatterns = patterns.getPDLPatterns();
63   ModuleOp pdlModule = pdlPatterns.getModule();
64   if (!pdlModule)
65     return;
66   if (failed(convertPDLToPDLInterp(pdlModule)))
67     llvm::report_fatal_error(
68         "failed to lower PDL pattern module to the PDL Interpreter");
69 
70   // Generate the pdl bytecode.
71   impl->pdlByteCode = std::make_unique<detail::PDLByteCode>(
72       pdlModule, pdlPatterns.takeConstraintFunctions(),
73       pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions());
74 }
75 
~FrozenRewritePatternList()76 FrozenRewritePatternList::~FrozenRewritePatternList() {}
77