1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Rewrite/PatternApplicator.h"
15 #include "ByteCode.h"
16 #include "llvm/Support/Debug.h"
17 
18 using namespace mlir;
19 using namespace mlir::detail;
20 
PatternApplicator(const FrozenRewritePatternList & frozenPatternList)21 PatternApplicator::PatternApplicator(
22     const FrozenRewritePatternList &frozenPatternList)
23     : frozenPatternList(frozenPatternList) {
24   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
25     mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
26     bytecode->initializeMutableState(*mutableByteCodeState);
27   }
28 }
~PatternApplicator()29 PatternApplicator::~PatternApplicator() {}
30 
31 #define DEBUG_TYPE "pattern-match"
32 
applyCostModel(CostModel model)33 void PatternApplicator::applyCostModel(CostModel model) {
34   // Apply the cost model to the bytecode patterns first, and then the native
35   // patterns.
36   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
37     for (auto it : llvm::enumerate(bytecode->getPatterns()))
38       mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
39   }
40 
41   // Separate patterns by root kind to simplify lookup later on.
42   patterns.clear();
43   anyOpPatterns.clear();
44   for (const auto &pat : frozenPatternList.getNativePatterns()) {
45     // If the pattern is always impossible to match, just ignore it.
46     if (pat.getBenefit().isImpossibleToMatch()) {
47       LLVM_DEBUG({
48         llvm::dbgs()
49             << "Ignoring pattern '" << pat.getRootKind()
50             << "' because it is impossible to match (by pattern benefit)\n";
51       });
52       continue;
53     }
54     if (Optional<OperationName> opName = pat.getRootKind())
55       patterns[*opName].push_back(&pat);
56     else
57       anyOpPatterns.push_back(&pat);
58   }
59 
60   // Sort the patterns using the provided cost model.
61   llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
62   auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
63     return benefits[lhs] > benefits[rhs];
64   };
65   auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
66     // Special case for one pattern in the list, which is the most common case.
67     if (list.size() == 1) {
68       if (model(*list.front()).isImpossibleToMatch()) {
69         LLVM_DEBUG({
70           llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
71                        << "' because it is impossible to match or cannot lead "
72                           "to legal IR (by cost model)\n";
73         });
74         list.clear();
75       }
76       return;
77     }
78 
79     // Collect the dynamic benefits for the current pattern list.
80     benefits.clear();
81     for (const Pattern *pat : list)
82       benefits.try_emplace(pat, model(*pat));
83 
84     // Sort patterns with highest benefit first, and remove those that are
85     // impossible to match.
86     std::stable_sort(list.begin(), list.end(), cmp);
87     while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
88       LLVM_DEBUG({
89         llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
90                      << "' because it is impossible to match or cannot lead to "
91                         "legal IR (by cost model)\n";
92       });
93       list.pop_back();
94     }
95   };
96   for (auto &it : patterns)
97     processPatternList(it.second);
98   processPatternList(anyOpPatterns);
99 }
100 
walkAllPatterns(function_ref<void (const Pattern &)> walk)101 void PatternApplicator::walkAllPatterns(
102     function_ref<void(const Pattern &)> walk) {
103   for (const Pattern &it : frozenPatternList.getNativePatterns())
104     walk(it);
105   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
106     for (const Pattern &it : bytecode->getPatterns())
107       walk(it);
108   }
109 }
110 
matchAndRewrite(Operation * op,PatternRewriter & rewriter,function_ref<bool (const Pattern &)> canApply,function_ref<void (const Pattern &)> onFailure,function_ref<LogicalResult (const Pattern &)> onSuccess)111 LogicalResult PatternApplicator::matchAndRewrite(
112     Operation *op, PatternRewriter &rewriter,
113     function_ref<bool(const Pattern &)> canApply,
114     function_ref<void(const Pattern &)> onFailure,
115     function_ref<LogicalResult(const Pattern &)> onSuccess) {
116   // Before checking native patterns, first match against the bytecode. This
117   // won't automatically perform any rewrites so there is no need to worry about
118   // conflicts.
119   SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
120   const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
121   if (bytecode)
122     bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
123 
124   // Check to see if there are patterns matching this specific operation type.
125   MutableArrayRef<const RewritePattern *> opPatterns;
126   auto patternIt = patterns.find(op->getName());
127   if (patternIt != patterns.end())
128     opPatterns = patternIt->second;
129 
130   // Process the patterns for that match the specific operation type, and any
131   // operation type in an interleaved fashion.
132   auto opIt = opPatterns.begin(), opE = opPatterns.end();
133   auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
134   auto pdlIt = pdlMatches.begin(), pdlE = pdlMatches.end();
135   while (true) {
136     // Find the next pattern with the highest benefit.
137     const Pattern *bestPattern = nullptr;
138     const PDLByteCode::MatchResult *pdlMatch = nullptr;
139     /// Operation specific patterns.
140     if (opIt != opE)
141       bestPattern = *(opIt++);
142     /// Operation agnostic patterns.
143     if (anyIt != anyE &&
144         (!bestPattern || bestPattern->getBenefit() < (*anyIt)->getBenefit()))
145       bestPattern = *(anyIt++);
146     /// PDL patterns.
147     if (pdlIt != pdlE &&
148         (!bestPattern || bestPattern->getBenefit() < pdlIt->benefit)) {
149       pdlMatch = pdlIt;
150       bestPattern = (pdlIt++)->pattern;
151     }
152     if (!bestPattern)
153       break;
154 
155     // Check that the pattern can be applied.
156     if (canApply && !canApply(*bestPattern))
157       continue;
158 
159     // Try to match and rewrite this pattern. The patterns are sorted by
160     // benefit, so if we match we can immediately rewrite. For PDL patterns, the
161     // match has already been performed, we just need to rewrite.
162     rewriter.setInsertionPoint(op);
163     LogicalResult result = success();
164     if (pdlMatch) {
165       bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
166     } else {
167       result = static_cast<const RewritePattern *>(bestPattern)
168                    ->matchAndRewrite(op, rewriter);
169     }
170     if (succeeded(result) && (!onSuccess || succeeded(onSuccess(*bestPattern))))
171       return success();
172 
173     // Perform any necessary cleanups.
174     if (onFailure)
175       onFailure(*bestPattern);
176   }
177   return failure();
178 }
179