1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_PATTERNMATCHER_H
10 #define MLIR_PATTERNMATCHER_H
11 
12 #include "mlir/IR/Builders.h"
13 
14 namespace mlir {
15 
16 class PatternRewriter;
17 
18 //===----------------------------------------------------------------------===//
19 // PatternBenefit class
20 //===----------------------------------------------------------------------===//
21 
22 /// This class represents the benefit of a pattern match in a unitless scheme
23 /// that ranges from 0 (very little benefit) to 65K.  The most common unit to
24 /// use here is the "number of operations matched" by the pattern.
25 ///
26 /// This also has a sentinel representation that can be used for patterns that
27 /// fail to match.
28 ///
29 class PatternBenefit {
30   enum { ImpossibleToMatchSentinel = 65535 };
31 
32 public:
33   /*implicit*/ PatternBenefit(unsigned benefit);
34   PatternBenefit(const PatternBenefit &) = default;
35   PatternBenefit &operator=(const PatternBenefit &) = default;
36 
impossibleToMatch()37   static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
isImpossibleToMatch()38   bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
39 
40   /// If the corresponding pattern can match, return its benefit.  If the
41   // corresponding pattern isImpossibleToMatch() then this aborts.
42   unsigned short getBenefit() const;
43 
44   bool operator==(const PatternBenefit &rhs) const {
45     return representation == rhs.representation;
46   }
47   bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
48   bool operator<(const PatternBenefit &rhs) const {
49     return representation < rhs.representation;
50   }
51 
52 private:
PatternBenefit()53   PatternBenefit() : representation(ImpossibleToMatchSentinel) {}
54   unsigned short representation;
55 };
56 
57 /// Pattern state is used by patterns that want to maintain state between their
58 /// match and rewrite phases.  Patterns can define a pattern-specific subclass
59 /// of this.
60 class PatternState {
61 public:
~PatternState()62   virtual ~PatternState() {}
63 
64 protected:
65   // Must be subclassed.
PatternState()66   PatternState() {}
67 };
68 
69 /// This is the type returned by a pattern match.  A match failure returns a
70 /// None value.  A match success returns a Some value with any state the pattern
71 /// may need to maintain (but may also be null).
72 using PatternMatchResult = Optional<std::unique_ptr<PatternState>>;
73 
74 //===----------------------------------------------------------------------===//
75 // Pattern class
76 //===----------------------------------------------------------------------===//
77 
78 /// Instances of Pattern can be matched against SSA IR.  These matches get used
79 /// in ways dependent on their subclasses and the driver doing the matching.
80 /// For example, RewritePatterns implement a rewrite from one matched pattern
81 /// to a replacement DAG tile.
82 class Pattern {
83 public:
84   /// Return the benefit (the inverse of "cost") of matching this pattern.  The
85   /// benefit of a Pattern is always static - rewrites that may have dynamic
86   /// benefit can be instantiated multiple times (different Pattern instances)
87   /// for each benefit that they may return, and be guarded by different match
88   /// condition predicates.
getBenefit()89   PatternBenefit getBenefit() const { return benefit; }
90 
91   /// Return the root node that this pattern matches.  Patterns that can
92   /// match multiple root types are instantiated once per root.
getRootKind()93   OperationName getRootKind() const { return rootKind; }
94 
95   //===--------------------------------------------------------------------===//
96   // Implementation hooks for patterns to implement.
97   //===--------------------------------------------------------------------===//
98 
99   /// Attempt to match against code rooted at the specified operation,
100   /// which is the same operation code as getRootKind().  On failure, this
101   /// returns a None value.  On success it returns a (possibly null)
102   /// pattern-specific state wrapped in an Optional.
103   virtual PatternMatchResult match(Operation *op) const = 0;
104 
~Pattern()105   virtual ~Pattern() {}
106 
107   //===--------------------------------------------------------------------===//
108   // Helper methods to simplify pattern implementations
109   //===--------------------------------------------------------------------===//
110 
111   /// This method indicates that no match was found.
matchFailure()112   static PatternMatchResult matchFailure() { return None; }
113 
114   /// This method indicates that a match was found and has the specified cost.
115   PatternMatchResult
116   matchSuccess(std::unique_ptr<PatternState> state = {}) const {
117     return PatternMatchResult(std::move(state));
118   }
119 
120 protected:
121   /// Patterns must specify the root operation name they match against, and can
122   /// also specify the benefit of the pattern matching.
123   Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context);
124 
125 private:
126   const OperationName rootKind;
127   const PatternBenefit benefit;
128 
129   virtual void anchor();
130 };
131 
132 /// RewritePattern is the common base class for all DAG to DAG replacements.
133 /// There are two possible usages of this class:
134 ///   * Multi-step RewritePattern with "match" and "rewrite"
135 ///     - By overloading the "match" and "rewrite" functions, the user can
136 ///       separate the concerns of matching and rewriting.
137 ///   * Single-step RewritePattern with "matchAndRewrite"
138 ///     - By overloading the "matchAndRewrite" function, the user can perform
139 ///       the rewrite in the same call as the match. This removes the need for
140 ///       any PatternState.
141 ///
142 class RewritePattern : public Pattern {
143 public:
144   /// Rewrite the IR rooted at the specified operation with the result of
145   /// this pattern, generating any new operations with the specified
146   /// rewriter.  If an unexpected error is encountered (an internal
147   /// compiler error), it is emitted through the normal MLIR diagnostic
148   /// hooks and the IR is left in a valid state.
149   virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state,
150                        PatternRewriter &rewriter) const;
151 
152   /// Rewrite the IR rooted at the specified operation with the result of
153   /// this pattern, generating any new operations with the specified
154   /// builder.  If an unexpected error is encountered (an internal
155   /// compiler error), it is emitted through the normal MLIR diagnostic
156   /// hooks and the IR is left in a valid state.
157   virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
158 
159   /// Attempt to match against code rooted at the specified operation,
160   /// which is the same operation code as getRootKind().  On failure, this
161   /// returns a None value.  On success, it returns a (possibly null)
162   /// pattern-specific state wrapped in an Optional.  This state is passed back
163   /// into the rewrite function if this match is selected.
164   PatternMatchResult match(Operation *op) const override;
165 
166   /// Attempt to match against code rooted at the specified operation,
167   /// which is the same operation code as getRootKind(). If successful, this
168   /// function will automatically perform the rewrite.
matchAndRewrite(Operation * op,PatternRewriter & rewriter)169   virtual PatternMatchResult matchAndRewrite(Operation *op,
170                                              PatternRewriter &rewriter) const {
171     if (auto matchResult = match(op)) {
172       rewrite(op, std::move(*matchResult), rewriter);
173       return matchSuccess();
174     }
175     return matchFailure();
176   }
177 
178   /// Return a list of operations that may be generated when rewriting an
179   /// operation instance with this pattern.
getGeneratedOps()180   ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
181 
182 protected:
183   /// Patterns must specify the root operation name they match against, and can
184   /// also specify the benefit of the pattern matching.
RewritePattern(StringRef rootName,PatternBenefit benefit,MLIRContext * context)185   RewritePattern(StringRef rootName, PatternBenefit benefit,
186                  MLIRContext *context)
187       : Pattern(rootName, benefit, context) {}
188   /// Patterns must specify the root operation name they match against, and can
189   /// also specify the benefit of the pattern matching. They can also specify
190   /// the names of operations that may be generated during a successful rewrite.
191   RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
192                  PatternBenefit benefit, MLIRContext *context);
193 
194   /// A list of the potential operations that may be generated when rewriting
195   /// an op with this pattern.
196   SmallVector<OperationName, 2> generatedOps;
197 };
198 
199 /// OpRewritePattern is a wrapper around RewritePattern that allows for
200 /// matching and rewriting against an instance of a derived operation class as
201 /// opposed to a raw Operation.
202 template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
203   /// Patterns must specify the root operation name they match against, and can
204   /// also specify the benefit of the pattern matching.
205   OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
RewritePatternOpRewritePattern206       : RewritePattern(SourceOp::getOperationName(), benefit, context) {}
207 
208   /// Wrappers around the RewritePattern methods that pass the derived op type.
rewriteOpRewritePattern209   void rewrite(Operation *op, std::unique_ptr<PatternState> state,
210                PatternRewriter &rewriter) const final {
211     rewrite(cast<SourceOp>(op), std::move(state), rewriter);
212   }
rewriteOpRewritePattern213   void rewrite(Operation *op, PatternRewriter &rewriter) const final {
214     rewrite(cast<SourceOp>(op), rewriter);
215   }
matchOpRewritePattern216   PatternMatchResult match(Operation *op) const final {
217     return match(cast<SourceOp>(op));
218   }
matchAndRewriteOpRewritePattern219   PatternMatchResult matchAndRewrite(Operation *op,
220                                      PatternRewriter &rewriter) const final {
221     return matchAndRewrite(cast<SourceOp>(op), rewriter);
222   }
223 
224   /// Rewrite and Match methods that operate on the SourceOp type. These must be
225   /// overridden by the derived pattern class.
rewriteOpRewritePattern226   virtual void rewrite(SourceOp op, std::unique_ptr<PatternState> state,
227                        PatternRewriter &rewriter) const {
228     rewrite(op, rewriter);
229   }
rewriteOpRewritePattern230   virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
231     llvm_unreachable("must override matchAndRewrite or a rewrite method");
232   }
matchOpRewritePattern233   virtual PatternMatchResult match(SourceOp op) const {
234     llvm_unreachable("must override match or matchAndRewrite");
235   }
matchAndRewriteOpRewritePattern236   virtual PatternMatchResult matchAndRewrite(SourceOp op,
237                                              PatternRewriter &rewriter) const {
238     if (auto matchResult = match(op)) {
239       rewrite(op, std::move(*matchResult), rewriter);
240       return matchSuccess();
241     }
242     return matchFailure();
243   }
244 };
245 
246 //===----------------------------------------------------------------------===//
247 // PatternRewriter class
248 //===----------------------------------------------------------------------===//
249 
250 /// This class coordinates the application of a pattern to the current function,
251 /// providing a way to create operations and keep track of what gets deleted.
252 ///
253 /// These class serves two purposes:
254 ///  1) it is the interface that patterns interact with to make mutations to the
255 ///     IR they are being applied to.
256 ///  2) It is a base class that clients of the PatternMatcher use when they want
257 ///     to apply patterns and observe their effects (e.g. to keep worklists or
258 ///     other data structures up to date).
259 ///
260 class PatternRewriter : public OpBuilder {
261 public:
262   /// Create operation of specific op type at the current insertion point
263   /// without verifying to see if it is valid.
264   template <typename OpTy, typename... Args>
create(Location location,Args...args)265   OpTy create(Location location, Args... args) {
266     OperationState state(location, OpTy::getOperationName());
267     OpTy::build(this, state, args...);
268     auto *op = createOperation(state);
269     auto result = dyn_cast<OpTy>(op);
270     assert(result && "Builder didn't return the right type");
271     return result;
272   }
273 
274   /// Creates an operation of specific op type at the current insertion point.
275   /// If the result is an invalid op (the verifier hook fails), emit an error
276   /// and return null.
277   template <typename OpTy, typename... Args>
createChecked(Location location,Args...args)278   OpTy createChecked(Location location, Args... args) {
279     OperationState state(location, OpTy::getOperationName());
280     OpTy::build(this, state, args...);
281     auto *op = createOperation(state);
282 
283     // If the Operation we produce is valid, return it.
284     if (!OpTy::verifyInvariants(op)) {
285       auto result = dyn_cast<OpTy>(op);
286       assert(result && "Builder didn't return the right type");
287       return result;
288     }
289 
290     // Otherwise, the error message got emitted.  Just remove the operation
291     // we made.
292     op->erase();
293     return OpTy();
294   }
295 
296   /// This is implemented to insert the specified operation and serves as a
297   /// notification hook for rewriters that want to know about new operations.
298   virtual Operation *insert(Operation *op) = 0;
299 
300   /// Move the blocks that belong to "region" before the given position in
301   /// another region "parent". The two regions must be different. The caller
302   /// is responsible for creating or updating the operation transferring flow
303   /// of control to the region and passing it the correct block arguments.
304   virtual void inlineRegionBefore(Region &region, Region &parent,
305                                   Region::iterator before);
306   void inlineRegionBefore(Region &region, Block *before);
307 
308   /// Clone the blocks that belong to "region" before the given position in
309   /// another region "parent". The two regions must be different. The caller is
310   /// responsible for creating or updating the operation transferring flow of
311   /// control to the region and passing it the correct block arguments.
312   virtual void cloneRegionBefore(Region &region, Region &parent,
313                                  Region::iterator before,
314                                  BlockAndValueMapping &mapping);
315   void cloneRegionBefore(Region &region, Region &parent,
316                          Region::iterator before);
317   void cloneRegionBefore(Region &region, Block *before);
318 
319   /// This method performs the final replacement for a pattern, where the
320   /// results of the operation are updated to use the specified list of SSA
321   /// values.  In addition to replacing and removing the specified operation,
322   /// clients can specify a list of other nodes that this replacement may make
323   /// (perhaps transitively) dead.  If any of those values are dead, this will
324   /// remove them as well.
325   virtual void replaceOp(Operation *op, ValueRange newValues,
326                          ValueRange valuesToRemoveIfDead);
replaceOp(Operation * op,ValueRange newValues)327   void replaceOp(Operation *op, ValueRange newValues) {
328     replaceOp(op, newValues, llvm::None);
329   }
330 
331   /// Replaces the result op with a new op that is created without verification.
332   /// The result values of the two ops must be the same types.
333   template <typename OpTy, typename... Args>
replaceOpWithNewOp(Operation * op,Args &&...args)334   void replaceOpWithNewOp(Operation *op, Args &&... args) {
335     auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
336     replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {});
337   }
338 
339   /// Replaces the result op with a new op that is created without verification.
340   /// The result values of the two ops must be the same types.  This allows
341   /// specifying a list of ops that may be removed if dead.
342   template <typename OpTy, typename... Args>
replaceOpWithNewOp(ValueRange valuesToRemoveIfDead,Operation * op,Args &&...args)343   void replaceOpWithNewOp(ValueRange valuesToRemoveIfDead, Operation *op,
344                           Args &&... args) {
345     auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
346     replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(),
347                                     valuesToRemoveIfDead);
348   }
349 
350   /// This method erases an operation that is known to have no uses.
351   virtual void eraseOp(Operation *op);
352 
353   /// Merge the operations of block 'source' into the end of block 'dest'.
354   /// 'source's predecessors must either be empty or only contain 'dest`.
355   /// 'argValues' is used to replace the block arguments of 'source' after
356   /// merging.
357   virtual void mergeBlocks(Block *source, Block *dest,
358                            ValueRange argValues = llvm::None);
359 
360   /// Split the operations starting at "before" (inclusive) out of the given
361   /// block into a new block, and return it.
362   virtual Block *splitBlock(Block *block, Block::iterator before);
363 
364   /// This method is used to notify the rewriter that an in-place operation
365   /// modification is about to happen. A call to this function *must* be
366   /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
367   /// This is a minor efficiency win (it avoids creating a new operation and
368   /// removing the old one) but also often allows simpler code in the client.
startRootUpdate(Operation * op)369   virtual void startRootUpdate(Operation *op) {}
370 
371   /// This method is used to signal the end of a root update on the given
372   /// operation. This can only be called on operations that were provided to a
373   /// call to `startRootUpdate`.
finalizeRootUpdate(Operation * op)374   virtual void finalizeRootUpdate(Operation *op) {}
375 
376   /// This method cancels a pending root update. This can only be called on
377   /// operations that were provided to a call to `startRootUpdate`.
cancelRootUpdate(Operation * op)378   virtual void cancelRootUpdate(Operation *op) {}
379 
380   /// This method is a utility wrapper around a root update of an operation. It
381   /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
382   /// callable.
383   template <typename CallableT>
updateRootInPlace(Operation * root,CallableT && callable)384   void updateRootInPlace(Operation *root, CallableT &&callable) {
385     startRootUpdate(root);
386     callable();
387     finalizeRootUpdate(root);
388   }
389 
390 protected:
PatternRewriter(MLIRContext * ctx)391   explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {}
392   virtual ~PatternRewriter();
393 
394   // These are the callback methods that subclasses can choose to implement if
395   // they would like to be notified about certain types of mutations.
396 
397   /// Notify the pattern rewriter that the specified operation is about to be
398   /// replaced with another set of operations.  This is called before the uses
399   /// of the operation have been changed.
notifyRootReplaced(Operation * op)400   virtual void notifyRootReplaced(Operation *op) {}
401 
402   /// This is called on an operation that a pattern match is removing, right
403   /// before the operation is deleted.  At this point, the operation has zero
404   /// uses.
notifyOperationRemoved(Operation * op)405   virtual void notifyOperationRemoved(Operation *op) {}
406 
407 private:
408   /// op and newOp are known to have the same number of results, replace the
409   /// uses of op with uses of newOp
410   void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp,
411                                        ValueRange valuesToRemoveIfDead);
412 };
413 
414 //===----------------------------------------------------------------------===//
415 // Pattern-driven rewriters
416 //===----------------------------------------------------------------------===//
417 
418 class OwningRewritePatternList {
419   using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
420 
421 public:
begin()422   PatternListT::iterator begin() { return patterns.begin(); }
end()423   PatternListT::iterator end() { return patterns.end(); }
begin()424   PatternListT::const_iterator begin() const { return patterns.begin(); }
end()425   PatternListT::const_iterator end() const { return patterns.end(); }
clear()426   void clear() { patterns.clear(); }
427 
428   //===--------------------------------------------------------------------===//
429   // Pattern Insertion
430   //===--------------------------------------------------------------------===//
431 
432   /// Add an instance of each of the pattern types 'Ts' to the pattern list with
433   /// the given arguments.
434   /// Note: ConstructorArg is necessary here to separate the two variadic lists.
435   template <typename... Ts, typename ConstructorArg,
436             typename... ConstructorArgs,
437             typename = std::enable_if_t<sizeof...(Ts) != 0>>
insert(ConstructorArg && arg,ConstructorArgs &&...args)438   void insert(ConstructorArg &&arg, ConstructorArgs &&... args) {
439     // The following expands a call to emplace_back for each of the pattern
440     // types 'Ts'. This magic is necessary due to a limitation in the places
441     // that a parameter pack can be expanded in c++11.
442     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
443     using dummy = int[];
444     (void)dummy{
445         0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...};
446   }
447 
448 private:
449   PatternListT patterns;
450 };
451 
452 /// This class manages optimization and execution of a group of rewrite
453 /// patterns, providing an API for finding and applying, the best match against
454 /// a given node.
455 ///
456 class RewritePatternMatcher {
457 public:
458   /// Create a RewritePatternMatcher with the specified set of patterns.
459   explicit RewritePatternMatcher(const OwningRewritePatternList &patterns);
460 
461   /// Try to match the given operation to a pattern and rewrite it. Return
462   /// true if any pattern matches.
463   bool matchAndRewrite(Operation *op, PatternRewriter &rewriter);
464 
465 private:
466   RewritePatternMatcher(const RewritePatternMatcher &) = delete;
467   void operator=(const RewritePatternMatcher &) = delete;
468 
469   /// The group of patterns that are matched for optimization through this
470   /// matcher.
471   std::vector<RewritePattern *> patterns;
472 };
473 
474 /// Rewrite the regions of the specified operation, which must be isolated from
475 /// above, by repeatedly applying the highest benefit patterns in a greedy
476 /// work-list driven manner. Return true if no more patterns can be matched in
477 /// the result operation regions.
478 /// Note: This does not apply patterns to the top-level operation itself.
479 /// Note: These methods also perform folding and simple dead-code elimination
480 ///       before attempting to match any of the provided patterns.
481 ///
482 bool applyPatternsGreedily(Operation *op,
483                            const OwningRewritePatternList &patterns);
484 /// Rewrite the given regions, which must be isolated from above.
485 bool applyPatternsGreedily(MutableArrayRef<Region> regions,
486                            const OwningRewritePatternList &patterns);
487 } // end namespace mlir
488 
489 #endif // MLIR_PATTERN_MATCH_H
490