1 //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a set of simple combiners for optimizing operations in
10 // the Toy dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "toy/Dialect.h"
17 #include <numeric>
18 using namespace mlir;
19 using namespace toy;
20 
21 namespace {
22 /// Include the patterns defined in the Declarative Rewrite framework.
23 #include "ToyCombine.inc"
24 } // end anonymous namespace
25 
26 /// Fold simple cast operations that return the same type as the input.
fold(ArrayRef<Attribute> operands)27 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
28   return mlir::impl::foldCastOp(*this);
29 }
30 
31 /// Fold constants.
fold(ArrayRef<Attribute> operands)32 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
33 
34 /// Fold struct constants.
fold(ArrayRef<Attribute> operands)35 OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
36   return value();
37 }
38 
39 /// Fold simple struct access operations that access into a constant.
fold(ArrayRef<Attribute> operands)40 OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
41   auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
42   if (!structAttr)
43     return nullptr;
44 
45   size_t elementIndex = index().getZExtValue();
46   return structAttr[elementIndex];
47 }
48 
49 /// This is an example of a c++ rewrite pattern for the TransposeOp. It
50 /// optimizes the following scenario: transpose(transpose(x)) -> x
51 struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
52   /// We register this pattern to match every toy.transpose in the IR.
53   /// The "benefit" is used by the framework to order the patterns and process
54   /// them in order of profitability.
SimplifyRedundantTransposeSimplifyRedundantTranspose55   SimplifyRedundantTranspose(mlir::MLIRContext *context)
56       : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
57 
58   /// This method attempts to match a pattern and rewrite it. The rewriter
59   /// argument is the orchestrator of the sequence of rewrites. The pattern is
60   /// expected to interact with it to perform any changes to the IR from here.
61   mlir::LogicalResult
matchAndRewriteSimplifyRedundantTranspose62   matchAndRewrite(TransposeOp op,
63                   mlir::PatternRewriter &rewriter) const override {
64     // Look through the input of the current transpose.
65     mlir::Value transposeInput = op.getOperand();
66     TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
67 
68     // Input defined by another transpose? If not, no match.
69     if (!transposeInputOp)
70       return failure();
71 
72     // Otherwise, we have a redundant transpose. Use the rewriter.
73     rewriter.replaceOp(op, {transposeInputOp.getOperand()});
74     return success();
75   }
76 };
77 
78 /// Register our patterns as "canonicalization" patterns on the TransposeOp so
79 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)80 void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
81                                               MLIRContext *context) {
82   results.insert<SimplifyRedundantTranspose>(context);
83 }
84 
85 /// Register our patterns as "canonicalization" patterns on the ReshapeOp so
86 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)87 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
88                                             MLIRContext *context) {
89   results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
90                  FoldConstantReshapeOptPattern>(context);
91 }
92