1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Rewrite/PatternApplicator.h"
15 #include "ByteCode.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "pattern-application"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
PatternApplicator(const FrozenRewritePatternSet & frozenPatternList)23 PatternApplicator::PatternApplicator(
24     const FrozenRewritePatternSet &frozenPatternList)
25     : frozenPatternList(frozenPatternList) {
26   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
27     mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
28     bytecode->initializeMutableState(*mutableByteCodeState);
29   }
30 }
~PatternApplicator()31 PatternApplicator::~PatternApplicator() {}
32 
33 #ifndef NDEBUG
34 /// Log a message for a pattern that is impossible to match.
logImpossibleToMatch(const Pattern & pattern)35 static void logImpossibleToMatch(const Pattern &pattern) {
36     llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
37                  << "' because it is impossible to match or cannot lead "
38                     "to legal IR (by cost model)\n";
39 }
40 
41 /// Log IR after pattern application.
getDumpRootOp(Operation * op)42 static Operation *getDumpRootOp(Operation *op) {
43   return op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>();
44 }
logSucessfulPatternApplication(Operation * op)45 static void logSucessfulPatternApplication(Operation *op) {
46   llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n";
47   op->dump();
48   llvm::dbgs() << "\n\n";
49 }
50 #endif
51 
applyCostModel(CostModel model)52 void PatternApplicator::applyCostModel(CostModel model) {
53   // Apply the cost model to the bytecode patterns first, and then the native
54   // patterns.
55   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
56     for (auto it : llvm::enumerate(bytecode->getPatterns()))
57       mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
58   }
59 
60   // Copy over the patterns so that we can sort by benefit based on the cost
61   // model. Patterns that are already impossible to match are ignored.
62   patterns.clear();
63   for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
64     for (const RewritePattern *pattern : it.second) {
65       if (pattern->getBenefit().isImpossibleToMatch())
66         LLVM_DEBUG(logImpossibleToMatch(*pattern));
67       else
68         patterns[it.first].push_back(pattern);
69     }
70   }
71   anyOpPatterns.clear();
72   for (const RewritePattern &pattern :
73        frozenPatternList.getMatchAnyOpNativePatterns()) {
74     if (pattern.getBenefit().isImpossibleToMatch())
75       LLVM_DEBUG(logImpossibleToMatch(pattern));
76     else
77       anyOpPatterns.push_back(&pattern);
78   }
79 
80   // Sort the patterns using the provided cost model.
81   llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
82   auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
83     return benefits[lhs] > benefits[rhs];
84   };
85   auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
86     // Special case for one pattern in the list, which is the most common case.
87     if (list.size() == 1) {
88       if (model(*list.front()).isImpossibleToMatch()) {
89         LLVM_DEBUG(logImpossibleToMatch(*list.front()));
90         list.clear();
91       }
92       return;
93     }
94 
95     // Collect the dynamic benefits for the current pattern list.
96     benefits.clear();
97     for (const Pattern *pat : list)
98       benefits.try_emplace(pat, model(*pat));
99 
100     // Sort patterns with highest benefit first, and remove those that are
101     // impossible to match.
102     std::stable_sort(list.begin(), list.end(), cmp);
103     while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
104       LLVM_DEBUG(logImpossibleToMatch(*list.back()));
105       list.pop_back();
106     }
107   };
108   for (auto &it : patterns)
109     processPatternList(it.second);
110   processPatternList(anyOpPatterns);
111 }
112 
walkAllPatterns(function_ref<void (const Pattern &)> walk)113 void PatternApplicator::walkAllPatterns(
114     function_ref<void(const Pattern &)> walk) {
115   for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
116     for (const auto &pattern : it.second)
117       walk(*pattern);
118   for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
119     walk(it);
120   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
121     for (const Pattern &it : bytecode->getPatterns())
122       walk(it);
123   }
124 }
125 
matchAndRewrite(Operation * op,PatternRewriter & rewriter,function_ref<bool (const Pattern &)> canApply,function_ref<void (const Pattern &)> onFailure,function_ref<LogicalResult (const Pattern &)> onSuccess)126 LogicalResult PatternApplicator::matchAndRewrite(
127     Operation *op, PatternRewriter &rewriter,
128     function_ref<bool(const Pattern &)> canApply,
129     function_ref<void(const Pattern &)> onFailure,
130     function_ref<LogicalResult(const Pattern &)> onSuccess) {
131   // Before checking native patterns, first match against the bytecode. This
132   // won't automatically perform any rewrites so there is no need to worry about
133   // conflicts.
134   SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
135   const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
136   if (bytecode)
137     bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
138 
139   // Check to see if there are patterns matching this specific operation type.
140   MutableArrayRef<const RewritePattern *> opPatterns;
141   auto patternIt = patterns.find(op->getName());
142   if (patternIt != patterns.end())
143     opPatterns = patternIt->second;
144 
145   // Process the patterns for that match the specific operation type, and any
146   // operation type in an interleaved fashion.
147   unsigned opIt = 0, opE = opPatterns.size();
148   unsigned anyIt = 0, anyE = anyOpPatterns.size();
149   unsigned pdlIt = 0, pdlE = pdlMatches.size();
150   LogicalResult result = failure();
151   do {
152     // Find the next pattern with the highest benefit.
153     const Pattern *bestPattern = nullptr;
154     unsigned *bestPatternIt = &opIt;
155     const PDLByteCode::MatchResult *pdlMatch = nullptr;
156 
157     /// Operation specific patterns.
158     if (opIt < opE)
159       bestPattern = opPatterns[opIt];
160     /// Operation agnostic patterns.
161     if (anyIt < anyE &&
162         (!bestPattern ||
163          bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
164       bestPatternIt = &anyIt;
165       bestPattern = anyOpPatterns[anyIt];
166     }
167     /// PDL patterns.
168     if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
169                                              pdlMatches[pdlIt].benefit)) {
170       bestPatternIt = &pdlIt;
171       pdlMatch = &pdlMatches[pdlIt];
172       bestPattern = pdlMatch->pattern;
173     }
174     if (!bestPattern)
175       break;
176 
177     // Update the pattern iterator on failure so that this pattern isn't
178     // attempted again.
179     ++(*bestPatternIt);
180 
181     // Check that the pattern can be applied.
182     if (canApply && !canApply(*bestPattern))
183       continue;
184 
185     // Try to match and rewrite this pattern. The patterns are sorted by
186     // benefit, so if we match we can immediately rewrite. For PDL patterns, the
187     // match has already been performed, we just need to rewrite.
188     rewriter.setInsertionPoint(op);
189 #ifndef NDEBUG
190     // Operation `op` may be invalidated after applying the rewrite pattern.
191     Operation *dumpRootOp = getDumpRootOp(op);
192 #endif
193     if (pdlMatch) {
194       bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
195       result = success(!onSuccess || succeeded(onSuccess(*bestPattern)));
196     } else {
197       const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
198 
199       LLVM_DEBUG(llvm::dbgs()
200                  << "Trying to match \"" << pattern->getDebugName() << "\"\n");
201       result = pattern->matchAndRewrite(op, rewriter);
202       LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result "
203                               << succeeded(result) << "\n");
204 
205       if (succeeded(result) && onSuccess && failed(onSuccess(*pattern)))
206         result = failure();
207     }
208     if (succeeded(result)) {
209       LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
210       break;
211     }
212 
213     // Perform any necessary cleanups.
214     if (onFailure)
215       onFailure(*bestPattern);
216   } while (true);
217 
218   if (mutableByteCodeState)
219     mutableByteCodeState->cleanupAfterMatchAndRewrite();
220   return result;
221 }
222