1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_PATTERNMATCHER_H 10 #define MLIR_PATTERNMATCHER_H 11 12 #include "mlir/IR/Builders.h" 13 14 namespace mlir { 15 16 class PatternRewriter; 17 18 //===----------------------------------------------------------------------===// 19 // PatternBenefit class 20 //===----------------------------------------------------------------------===// 21 22 /// This class represents the benefit of a pattern match in a unitless scheme 23 /// that ranges from 0 (very little benefit) to 65K. The most common unit to 24 /// use here is the "number of operations matched" by the pattern. 25 /// 26 /// This also has a sentinel representation that can be used for patterns that 27 /// fail to match. 28 /// 29 class PatternBenefit { 30 enum { ImpossibleToMatchSentinel = 65535 }; 31 32 public: 33 /*implicit*/ PatternBenefit(unsigned benefit); 34 PatternBenefit(const PatternBenefit &) = default; 35 PatternBenefit &operator=(const PatternBenefit &) = default; 36 impossibleToMatch()37 static PatternBenefit impossibleToMatch() { return PatternBenefit(); } isImpossibleToMatch()38 bool isImpossibleToMatch() const { return *this == impossibleToMatch(); } 39 40 /// If the corresponding pattern can match, return its benefit. If the 41 // corresponding pattern isImpossibleToMatch() then this aborts. 42 unsigned short getBenefit() const; 43 44 bool operator==(const PatternBenefit &rhs) const { 45 return representation == rhs.representation; 46 } 47 bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); } 48 bool operator<(const PatternBenefit &rhs) const { 49 return representation < rhs.representation; 50 } 51 52 private: PatternBenefit()53 PatternBenefit() : representation(ImpossibleToMatchSentinel) {} 54 unsigned short representation; 55 }; 56 57 /// Pattern state is used by patterns that want to maintain state between their 58 /// match and rewrite phases. Patterns can define a pattern-specific subclass 59 /// of this. 60 class PatternState { 61 public: ~PatternState()62 virtual ~PatternState() {} 63 64 protected: 65 // Must be subclassed. PatternState()66 PatternState() {} 67 }; 68 69 /// This is the type returned by a pattern match. A match failure returns a 70 /// None value. A match success returns a Some value with any state the pattern 71 /// may need to maintain (but may also be null). 72 using PatternMatchResult = Optional<std::unique_ptr<PatternState>>; 73 74 //===----------------------------------------------------------------------===// 75 // Pattern class 76 //===----------------------------------------------------------------------===// 77 78 /// Instances of Pattern can be matched against SSA IR. These matches get used 79 /// in ways dependent on their subclasses and the driver doing the matching. 80 /// For example, RewritePatterns implement a rewrite from one matched pattern 81 /// to a replacement DAG tile. 82 class Pattern { 83 public: 84 /// Return the benefit (the inverse of "cost") of matching this pattern. The 85 /// benefit of a Pattern is always static - rewrites that may have dynamic 86 /// benefit can be instantiated multiple times (different Pattern instances) 87 /// for each benefit that they may return, and be guarded by different match 88 /// condition predicates. getBenefit()89 PatternBenefit getBenefit() const { return benefit; } 90 91 /// Return the root node that this pattern matches. Patterns that can 92 /// match multiple root types are instantiated once per root. getRootKind()93 OperationName getRootKind() const { return rootKind; } 94 95 //===--------------------------------------------------------------------===// 96 // Implementation hooks for patterns to implement. 97 //===--------------------------------------------------------------------===// 98 99 /// Attempt to match against code rooted at the specified operation, 100 /// which is the same operation code as getRootKind(). On failure, this 101 /// returns a None value. On success it returns a (possibly null) 102 /// pattern-specific state wrapped in an Optional. 103 virtual PatternMatchResult match(Operation *op) const = 0; 104 ~Pattern()105 virtual ~Pattern() {} 106 107 //===--------------------------------------------------------------------===// 108 // Helper methods to simplify pattern implementations 109 //===--------------------------------------------------------------------===// 110 111 /// This method indicates that no match was found. matchFailure()112 static PatternMatchResult matchFailure() { return None; } 113 114 /// This method indicates that a match was found and has the specified cost. 115 PatternMatchResult 116 matchSuccess(std::unique_ptr<PatternState> state = {}) const { 117 return PatternMatchResult(std::move(state)); 118 } 119 120 protected: 121 /// Patterns must specify the root operation name they match against, and can 122 /// also specify the benefit of the pattern matching. 123 Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context); 124 125 private: 126 const OperationName rootKind; 127 const PatternBenefit benefit; 128 129 virtual void anchor(); 130 }; 131 132 /// RewritePattern is the common base class for all DAG to DAG replacements. 133 /// There are two possible usages of this class: 134 /// * Multi-step RewritePattern with "match" and "rewrite" 135 /// - By overloading the "match" and "rewrite" functions, the user can 136 /// separate the concerns of matching and rewriting. 137 /// * Single-step RewritePattern with "matchAndRewrite" 138 /// - By overloading the "matchAndRewrite" function, the user can perform 139 /// the rewrite in the same call as the match. This removes the need for 140 /// any PatternState. 141 /// 142 class RewritePattern : public Pattern { 143 public: 144 /// Rewrite the IR rooted at the specified operation with the result of 145 /// this pattern, generating any new operations with the specified 146 /// rewriter. If an unexpected error is encountered (an internal 147 /// compiler error), it is emitted through the normal MLIR diagnostic 148 /// hooks and the IR is left in a valid state. 149 virtual void rewrite(Operation *op, std::unique_ptr<PatternState> state, 150 PatternRewriter &rewriter) const; 151 152 /// Rewrite the IR rooted at the specified operation with the result of 153 /// this pattern, generating any new operations with the specified 154 /// builder. If an unexpected error is encountered (an internal 155 /// compiler error), it is emitted through the normal MLIR diagnostic 156 /// hooks and the IR is left in a valid state. 157 virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; 158 159 /// Attempt to match against code rooted at the specified operation, 160 /// which is the same operation code as getRootKind(). On failure, this 161 /// returns a None value. On success, it returns a (possibly null) 162 /// pattern-specific state wrapped in an Optional. This state is passed back 163 /// into the rewrite function if this match is selected. 164 PatternMatchResult match(Operation *op) const override; 165 166 /// Attempt to match against code rooted at the specified operation, 167 /// which is the same operation code as getRootKind(). If successful, this 168 /// function will automatically perform the rewrite. matchAndRewrite(Operation * op,PatternRewriter & rewriter)169 virtual PatternMatchResult matchAndRewrite(Operation *op, 170 PatternRewriter &rewriter) const { 171 if (auto matchResult = match(op)) { 172 rewrite(op, std::move(*matchResult), rewriter); 173 return matchSuccess(); 174 } 175 return matchFailure(); 176 } 177 178 /// Return a list of operations that may be generated when rewriting an 179 /// operation instance with this pattern. getGeneratedOps()180 ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; } 181 182 protected: 183 /// Patterns must specify the root operation name they match against, and can 184 /// also specify the benefit of the pattern matching. RewritePattern(StringRef rootName,PatternBenefit benefit,MLIRContext * context)185 RewritePattern(StringRef rootName, PatternBenefit benefit, 186 MLIRContext *context) 187 : Pattern(rootName, benefit, context) {} 188 /// Patterns must specify the root operation name they match against, and can 189 /// also specify the benefit of the pattern matching. They can also specify 190 /// the names of operations that may be generated during a successful rewrite. 191 RewritePattern(StringRef rootName, ArrayRef<StringRef> generatedNames, 192 PatternBenefit benefit, MLIRContext *context); 193 194 /// A list of the potential operations that may be generated when rewriting 195 /// an op with this pattern. 196 SmallVector<OperationName, 2> generatedOps; 197 }; 198 199 /// OpRewritePattern is a wrapper around RewritePattern that allows for 200 /// matching and rewriting against an instance of a derived operation class as 201 /// opposed to a raw Operation. 202 template <typename SourceOp> struct OpRewritePattern : public RewritePattern { 203 /// Patterns must specify the root operation name they match against, and can 204 /// also specify the benefit of the pattern matching. 205 OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) RewritePatternOpRewritePattern206 : RewritePattern(SourceOp::getOperationName(), benefit, context) {} 207 208 /// Wrappers around the RewritePattern methods that pass the derived op type. rewriteOpRewritePattern209 void rewrite(Operation *op, std::unique_ptr<PatternState> state, 210 PatternRewriter &rewriter) const final { 211 rewrite(cast<SourceOp>(op), std::move(state), rewriter); 212 } rewriteOpRewritePattern213 void rewrite(Operation *op, PatternRewriter &rewriter) const final { 214 rewrite(cast<SourceOp>(op), rewriter); 215 } matchOpRewritePattern216 PatternMatchResult match(Operation *op) const final { 217 return match(cast<SourceOp>(op)); 218 } matchAndRewriteOpRewritePattern219 PatternMatchResult matchAndRewrite(Operation *op, 220 PatternRewriter &rewriter) const final { 221 return matchAndRewrite(cast<SourceOp>(op), rewriter); 222 } 223 224 /// Rewrite and Match methods that operate on the SourceOp type. These must be 225 /// overridden by the derived pattern class. rewriteOpRewritePattern226 virtual void rewrite(SourceOp op, std::unique_ptr<PatternState> state, 227 PatternRewriter &rewriter) const { 228 rewrite(op, rewriter); 229 } rewriteOpRewritePattern230 virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const { 231 llvm_unreachable("must override matchAndRewrite or a rewrite method"); 232 } matchOpRewritePattern233 virtual PatternMatchResult match(SourceOp op) const { 234 llvm_unreachable("must override match or matchAndRewrite"); 235 } matchAndRewriteOpRewritePattern236 virtual PatternMatchResult matchAndRewrite(SourceOp op, 237 PatternRewriter &rewriter) const { 238 if (auto matchResult = match(op)) { 239 rewrite(op, std::move(*matchResult), rewriter); 240 return matchSuccess(); 241 } 242 return matchFailure(); 243 } 244 }; 245 246 //===----------------------------------------------------------------------===// 247 // PatternRewriter class 248 //===----------------------------------------------------------------------===// 249 250 /// This class coordinates the application of a pattern to the current function, 251 /// providing a way to create operations and keep track of what gets deleted. 252 /// 253 /// These class serves two purposes: 254 /// 1) it is the interface that patterns interact with to make mutations to the 255 /// IR they are being applied to. 256 /// 2) It is a base class that clients of the PatternMatcher use when they want 257 /// to apply patterns and observe their effects (e.g. to keep worklists or 258 /// other data structures up to date). 259 /// 260 class PatternRewriter : public OpBuilder { 261 public: 262 /// Create operation of specific op type at the current insertion point 263 /// without verifying to see if it is valid. 264 template <typename OpTy, typename... Args> create(Location location,Args...args)265 OpTy create(Location location, Args... args) { 266 OperationState state(location, OpTy::getOperationName()); 267 OpTy::build(this, state, args...); 268 auto *op = createOperation(state); 269 auto result = dyn_cast<OpTy>(op); 270 assert(result && "Builder didn't return the right type"); 271 return result; 272 } 273 274 /// Creates an operation of specific op type at the current insertion point. 275 /// If the result is an invalid op (the verifier hook fails), emit an error 276 /// and return null. 277 template <typename OpTy, typename... Args> createChecked(Location location,Args...args)278 OpTy createChecked(Location location, Args... args) { 279 OperationState state(location, OpTy::getOperationName()); 280 OpTy::build(this, state, args...); 281 auto *op = createOperation(state); 282 283 // If the Operation we produce is valid, return it. 284 if (!OpTy::verifyInvariants(op)) { 285 auto result = dyn_cast<OpTy>(op); 286 assert(result && "Builder didn't return the right type"); 287 return result; 288 } 289 290 // Otherwise, the error message got emitted. Just remove the operation 291 // we made. 292 op->erase(); 293 return OpTy(); 294 } 295 296 /// This is implemented to insert the specified operation and serves as a 297 /// notification hook for rewriters that want to know about new operations. 298 virtual Operation *insert(Operation *op) = 0; 299 300 /// Move the blocks that belong to "region" before the given position in 301 /// another region "parent". The two regions must be different. The caller 302 /// is responsible for creating or updating the operation transferring flow 303 /// of control to the region and passing it the correct block arguments. 304 virtual void inlineRegionBefore(Region ®ion, Region &parent, 305 Region::iterator before); 306 void inlineRegionBefore(Region ®ion, Block *before); 307 308 /// Clone the blocks that belong to "region" before the given position in 309 /// another region "parent". The two regions must be different. The caller is 310 /// responsible for creating or updating the operation transferring flow of 311 /// control to the region and passing it the correct block arguments. 312 virtual void cloneRegionBefore(Region ®ion, Region &parent, 313 Region::iterator before, 314 BlockAndValueMapping &mapping); 315 void cloneRegionBefore(Region ®ion, Region &parent, 316 Region::iterator before); 317 void cloneRegionBefore(Region ®ion, Block *before); 318 319 /// This method performs the final replacement for a pattern, where the 320 /// results of the operation are updated to use the specified list of SSA 321 /// values. In addition to replacing and removing the specified operation, 322 /// clients can specify a list of other nodes that this replacement may make 323 /// (perhaps transitively) dead. If any of those values are dead, this will 324 /// remove them as well. 325 virtual void replaceOp(Operation *op, ValueRange newValues, 326 ValueRange valuesToRemoveIfDead); replaceOp(Operation * op,ValueRange newValues)327 void replaceOp(Operation *op, ValueRange newValues) { 328 replaceOp(op, newValues, llvm::None); 329 } 330 331 /// Replaces the result op with a new op that is created without verification. 332 /// The result values of the two ops must be the same types. 333 template <typename OpTy, typename... Args> replaceOpWithNewOp(Operation * op,Args &&...args)334 void replaceOpWithNewOp(Operation *op, Args &&... args) { 335 auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...); 336 replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), {}); 337 } 338 339 /// Replaces the result op with a new op that is created without verification. 340 /// The result values of the two ops must be the same types. This allows 341 /// specifying a list of ops that may be removed if dead. 342 template <typename OpTy, typename... Args> replaceOpWithNewOp(ValueRange valuesToRemoveIfDead,Operation * op,Args &&...args)343 void replaceOpWithNewOp(ValueRange valuesToRemoveIfDead, Operation *op, 344 Args &&... args) { 345 auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...); 346 replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), 347 valuesToRemoveIfDead); 348 } 349 350 /// This method erases an operation that is known to have no uses. 351 virtual void eraseOp(Operation *op); 352 353 /// Merge the operations of block 'source' into the end of block 'dest'. 354 /// 'source's predecessors must either be empty or only contain 'dest`. 355 /// 'argValues' is used to replace the block arguments of 'source' after 356 /// merging. 357 virtual void mergeBlocks(Block *source, Block *dest, 358 ValueRange argValues = llvm::None); 359 360 /// Split the operations starting at "before" (inclusive) out of the given 361 /// block into a new block, and return it. 362 virtual Block *splitBlock(Block *block, Block::iterator before); 363 364 /// This method is used to notify the rewriter that an in-place operation 365 /// modification is about to happen. A call to this function *must* be 366 /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`. 367 /// This is a minor efficiency win (it avoids creating a new operation and 368 /// removing the old one) but also often allows simpler code in the client. startRootUpdate(Operation * op)369 virtual void startRootUpdate(Operation *op) {} 370 371 /// This method is used to signal the end of a root update on the given 372 /// operation. This can only be called on operations that were provided to a 373 /// call to `startRootUpdate`. finalizeRootUpdate(Operation * op)374 virtual void finalizeRootUpdate(Operation *op) {} 375 376 /// This method cancels a pending root update. This can only be called on 377 /// operations that were provided to a call to `startRootUpdate`. cancelRootUpdate(Operation * op)378 virtual void cancelRootUpdate(Operation *op) {} 379 380 /// This method is a utility wrapper around a root update of an operation. It 381 /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given 382 /// callable. 383 template <typename CallableT> updateRootInPlace(Operation * root,CallableT && callable)384 void updateRootInPlace(Operation *root, CallableT &&callable) { 385 startRootUpdate(root); 386 callable(); 387 finalizeRootUpdate(root); 388 } 389 390 protected: PatternRewriter(MLIRContext * ctx)391 explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {} 392 virtual ~PatternRewriter(); 393 394 // These are the callback methods that subclasses can choose to implement if 395 // they would like to be notified about certain types of mutations. 396 397 /// Notify the pattern rewriter that the specified operation is about to be 398 /// replaced with another set of operations. This is called before the uses 399 /// of the operation have been changed. notifyRootReplaced(Operation * op)400 virtual void notifyRootReplaced(Operation *op) {} 401 402 /// This is called on an operation that a pattern match is removing, right 403 /// before the operation is deleted. At this point, the operation has zero 404 /// uses. notifyOperationRemoved(Operation * op)405 virtual void notifyOperationRemoved(Operation *op) {} 406 407 private: 408 /// op and newOp are known to have the same number of results, replace the 409 /// uses of op with uses of newOp 410 void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp, 411 ValueRange valuesToRemoveIfDead); 412 }; 413 414 //===----------------------------------------------------------------------===// 415 // Pattern-driven rewriters 416 //===----------------------------------------------------------------------===// 417 418 class OwningRewritePatternList { 419 using PatternListT = std::vector<std::unique_ptr<RewritePattern>>; 420 421 public: begin()422 PatternListT::iterator begin() { return patterns.begin(); } end()423 PatternListT::iterator end() { return patterns.end(); } begin()424 PatternListT::const_iterator begin() const { return patterns.begin(); } end()425 PatternListT::const_iterator end() const { return patterns.end(); } clear()426 void clear() { patterns.clear(); } 427 428 //===--------------------------------------------------------------------===// 429 // Pattern Insertion 430 //===--------------------------------------------------------------------===// 431 432 /// Add an instance of each of the pattern types 'Ts' to the pattern list with 433 /// the given arguments. 434 /// Note: ConstructorArg is necessary here to separate the two variadic lists. 435 template <typename... Ts, typename ConstructorArg, 436 typename... ConstructorArgs, 437 typename = std::enable_if_t<sizeof...(Ts) != 0>> insert(ConstructorArg && arg,ConstructorArgs &&...args)438 void insert(ConstructorArg &&arg, ConstructorArgs &&... args) { 439 // The following expands a call to emplace_back for each of the pattern 440 // types 'Ts'. This magic is necessary due to a limitation in the places 441 // that a parameter pack can be expanded in c++11. 442 // FIXME: In c++17 this can be simplified by using 'fold expressions'. 443 using dummy = int[]; 444 (void)dummy{ 445 0, (patterns.emplace_back(std::make_unique<Ts>(arg, args...)), 0)...}; 446 } 447 448 private: 449 PatternListT patterns; 450 }; 451 452 /// This class manages optimization and execution of a group of rewrite 453 /// patterns, providing an API for finding and applying, the best match against 454 /// a given node. 455 /// 456 class RewritePatternMatcher { 457 public: 458 /// Create a RewritePatternMatcher with the specified set of patterns. 459 explicit RewritePatternMatcher(const OwningRewritePatternList &patterns); 460 461 /// Try to match the given operation to a pattern and rewrite it. Return 462 /// true if any pattern matches. 463 bool matchAndRewrite(Operation *op, PatternRewriter &rewriter); 464 465 private: 466 RewritePatternMatcher(const RewritePatternMatcher &) = delete; 467 void operator=(const RewritePatternMatcher &) = delete; 468 469 /// The group of patterns that are matched for optimization through this 470 /// matcher. 471 std::vector<RewritePattern *> patterns; 472 }; 473 474 /// Rewrite the regions of the specified operation, which must be isolated from 475 /// above, by repeatedly applying the highest benefit patterns in a greedy 476 /// work-list driven manner. Return true if no more patterns can be matched in 477 /// the result operation regions. 478 /// Note: This does not apply patterns to the top-level operation itself. 479 /// Note: These methods also perform folding and simple dead-code elimination 480 /// before attempting to match any of the provided patterns. 481 /// 482 bool applyPatternsGreedily(Operation *op, 483 const OwningRewritePatternList &patterns); 484 /// Rewrite the given regions, which must be isolated from above. 485 bool applyPatternsGreedily(MutableArrayRef<Region> regions, 486 const OwningRewritePatternList &patterns); 487 } // end namespace mlir 488 489 #endif // MLIR_PATTERN_MATCH_H 490