1 //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a set of simple combiners for optimizing operations in
10 // the Toy dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "toy/Dialect.h"
17 #include <numeric>
18 using namespace mlir;
19 using namespace toy;
20 
21 namespace {
22 /// Include the patterns defined in the Declarative Rewrite framework.
23 #include "ToyCombine.inc"
24 } // end anonymous namespace
25 
26 /// Fold simple cast operations that return the same type as the input.
fold(ArrayRef<Attribute> operands)27 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
28   return mlir::impl::foldCastOp(*this);
29 }
30 
31 /// This is an example of a c++ rewrite pattern for the TransposeOp. It
32 /// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
33 struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
34   /// We register this pattern to match every toy.transpose in the IR.
35   /// The "benefit" is used by the framework to order the patterns and process
36   /// them in order of profitability.
SimplifyRedundantTransposeSimplifyRedundantTranspose37   SimplifyRedundantTranspose(mlir::MLIRContext *context)
38       : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
39 
40   /// This method attempts to match a pattern and rewrite it. The rewriter
41   /// argument is the orchestrator of the sequence of rewrites. The pattern is
42   /// expected to interact with it to perform any changes to the IR from here.
43   mlir::PatternMatchResult
matchAndRewriteSimplifyRedundantTranspose44   matchAndRewrite(TransposeOp op,
45                   mlir::PatternRewriter &rewriter) const override {
46     // Look through the input of the current transpose.
47     mlir::Value transposeInput = op.getOperand();
48     TransposeOp transposeInputOp =
49         llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
50 
51     // If the input is defined by another Transpose, bingo!
52     if (!transposeInputOp)
53       return matchFailure();
54 
55     // Use the rewriter to perform the replacement.
56     rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
57     return matchSuccess();
58   }
59 };
60 
61 /// Register our patterns as "canonicalization" patterns on the TransposeOp so
62 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)63 void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
64                                               MLIRContext *context) {
65   results.insert<SimplifyRedundantTranspose>(context);
66 }
67 
68 /// Register our patterns as "canonicalization" patterns on the ReshapeOp so
69 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)70 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
71                                             MLIRContext *context) {
72   results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
73                  FoldConstantReshapeOptPattern>(context);
74 }
75