1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tvm/relay/backend/vm/lambda_lift.cc
22 * \brief Lift all local functions into global functions.
23 */
24
25 #include <tvm/relay/expr.h>
26 #include <tvm/relay/expr_functor.h>
27 #include <tvm/logging.h>
28 #include <tvm/relay/analysis.h>
29 #include <tvm/relay/transform.h>
30 #include <tvm/runtime/vm.h>
31 #include <iostream>
32 #include <vector>
33
34 using namespace tvm::runtime;
35
36 namespace tvm {
37 namespace relay {
38 namespace vm {
39
40 static const char* kIsClosure = "IsClosure";
41
GenerateName(const Function & func)42 inline std::string GenerateName(const Function& func) {
43 size_t hash = StructuralHash()(func);
44 return std::string("lifted_name") + std::to_string(hash);
45 }
46
IsClosure(const Function & func)47 bool IsClosure(const Function& func) {
48 NodeRef res = FunctionGetAttr(func, kIsClosure);
49 const ir::IntImm* pval = res.as<ir::IntImm>();
50 return pval && pval->value != 0;
51 }
52
MarkClosure(const Function & func)53 Function MarkClosure(const Function& func) {
54 return FunctionSetAttr(func, kIsClosure, tvm::Integer(1));
55 }
56
57 /* The goal of this class is to lift out any nested functions into top-level
58 * functions.
59 *
60 * We will lift a function out into a global which takes the set of the free
61 * vars and then return the new created function.
62 */
63 class LambdaLifter : public ExprMutator {
64 public:
LambdaLifter(const Module & module)65 explicit LambdaLifter(const Module& module) : module_(module) {}
66
VisitExpr_(const LetNode * let_node)67 Expr VisitExpr_(const LetNode* let_node) final {
68 bool is_lambda = false;
69 if (auto func = let_node->value.as<FunctionNode>()) {
70 if (!func->IsPrimitive()) {
71 is_lambda = true;
72 letrec_.push_back(let_node->var);
73 }
74 }
75 auto value = VisitExpr(let_node->value);
76 if (is_lambda) {
77 letrec_.pop_back();
78 }
79 auto body = VisitExpr(let_node->body);
80 return LetNode::make(let_node->var, value, body);
81 }
82
VisitExpr_(const CallNode * call_node)83 Expr VisitExpr_(const CallNode* call_node) final {
84 auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
85 if (auto var_node = call_node->op.as<VarNode>()) {
86 auto var = GetRef<Var>(var_node);
87 if (!letrec_.empty() && var == letrec_.back()) {
88 auto it = lambda_map_.find(var);
89 CHECK(it != lambda_map_.end());
90 return CallNode::make(it->second, call->args, call_node->attrs,
91 call_node->type_args);
92 }
93 }
94 return std::move(call);
95 }
96
VisitExpr_(const FunctionNode * func_node)97 Expr VisitExpr_(const FunctionNode* func_node) final {
98 auto func = GetRef<Function>(func_node);
99
100 // We should not transform primitive functions.
101 if (func->IsPrimitive()) {
102 return std::move(func);
103 }
104
105 auto name = GenerateName(func);
106 auto global = GlobalVarNode::make(name);
107 auto free_vars = FreeVars(func);
108 auto free_type_vars = FreeTypeVars(func, module_);
109
110 Array<Var> captured_vars;
111 bool recursive = false;
112 for (const auto& var : free_vars) {
113 if (!letrec_.empty() && var == letrec_.back()) {
114 recursive = true;
115 continue;
116 }
117 captured_vars.push_back(var);
118 }
119 if (recursive) {
120 if (!captured_vars.empty()) {
121 Array<Expr> fvs;
122 for (auto fv : captured_vars) {
123 fvs.push_back(fv);
124 }
125 lambda_map_.emplace(letrec_.back(), CallNode::make(global, fvs));
126 } else {
127 lambda_map_.emplace(letrec_.back(), global);
128 }
129 }
130 auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
131
132 // When performing this optimization there are two cases.
133 //
134 // The first case in which we have no free variables
135 // we can just lift the function into the global
136 // environment without needing to allocate a closure.
137 //
138 //
139 // The second case requires that we generate a special
140 // function which makes a distinction between allocating
141 // a closure, and then the code for the closure.
142 //
143 // We represent a closure allocation by lifting the
144 // closure to a global function which takes its
145 // captured arguments and then directly returns
146 // the function representing the closure's code.
147 //
148 // When we generate code later on a call to the "outer"
149 // function marked as a closure is used to emit allocation
150 // code for the closure's environment.
151 //
152 // The "inner" function should be used to generate the
153 // code for the closure.
154 Function lifted_func;
155 if (captured_vars.size() == 0 && free_type_vars.size() == 0) {
156 lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params);
157 } else {
158 lifted_func =
159 FunctionNode::make(captured_vars, body, func->func_type_annotation(), free_type_vars);
160 lifted_func = MarkClosure(lifted_func);
161 }
162
163 CHECK(lifted_func.defined());
164
165
166 if (module_->ContainGlobalVar(name)) {
167 const auto existing_func = module_->Lookup(name);
168 CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision";
169 // If an identical function already exists, use its global var.
170 global = module_->GetGlobalVar(name);
171 } else {
172 // Add the lifted function to the module.
173 module_->Add(global, lifted_func);
174 }
175
176 if (captured_vars.size() == 0) {
177 return std::move(global);
178 } else {
179 // If we need to allocate a closure,
180 // we pass the variables in its environment here.
181 Array<Expr> fvs;
182 for (auto fv : captured_vars) {
183 fvs.push_back(fv);
184 }
185 return CallNode::make(global, fvs);
186 }
187 }
188
Lift()189 Module Lift() {
190 // There is an ordering bug here.
191 auto glob_funcs = module_->functions;
192 for (auto pair : glob_funcs) {
193 auto func = pair.second;
194 func = FunctionNode::make(func->params,
195 VisitExpr(func->body),
196 func->ret_type,
197 func->type_params,
198 func->attrs);
199 module_->Add(pair.first, func, true);
200 }
201 return module_;
202 }
203
204 private:
205 std::unordered_map<Var, Expr, NodeHash, NodeEqual> lambda_map_;
206 std::vector<Var> letrec_;
207 Module module_;
208 };
209
210 } // namespace vm
211
212 namespace transform {
213
LambdaLift()214 Pass LambdaLift() {
215 runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
216 [=](Module m, PassContext pc) {
217 return relay::vm::LambdaLifter(m).Lift();
218 };
219 return CreateModulePass(pass_func, 1, "LambdaLift", {});
220 }
221
222 TVM_REGISTER_API("relay._transform.LambdaLift")
223 .set_body_typed(LambdaLift);
224
225 } // namespace transform
226
227 } // namespace relay
228 } // namespace tvm
229