1 //===- FrozenRewritePatternSet.cpp - Frozen Pattern List -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Rewrite/FrozenRewritePatternSet.h" 10 #include "ByteCode.h" 11 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" 12 #include "mlir/Dialect/PDL/IR/PDLOps.h" 13 #include "mlir/Interfaces/SideEffectInterfaces.h" 14 #include "mlir/Pass/Pass.h" 15 #include "mlir/Pass/PassManager.h" 16 17 using namespace mlir; 18 19 static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) { 20 // Skip the conversion if the module doesn't contain pdl. 21 if (llvm::empty(pdlModule.getOps<pdl::PatternOp>())) 22 return success(); 23 24 // Simplify the provided PDL module. Note that we can't use the canonicalizer 25 // here because it would create a cyclic dependency. 26 auto simplifyFn = [](Operation *op) { 27 // TODO: Add folding here if ever necessary. 28 if (isOpTriviallyDead(op)) 29 op->erase(); 30 }; 31 pdlModule.getBody()->walk(simplifyFn); 32 33 /// Lower the PDL pattern module to the interpreter dialect. 34 PassManager pdlPipeline(pdlModule.getContext()); 35 #ifdef NDEBUG 36 // We don't want to incur the hit of running the verifier when in release 37 // mode. 38 pdlPipeline.enableVerifier(false); 39 #endif 40 pdlPipeline.addPass(createPDLToPDLInterpPass()); 41 if (failed(pdlPipeline.run(pdlModule))) 42 return failure(); 43 44 // Simplify again after running the lowering pipeline. 45 pdlModule.getBody()->walk(simplifyFn); 46 return success(); 47 } 48 49 //===----------------------------------------------------------------------===// 50 // FrozenRewritePatternSet 51 //===----------------------------------------------------------------------===// 52 53 FrozenRewritePatternSet::FrozenRewritePatternSet() 54 : impl(std::make_shared<Impl>()) {} 55 56 FrozenRewritePatternSet::FrozenRewritePatternSet( 57 RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels, 58 ArrayRef<std::string> enabledPatternLabels) 59 : impl(std::make_shared<Impl>()) { 60 DenseSet<StringRef> disabledPatterns, enabledPatterns; 61 disabledPatterns.insert(disabledPatternLabels.begin(), 62 disabledPatternLabels.end()); 63 enabledPatterns.insert(enabledPatternLabels.begin(), 64 enabledPatternLabels.end()); 65 66 // Functor used to walk all of the operations registered in the context. This 67 // is useful for patterns that get applied to multiple operations, such as 68 // interface and trait based patterns. 69 std::vector<AbstractOperation *> abstractOps; 70 auto addToOpsWhen = [&](std::unique_ptr<RewritePattern> &pattern, 71 function_ref<bool(AbstractOperation *)> callbackFn) { 72 if (abstractOps.empty()) 73 abstractOps = pattern->getContext()->getRegisteredOperations(); 74 for (AbstractOperation *absOp : abstractOps) { 75 if (callbackFn(absOp)) { 76 OperationName opName(absOp); 77 impl->nativeOpSpecificPatternMap[opName].push_back(pattern.get()); 78 } 79 } disconnect_atexit(void)80 impl->nativeOpSpecificPatternList.push_back(std::move(pattern)); 81 }; 82 83 for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) { 84 // Don't add patterns that haven't been enabled by the user. 85 if (!enabledPatterns.empty()) { 86 auto isEnabledFn = [&](StringRef label) { 87 return enabledPatterns.count(label); 88 }; 89 if (!isEnabledFn(pat->getDebugName()) && main(int argc,char ** argv)90 llvm::none_of(pat->getDebugLabels(), isEnabledFn)) 91 continue; 92 } 93 // Don't add patterns that have been disabled by the user. 94 if (!disabledPatterns.empty()) { 95 auto isDisabledFn = [&](StringRef label) { 96 return disabledPatterns.count(label); 97 }; 98 if (isDisabledFn(pat->getDebugName()) || 99 llvm::any_of(pat->getDebugLabels(), isDisabledFn)) 100 continue; 101 } 102 103 if (Optional<OperationName> rootName = pat->getRootKind()) { 104 impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get()); 105 impl->nativeOpSpecificPatternList.push_back(std::move(pat)); 106 continue; 107 } 108 if (Optional<TypeID> interfaceID = pat->getRootInterfaceID()) { 109 addToOpsWhen(pat, [&](AbstractOperation *absOp) { 110 return absOp->hasInterface(*interfaceID); 111 }); 112 continue; 113 } 114 if (Optional<TypeID> traitID = pat->getRootTraitID()) { 115 addToOpsWhen(pat, [&](AbstractOperation *absOp) { 116 return absOp->hasTrait(*traitID); 117 }); 118 continue; 119 } 120 impl->nativeAnyOpPatterns.push_back(std::move(pat)); 121 } 122 123 // Generate the bytecode for the PDL patterns if any were provided. 124 PDLPatternModule &pdlPatterns = patterns.getPDLPatterns(); 125 ModuleOp pdlModule = pdlPatterns.getModule(); 126 if (!pdlModule) 127 return; 128 if (failed(convertPDLToPDLInterp(pdlModule))) 129 llvm::report_fatal_error( 130 "failed to lower PDL pattern module to the PDL Interpreter"); 131 132 // Generate the pdl bytecode. 133 impl->pdlByteCode = std::make_unique<detail::PDLByteCode>( 134 pdlModule, pdlPatterns.takeConstraintFunctions(), 135 pdlPatterns.takeRewriteFunctions()); 136 } 137 138 FrozenRewritePatternSet::~FrozenRewritePatternSet() {} 139