1 //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a set of simple combiners for optimizing operations in
10 // the Toy dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "toy/Dialect.h"
17 #include <numeric>
18 using namespace mlir;
19 using namespace toy;
20 
21 namespace {
22 /// Include the patterns defined in the Declarative Rewrite framework.
23 #include "ToyCombine.inc"
24 } // end anonymous namespace
25 
26 /// Fold simple cast operations that return the same type as the input.
fold(ArrayRef<Attribute> operands)27 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
28   return mlir::impl::foldCastOp(*this);
29 }
30 
31 /// This is an example of a c++ rewrite pattern for the TransposeOp. It
32 /// optimizes the following scenario: transpose(transpose(x)) -> x
33 struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
34   /// We register this pattern to match every toy.transpose in the IR.
35   /// The "benefit" is used by the framework to order the patterns and process
36   /// them in order of profitability.
SimplifyRedundantTransposeSimplifyRedundantTranspose37   SimplifyRedundantTranspose(mlir::MLIRContext *context)
38       : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
39 
40   /// This method attempts to match a pattern and rewrite it. The rewriter
41   /// argument is the orchestrator of the sequence of rewrites. The pattern is
42   /// expected to interact with it to perform any changes to the IR from here.
43   mlir::LogicalResult
matchAndRewriteSimplifyRedundantTranspose44   matchAndRewrite(TransposeOp op,
45                   mlir::PatternRewriter &rewriter) const override {
46     // Look through the input of the current transpose.
47     mlir::Value transposeInput = op.getOperand();
48     TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
49 
50     // Input defined by another transpose? If not, no match.
51     if (!transposeInputOp)
52       return failure();
53 
54     // Otherwise, we have a redundant transpose. Use the rewriter.
55     rewriter.replaceOp(op, {transposeInputOp.getOperand()});
56     return success();
57   }
58 };
59 
60 /// Register our patterns as "canonicalization" patterns on the TransposeOp so
61 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)62 void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
63                                               MLIRContext *context) {
64   results.insert<SimplifyRedundantTranspose>(context);
65 }
66 
67 /// Register our patterns as "canonicalization" patterns on the ReshapeOp so
68 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)69 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
70                                             MLIRContext *context) {
71   results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
72                  FoldConstantReshapeOptPattern>(context);
73 }
74