1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *
22  * \file eliminate_common_subexpr.cc
23  * \brief Combine common subexpressions.
24  *
25  * This is an optimization pass that eliminates common subexpressions. During the pass, it tries
26  * to replace an expression with a previously appeared expression with the same input and
27  * attributes. The fskip callback argument allows us to skip specific expressions.
28  */
29 #include <tvm/relay/analysis.h>
30 #include <tvm/relay/expr_functor.h>
31 #include <tvm/relay/transform.h>
32 
33 #include <unordered_map>
34 
35 #include "pattern_util.h"
36 
37 namespace tvm {
38 namespace relay {
39 
40 class CommonSubexprEliminator : public MixedModeMutator {
41  public:
CommonSubexprEliminator(runtime::TypedPackedFunc<bool (Expr)> fskip)42   explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip) : fskip_(fskip) {}
43 
Rewrite_(const CallNode * call,const Expr & post)44   Expr Rewrite_(const CallNode* call, const Expr& post) final {
45     static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
46     Expr new_expr = post;
47     const CallNode* new_call = new_expr.as<CallNode>();
48     CHECK(new_call);
49     const OpNode* op = new_call->op.as<OpNode>();
50     StructuralEqual attrs_equal;
51 
52     if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
53       return new_expr;
54     }
55     if (fskip_ != nullptr && fskip_(new_expr)) {
56       return new_expr;
57     }
58 
59     auto it = expr_map_.find(new_call->op);
60     if (it != expr_map_.end()) {
61       for (const Expr& candidate_expr : it->second) {
62         if (const CallNode* candidate = candidate_expr.as<CallNode>()) {
63           bool is_equivalent = true;
64           if (!attrs_equal(new_call->attrs, candidate->attrs)) {
65             continue;
66           }
67           for (size_t i = 0; i < new_call->args.size(); i++) {
68             if (!new_call->args[i].same_as(candidate->args[i]) &&
69                 !IsEqualScalar(new_call->args[i], candidate->args[i])) {
70               is_equivalent = false;
71               break;
72             }
73           }
74           if (!is_equivalent) continue;
75           return GetRef<Call>(candidate);
76         }
77       }
78     }
79     expr_map_[new_call->op].push_back(new_expr);
80     return new_expr;
81   }
82 
Rewrite_(const TupleGetItemNode * op,const Expr & post)83   Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
84     Expr new_expr = post;
85     const TupleGetItemNode* new_tuple_item = new_expr.as<TupleGetItemNode>();
86     CHECK(new_tuple_item);
87 
88     if (fskip_ != nullptr && fskip_(new_expr)) {
89       return new_expr;
90     }
91 
92     auto it = expr_map_.find(new_tuple_item->tuple);
93     if (it != expr_map_.end()) {
94       for (const Expr& candidate_expr : it->second) {
95         if (const TupleGetItemNode* candidate = candidate_expr.as<TupleGetItemNode>()) {
96           if (new_tuple_item->index == candidate->index) {
97             return GetRef<Expr>(candidate);
98           }
99         }
100       }
101     }
102     expr_map_[new_tuple_item->tuple].push_back(new_expr);
103     return new_expr;
104   }
105 
106   std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_;
107   runtime::TypedPackedFunc<bool(Expr)> fskip_;
108 };
109 
EliminateCommonSubexpr(const Expr & expr,PackedFunc callback)110 Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
111   return CommonSubexprEliminator(callback)(expr);
112 }
113 
114 namespace transform {
115 
EliminateCommonSubexpr(PackedFunc fskip)116 Pass EliminateCommonSubexpr(PackedFunc fskip) {
117   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
118       [=](Function f, IRModule m, PassContext pc) {
119         return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
120       };
121   return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
122 }
123 
124 TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
125     .set_body_typed(EliminateCommonSubexpr);
126 
127 }  // namespace transform
128 
129 }  // namespace relay
130 }  // namespace tvm
131