1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Rewrite/PatternApplicator.h"
15 #include "llvm/Support/Debug.h"
16 
17 using namespace mlir;
18 
19 #define DEBUG_TYPE "pattern-match"
20 
applyCostModel(CostModel model)21 void PatternApplicator::applyCostModel(CostModel model) {
22   // Separate patterns by root kind to simplify lookup later on.
23   patterns.clear();
24   anyOpPatterns.clear();
25   for (const auto &pat : frozenPatternList.getPatterns()) {
26     // If the pattern is always impossible to match, just ignore it.
27     if (pat.getBenefit().isImpossibleToMatch()) {
28       LLVM_DEBUG({
29         llvm::dbgs()
30             << "Ignoring pattern '" << pat.getRootKind()
31             << "' because it is impossible to match (by pattern benefit)\n";
32       });
33       continue;
34     }
35     if (Optional<OperationName> opName = pat.getRootKind())
36       patterns[*opName].push_back(&pat);
37     else
38       anyOpPatterns.push_back(&pat);
39   }
40 
41   // Sort the patterns using the provided cost model.
42   llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
43   auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
44     return benefits[lhs] > benefits[rhs];
45   };
46   auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
47     // Special case for one pattern in the list, which is the most common case.
48     if (list.size() == 1) {
49       if (model(*list.front()).isImpossibleToMatch()) {
50         LLVM_DEBUG({
51           llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
52                        << "' because it is impossible to match or cannot lead "
53                           "to legal IR (by cost model)\n";
54         });
55         list.clear();
56       }
57       return;
58     }
59 
60     // Collect the dynamic benefits for the current pattern list.
61     benefits.clear();
62     for (const Pattern *pat : list)
63       benefits.try_emplace(pat, model(*pat));
64 
65     // Sort patterns with highest benefit first, and remove those that are
66     // impossible to match.
67     std::stable_sort(list.begin(), list.end(), cmp);
68     while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
69       LLVM_DEBUG({
70         llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
71                      << "' because it is impossible to match or cannot lead to "
72                         "legal IR (by cost model)\n";
73       });
74       list.pop_back();
75     }
76   };
77   for (auto &it : patterns)
78     processPatternList(it.second);
79   processPatternList(anyOpPatterns);
80 }
81 
walkAllPatterns(function_ref<void (const Pattern &)> walk)82 void PatternApplicator::walkAllPatterns(
83     function_ref<void(const Pattern &)> walk) {
84   for (auto &it : frozenPatternList.getPatterns())
85     walk(it);
86 }
87 
matchAndRewrite(Operation * op,PatternRewriter & rewriter,function_ref<bool (const Pattern &)> canApply,function_ref<void (const Pattern &)> onFailure,function_ref<LogicalResult (const Pattern &)> onSuccess)88 LogicalResult PatternApplicator::matchAndRewrite(
89     Operation *op, PatternRewriter &rewriter,
90     function_ref<bool(const Pattern &)> canApply,
91     function_ref<void(const Pattern &)> onFailure,
92     function_ref<LogicalResult(const Pattern &)> onSuccess) {
93   // Check to see if there are patterns matching this specific operation type.
94   MutableArrayRef<const RewritePattern *> opPatterns;
95   auto patternIt = patterns.find(op->getName());
96   if (patternIt != patterns.end())
97     opPatterns = patternIt->second;
98 
99   // Process the patterns for that match the specific operation type, and any
100   // operation type in an interleaved fashion.
101   // FIXME: It'd be nice to just write an llvm::make_merge_range utility
102   // and pass in a comparison function. That would make this code trivial.
103   auto opIt = opPatterns.begin(), opE = opPatterns.end();
104   auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
105   while (opIt != opE && anyIt != anyE) {
106     // Try to match the pattern providing the most benefit.
107     const RewritePattern *pattern;
108     if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
109       pattern = *(opIt++);
110     else
111       pattern = *(anyIt++);
112 
113     // Otherwise, try to match the generic pattern.
114     if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
115                                   onSuccess)))
116       return success();
117   }
118   // If we break from the loop, then only one of the ranges can still have
119   // elements. Loop over both without checking given that we don't need to
120   // interleave anymore.
121   for (const RewritePattern *pattern : llvm::concat<const RewritePattern *>(
122            llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
123     if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
124                                   onSuccess)))
125       return success();
126   }
127   return failure();
128 }
129 
matchAndRewrite(Operation * op,const RewritePattern & pattern,PatternRewriter & rewriter,function_ref<bool (const Pattern &)> canApply,function_ref<void (const Pattern &)> onFailure,function_ref<LogicalResult (const Pattern &)> onSuccess)130 LogicalResult PatternApplicator::matchAndRewrite(
131     Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
132     function_ref<bool(const Pattern &)> canApply,
133     function_ref<void(const Pattern &)> onFailure,
134     function_ref<LogicalResult(const Pattern &)> onSuccess) {
135   // Check that the pattern can be applied.
136   if (canApply && !canApply(pattern))
137     return failure();
138 
139   // Try to match and rewrite this pattern. The patterns are sorted by
140   // benefit, so if we match we can immediately rewrite.
141   rewriter.setInsertionPoint(op);
142   if (succeeded(pattern.matchAndRewrite(op, rewriter)))
143     return success(!onSuccess || succeeded(onSuccess(pattern)));
144 
145   if (onFailure)
146     onFailure(pattern);
147   return failure();
148 }
149