1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *
22  * \file eliminate_common_subexpr.cc
23  * \brief Combine common subexpressions.
24  *
25  * This is an optimization pass that eliminates common subexpressions. During the pass, it tries
26  * to replace an expression with a previously appeared expression with the same input and
27  * attributes. The fskip callback argument allows us to skip specific expressions.
28  */
29 #include <tvm/relay/analysis.h>
30 #include <tvm/relay/expr_functor.h>
31 #include <tvm/relay/transform.h>
32 #include <unordered_map>
33 #include "./pattern_util.h"
34 
35 namespace tvm {
36 namespace relay {
37 
38 class CommonSubexprEliminator : public ExprMutator {
39  public:
CommonSubexprEliminator(runtime::TypedPackedFunc<bool (Expr)> fskip)40   explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip): fskip_(fskip) {}
41 
VisitExpr_(const CallNode * call)42   Expr VisitExpr_(const CallNode* call) final {
43     static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
44     Expr new_expr = ExprMutator::VisitExpr_(call);
45     const CallNode* new_call = new_expr.as<CallNode>();
46     CHECK(new_call);
47     const OpNode* op = new_call->op.as<OpNode>();
48     AttrsEqual attrs_equal;
49 
50     if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
51       return new_expr;
52     }
53     if (fskip_ != nullptr && fskip_(new_expr)) {
54       return new_expr;
55     }
56 
57     auto it = expr_map_.find(new_call->op);
58     if (it != expr_map_.end()) {
59       for (const CallNode* candidate : it->second) {
60         bool is_equivalent = true;
61         if (!attrs_equal(new_call->attrs, candidate->attrs)) {
62           continue;
63         }
64         for (size_t i = 0; i < new_call->args.size(); i++) {
65           if (!new_call->args[i].same_as(candidate->args[i]) &&
66               !IsEqualScalar(new_call->args[i], candidate->args[i])) {
67             is_equivalent = false;
68             break;
69           }
70         }
71         if (!is_equivalent) continue;
72         return GetRef<Call>(candidate);
73       }
74     }
75     expr_map_[new_call->op].push_back(new_call);
76     return new_expr;
77   }
78 
79   std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> expr_map_;
80   runtime::TypedPackedFunc<bool(Expr)> fskip_;
81 };
82 
EliminateCommonSubexpr(const Expr & expr,PackedFunc callback)83 Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
84   return CommonSubexprEliminator(callback)(expr);
85 }
86 
87 namespace transform {
88 
EliminateCommonSubexpr(PackedFunc fskip)89 Pass EliminateCommonSubexpr(PackedFunc fskip) {
90   runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
91     [=](Function f, Module m, PassContext pc) {
92       return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
93   };
94   return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
95                             {ir::StringImm::make("InferType")});
96 }
97 
98 TVM_REGISTER_API("relay._transform.EliminateCommonSubexpr")
99 .set_body_typed(EliminateCommonSubexpr);
100 
101 }  // namespace transform
102 
103 }  // namespace relay
104 }  // namespace tvm
105