1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tvm/relay/backend/vm/lambda_lift.cc
22  * \brief Lift all local functions into global functions.
23  */
24 
25 #include <tvm/relay/expr.h>
26 #include <tvm/relay/expr_functor.h>
27 #include <tvm/logging.h>
28 #include <tvm/relay/analysis.h>
29 #include <tvm/relay/transform.h>
30 #include <tvm/runtime/vm.h>
31 #include <iostream>
32 #include <vector>
33 
34 using namespace tvm::runtime;
35 
36 namespace tvm {
37 namespace relay {
38 namespace vm {
39 
40 static const char* kIsClosure = "IsClosure";
41 
GenerateName(const Function & func)42 inline std::string GenerateName(const Function& func) {
43   size_t hash = StructuralHash()(func);
44   return std::string("lifted_name") + std::to_string(hash);
45 }
46 
IsClosure(const Function & func)47 bool IsClosure(const Function& func) {
48   NodeRef res = FunctionGetAttr(func, kIsClosure);
49   const ir::IntImm* pval = res.as<ir::IntImm>();
50   return pval && pval->value != 0;
51 }
52 
MarkClosure(const Function & func)53 Function MarkClosure(const Function& func) {
54   return FunctionSetAttr(func, kIsClosure, tvm::Integer(1));
55 }
56 
57 /* The goal of this class is to lift out any nested functions into top-level
58  * functions.
59  *
60  * We will lift a function out into a global which takes the set of the free
61  * vars and then return the new created function.
62  */
63 class LambdaLifter : public ExprMutator {
64  public:
LambdaLifter(const Module & module)65   explicit LambdaLifter(const Module& module) : module_(module) {}
66 
VisitExpr_(const LetNode * let_node)67   Expr VisitExpr_(const LetNode* let_node) final {
68     bool is_lambda = false;
69     if (auto func = let_node->value.as<FunctionNode>()) {
70       if (!func->IsPrimitive()) {
71         is_lambda = true;
72         letrec_.push_back(let_node->var);
73       }
74     }
75     auto value = VisitExpr(let_node->value);
76     if (is_lambda) {
77       letrec_.pop_back();
78     }
79     auto body = VisitExpr(let_node->body);
80     return LetNode::make(let_node->var, value, body);
81   }
82 
VisitExpr_(const CallNode * call_node)83   Expr VisitExpr_(const CallNode* call_node) final {
84     auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
85     if (auto var_node = call_node->op.as<VarNode>()) {
86       auto var = GetRef<Var>(var_node);
87       if (!letrec_.empty() && var == letrec_.back()) {
88         auto it = lambda_map_.find(var);
89         CHECK(it != lambda_map_.end());
90         return CallNode::make(it->second, call->args, call_node->attrs,
91                               call_node->type_args);
92       }
93     }
94     return std::move(call);
95   }
96 
VisitExpr_(const FunctionNode * func_node)97   Expr VisitExpr_(const FunctionNode* func_node) final {
98     auto func = GetRef<Function>(func_node);
99 
100     // We should not transform primitive functions.
101     if (func->IsPrimitive()) {
102       return std::move(func);
103     }
104 
105     auto name = GenerateName(func);
106     auto global = GlobalVarNode::make(name);
107     auto free_vars = FreeVars(func);
108     auto free_type_vars = FreeTypeVars(func, module_);
109 
110     Array<Var> captured_vars;
111     bool recursive = false;
112     for (const auto& var : free_vars) {
113       if (!letrec_.empty() && var == letrec_.back()) {
114         recursive = true;
115         continue;
116       }
117       captured_vars.push_back(var);
118     }
119     if (recursive) {
120       if (!captured_vars.empty()) {
121         Array<Expr> fvs;
122         for (auto fv : captured_vars) {
123           fvs.push_back(fv);
124         }
125         lambda_map_.emplace(letrec_.back(), CallNode::make(global, fvs));
126       } else {
127         lambda_map_.emplace(letrec_.back(), global);
128       }
129     }
130     auto body = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
131 
132     // When performing this optimization there are two cases.
133     //
134     // The first case in which we have no free variables
135     // we can just lift the function into the global
136     // environment without needing to allocate a closure.
137     //
138     //
139     // The second case requires that we generate a special
140     // function which makes a distinction between allocating
141     // a closure, and then the code for the closure.
142     //
143     // We represent a closure allocation by lifting the
144     // closure to a global function which takes its
145     // captured arguments and then directly returns
146     // the function representing the closure's code.
147     //
148     // When we generate code later on a call to the "outer"
149     // function marked as a closure is used to emit allocation
150     // code for the closure's environment.
151     //
152     // The "inner" function should be used to generate the
153     // code for the closure.
154     Function lifted_func;
155     if (captured_vars.size() == 0 && free_type_vars.size() == 0) {
156       lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params);
157     } else {
158       lifted_func =
159           FunctionNode::make(captured_vars, body, func->func_type_annotation(), free_type_vars);
160       lifted_func = MarkClosure(lifted_func);
161     }
162 
163     CHECK(lifted_func.defined());
164 
165 
166     if (module_->ContainGlobalVar(name)) {
167       const auto existing_func = module_->Lookup(name);
168       CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash collision";
169       // If an identical function already exists, use its global var.
170       global = module_->GetGlobalVar(name);
171     } else {
172       // Add the lifted function to the module.
173       module_->Add(global, lifted_func);
174     }
175 
176     if (captured_vars.size() == 0) {
177       return std::move(global);
178     } else {
179       // If we need to allocate a closure,
180       // we pass the variables in its environment here.
181       Array<Expr> fvs;
182       for (auto fv : captured_vars) {
183         fvs.push_back(fv);
184       }
185       return CallNode::make(global, fvs);
186     }
187   }
188 
Lift()189   Module Lift() {
190     // There is an ordering bug here.
191     auto glob_funcs = module_->functions;
192     for (auto pair : glob_funcs) {
193       auto func = pair.second;
194       func = FunctionNode::make(func->params,
195                                 VisitExpr(func->body),
196                                 func->ret_type,
197                                 func->type_params,
198                                 func->attrs);
199       module_->Add(pair.first, func, true);
200     }
201     return module_;
202   }
203 
204  private:
205   std::unordered_map<Var, Expr, NodeHash, NodeEqual> lambda_map_;
206   std::vector<Var> letrec_;
207   Module module_;
208 };
209 
210 }  // namespace vm
211 
212 namespace transform {
213 
LambdaLift()214 Pass LambdaLift() {
215   runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func =
216     [=](Module m, PassContext pc) {
217     return relay::vm::LambdaLifter(m).Lift();
218   };
219   return CreateModulePass(pass_func, 1, "LambdaLift", {});
220 }
221 
222 TVM_REGISTER_API("relay._transform.LambdaLift")
223 .set_body_typed(LambdaLift);
224 
225 }  // namespace transform
226 
227 }  // namespace relay
228 }  // namespace tvm
229