1 //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a set of simple combiners for optimizing operations in
10 // the Toy dialect.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "toy/Dialect.h"
17 #include <numeric>
18 using namespace mlir;
19 using namespace toy;
20
21 namespace {
22 /// Include the patterns defined in the Declarative Rewrite framework.
23 #include "ToyCombine.inc"
24 } // end anonymous namespace
25
26 /// Fold simple cast operations that return the same type as the input.
fold(ArrayRef<Attribute> operands)27 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
28 return mlir::impl::foldCastOp(*this);
29 }
30
31 /// Fold constants.
fold(ArrayRef<Attribute> operands)32 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { return value(); }
33
34 /// Fold struct constants.
fold(ArrayRef<Attribute> operands)35 OpFoldResult StructConstantOp::fold(ArrayRef<Attribute> operands) {
36 return value();
37 }
38
39 /// Fold simple struct access operations that access into a constant.
fold(ArrayRef<Attribute> operands)40 OpFoldResult StructAccessOp::fold(ArrayRef<Attribute> operands) {
41 auto structAttr = operands.front().dyn_cast_or_null<mlir::ArrayAttr>();
42 if (!structAttr)
43 return nullptr;
44
45 size_t elementIndex = index().getZExtValue();
46 return structAttr[elementIndex];
47 }
48
49 /// This is an example of a c++ rewrite pattern for the TransposeOp. It
50 /// optimizes the following scenario: transpose(transpose(x)) -> x
51 struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
52 /// We register this pattern to match every toy.transpose in the IR.
53 /// The "benefit" is used by the framework to order the patterns and process
54 /// them in order of profitability.
SimplifyRedundantTransposeSimplifyRedundantTranspose55 SimplifyRedundantTranspose(mlir::MLIRContext *context)
56 : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
57
58 /// This method attempts to match a pattern and rewrite it. The rewriter
59 /// argument is the orchestrator of the sequence of rewrites. The pattern is
60 /// expected to interact with it to perform any changes to the IR from here.
61 mlir::LogicalResult
matchAndRewriteSimplifyRedundantTranspose62 matchAndRewrite(TransposeOp op,
63 mlir::PatternRewriter &rewriter) const override {
64 // Look through the input of the current transpose.
65 mlir::Value transposeInput = op.getOperand();
66 TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
67
68 // Input defined by another transpose? If not, no match.
69 if (!transposeInputOp)
70 return failure();
71
72 // Otherwise, we have a redundant transpose. Use the rewriter.
73 rewriter.replaceOp(op, {transposeInputOp.getOperand()});
74 return success();
75 }
76 };
77
78 /// Register our patterns as "canonicalization" patterns on the TransposeOp so
79 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)80 void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
81 MLIRContext *context) {
82 results.insert<SimplifyRedundantTranspose>(context);
83 }
84
85 /// Register our patterns as "canonicalization" patterns on the ReshapeOp so
86 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)87 void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
88 MLIRContext *context) {
89 results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
90 FoldConstantReshapeOptPattern>(context);
91 }
92