1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Rewrite/PatternApplicator.h"
15 #include "llvm/Support/Debug.h"
16
17 using namespace mlir;
18
19 #define DEBUG_TYPE "pattern-match"
20
applyCostModel(CostModel model)21 void PatternApplicator::applyCostModel(CostModel model) {
22 // Separate patterns by root kind to simplify lookup later on.
23 patterns.clear();
24 anyOpPatterns.clear();
25 for (const auto &pat : frozenPatternList.getPatterns()) {
26 // If the pattern is always impossible to match, just ignore it.
27 if (pat.getBenefit().isImpossibleToMatch()) {
28 LLVM_DEBUG({
29 llvm::dbgs()
30 << "Ignoring pattern '" << pat.getRootKind()
31 << "' because it is impossible to match (by pattern benefit)\n";
32 });
33 continue;
34 }
35 if (Optional<OperationName> opName = pat.getRootKind())
36 patterns[*opName].push_back(&pat);
37 else
38 anyOpPatterns.push_back(&pat);
39 }
40
41 // Sort the patterns using the provided cost model.
42 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
43 auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
44 return benefits[lhs] > benefits[rhs];
45 };
46 auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
47 // Special case for one pattern in the list, which is the most common case.
48 if (list.size() == 1) {
49 if (model(*list.front()).isImpossibleToMatch()) {
50 LLVM_DEBUG({
51 llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
52 << "' because it is impossible to match or cannot lead "
53 "to legal IR (by cost model)\n";
54 });
55 list.clear();
56 }
57 return;
58 }
59
60 // Collect the dynamic benefits for the current pattern list.
61 benefits.clear();
62 for (const Pattern *pat : list)
63 benefits.try_emplace(pat, model(*pat));
64
65 // Sort patterns with highest benefit first, and remove those that are
66 // impossible to match.
67 std::stable_sort(list.begin(), list.end(), cmp);
68 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
69 LLVM_DEBUG({
70 llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
71 << "' because it is impossible to match or cannot lead to "
72 "legal IR (by cost model)\n";
73 });
74 list.pop_back();
75 }
76 };
77 for (auto &it : patterns)
78 processPatternList(it.second);
79 processPatternList(anyOpPatterns);
80 }
81
walkAllPatterns(function_ref<void (const Pattern &)> walk)82 void PatternApplicator::walkAllPatterns(
83 function_ref<void(const Pattern &)> walk) {
84 for (auto &it : frozenPatternList.getPatterns())
85 walk(it);
86 }
87
matchAndRewrite(Operation * op,PatternRewriter & rewriter,function_ref<bool (const Pattern &)> canApply,function_ref<void (const Pattern &)> onFailure,function_ref<LogicalResult (const Pattern &)> onSuccess)88 LogicalResult PatternApplicator::matchAndRewrite(
89 Operation *op, PatternRewriter &rewriter,
90 function_ref<bool(const Pattern &)> canApply,
91 function_ref<void(const Pattern &)> onFailure,
92 function_ref<LogicalResult(const Pattern &)> onSuccess) {
93 // Check to see if there are patterns matching this specific operation type.
94 MutableArrayRef<const RewritePattern *> opPatterns;
95 auto patternIt = patterns.find(op->getName());
96 if (patternIt != patterns.end())
97 opPatterns = patternIt->second;
98
99 // Process the patterns for that match the specific operation type, and any
100 // operation type in an interleaved fashion.
101 // FIXME: It'd be nice to just write an llvm::make_merge_range utility
102 // and pass in a comparison function. That would make this code trivial.
103 auto opIt = opPatterns.begin(), opE = opPatterns.end();
104 auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
105 while (opIt != opE && anyIt != anyE) {
106 // Try to match the pattern providing the most benefit.
107 const RewritePattern *pattern;
108 if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
109 pattern = *(opIt++);
110 else
111 pattern = *(anyIt++);
112
113 // Otherwise, try to match the generic pattern.
114 if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
115 onSuccess)))
116 return success();
117 }
118 // If we break from the loop, then only one of the ranges can still have
119 // elements. Loop over both without checking given that we don't need to
120 // interleave anymore.
121 for (const RewritePattern *pattern : llvm::concat<const RewritePattern *>(
122 llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
123 if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
124 onSuccess)))
125 return success();
126 }
127 return failure();
128 }
129
matchAndRewrite(Operation * op,const RewritePattern & pattern,PatternRewriter & rewriter,function_ref<bool (const Pattern &)> canApply,function_ref<void (const Pattern &)> onFailure,function_ref<LogicalResult (const Pattern &)> onSuccess)130 LogicalResult PatternApplicator::matchAndRewrite(
131 Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
132 function_ref<bool(const Pattern &)> canApply,
133 function_ref<void(const Pattern &)> onFailure,
134 function_ref<LogicalResult(const Pattern &)> onSuccess) {
135 // Check that the pattern can be applied.
136 if (canApply && !canApply(pattern))
137 return failure();
138
139 // Try to match and rewrite this pattern. The patterns are sorted by
140 // benefit, so if we match we can immediately rewrite.
141 rewriter.setInsertionPoint(op);
142 if (succeeded(pattern.matchAndRewrite(op, rewriter)))
143 return success(!onSuccess || succeeded(onSuccess(pattern)));
144
145 if (onFailure)
146 onFailure(pattern);
147 return failure();
148 }
149