1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 *
22 * \file eliminate_common_subexpr.cc
23 * \brief Combine common subexpressions.
24 *
25 * This is an optimization pass that eliminates common subexpressions. During the pass, it tries
26 * to replace an expression with a previously appeared expression with the same input and
27 * attributes. The fskip callback argument allows us to skip specific expressions.
28 */
29 #include <tvm/relay/analysis.h>
30 #include <tvm/relay/expr_functor.h>
31 #include <tvm/relay/transform.h>
32 #include <unordered_map>
33 #include "./pattern_util.h"
34
35 namespace tvm {
36 namespace relay {
37
38 class CommonSubexprEliminator : public ExprMutator {
39 public:
CommonSubexprEliminator(runtime::TypedPackedFunc<bool (Expr)> fskip)40 explicit CommonSubexprEliminator(runtime::TypedPackedFunc<bool(Expr)> fskip): fskip_(fskip) {}
41
VisitExpr_(const CallNode * call)42 Expr VisitExpr_(const CallNode* call) final {
43 static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
44 Expr new_expr = ExprMutator::VisitExpr_(call);
45 const CallNode* new_call = new_expr.as<CallNode>();
46 CHECK(new_call);
47 const OpNode* op = new_call->op.as<OpNode>();
48 AttrsEqual attrs_equal;
49
50 if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
51 return new_expr;
52 }
53 if (fskip_ != nullptr && fskip_(new_expr)) {
54 return new_expr;
55 }
56
57 auto it = expr_map_.find(new_call->op);
58 if (it != expr_map_.end()) {
59 for (const CallNode* candidate : it->second) {
60 bool is_equivalent = true;
61 if (!attrs_equal(new_call->attrs, candidate->attrs)) {
62 continue;
63 }
64 for (size_t i = 0; i < new_call->args.size(); i++) {
65 if (!new_call->args[i].same_as(candidate->args[i]) &&
66 !IsEqualScalar(new_call->args[i], candidate->args[i])) {
67 is_equivalent = false;
68 break;
69 }
70 }
71 if (!is_equivalent) continue;
72 return GetRef<Call>(candidate);
73 }
74 }
75 expr_map_[new_call->op].push_back(new_call);
76 return new_expr;
77 }
78
79 std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> expr_map_;
80 runtime::TypedPackedFunc<bool(Expr)> fskip_;
81 };
82
EliminateCommonSubexpr(const Expr & expr,PackedFunc callback)83 Expr EliminateCommonSubexpr(const Expr& expr, PackedFunc callback) {
84 return CommonSubexprEliminator(callback)(expr);
85 }
86
87 namespace transform {
88
EliminateCommonSubexpr(PackedFunc fskip)89 Pass EliminateCommonSubexpr(PackedFunc fskip) {
90 runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
91 [=](Function f, Module m, PassContext pc) {
92 return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
93 };
94 return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr",
95 {ir::StringImm::make("InferType")});
96 }
97
98 TVM_REGISTER_API("relay._transform.EliminateCommonSubexpr")
99 .set_body_typed(EliminateCommonSubexpr);
100
101 } // namespace transform
102
103 } // namespace relay
104 } // namespace tvm
105