1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file  module.cc
22  * \brief The global module in Relay.
23  */
24 #include <tvm/ir/module.h>
25 #include <tvm/node/structural_equal.h>
26 #include <tvm/runtime/registry.h>
27 // NOTE: reverse dependency on relay.
28 // These dependencies do not happen at the interface-level,
29 // and are only used in minimum cases where they are clearly marked.
30 //
31 // Rationale: We calls into relay's analysis module to verify correctness.
32 #include <tvm/parser/parser.h>
33 #include <tvm/relay/analysis.h>
34 #include <tvm/relay/transform.h>
35 
36 #include <fstream>
37 #include <sstream>
38 #include <unordered_set>
39 
40 namespace tvm {
41 
IRModule(tvm::Map<GlobalVar,BaseFunc> functions,tvm::Map<GlobalTypeVar,TypeData> type_definitions,std::unordered_set<String> import_set)42 IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
43                    tvm::Map<GlobalTypeVar, TypeData> type_definitions,
44                    std::unordered_set<String> import_set) {
45   auto n = make_object<IRModuleNode>();
46   n->functions = std::move(functions);
47   n->type_definitions = std::move(type_definitions);
48   n->global_type_var_map_ = {};
49   n->global_var_map_ = {};
50   n->constructor_tag_map_ = {};
51   n->import_set_ = std::move(import_set);
52 
53   for (const auto& kv : n->functions) {
54     // set global var map
55     CHECK(n->global_var_map_.count(kv.first->name_hint) == 0)
56         << "Duplicate global function name " << kv.first->name_hint;
57     n->global_var_map_.Set(kv.first->name_hint, kv.first);
58   }
59 
60   for (const auto& kv : n->type_definitions) {
61     // set global typevar map
62     CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0)
63         << "Duplicate global type definition name " << kv.first->name_hint;
64     n->global_type_var_map_.Set(kv.first->name_hint, kv.first);
65     n->RegisterConstructors(kv.first, kv.second);
66   }
67   data_ = std::move(n);
68 }
69 
SEqualReduce(const IRModuleNode * other,SEqualReducer equal) const70 bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
71   if (functions.size() != other->functions.size()) return false;
72   for (const auto& kv : this->functions) {
73     if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
74     if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
75   }
76   if (type_definitions.size() != other->type_definitions.size()) return false;
77   for (const auto& kv : this->type_definitions) {
78     if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
79     if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
80   }
81   return true;
82 }
83 
SHashReduce(SHashReducer hash_reduce) const84 void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
85   using KV = std::pair<std::string, ObjectRef>;
86   // hash the functions.
87   std::vector<KV> temp;
88 
89   auto reduce_temp = [&]() {
90     // sort by the hash key of the keys.
91     std::sort(temp.begin(), temp.end(),
92               [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; });
93 
94     hash_reduce(static_cast<uint64_t>(temp.size()));
95     // hash the content
96     for (size_t i = 0; i < temp.size(); ++i) {
97       hash_reduce(temp[i].first);
98       hash_reduce(temp[i].second);
99     }
100   };
101 
102   for (const auto& kv : this->functions) {
103     temp.emplace_back(kv.first->name_hint, kv.second);
104   }
105   reduce_temp();
106 
107   temp.clear();
108   for (const auto& kv : this->type_definitions) {
109     temp.emplace_back(kv.first->name_hint, kv.second);
110   }
111   reduce_temp();
112 }
113 
ContainGlobalVar(const String & name) const114 bool IRModuleNode::ContainGlobalVar(const String& name) const {
115   return global_var_map_.find(name) != global_var_map_.end();
116 }
117 
ContainGlobalTypeVar(const String & name) const118 bool IRModuleNode::ContainGlobalTypeVar(const String& name) const {
119   return global_type_var_map_.find(name) != global_type_var_map_.end();
120 }
121 
GetGlobalVar(const String & name) const122 GlobalVar IRModuleNode::GetGlobalVar(const String& name) const {
123   auto it = global_var_map_.find(name);
124   if (it == global_var_map_.end()) {
125     std::ostringstream msg;
126     msg << "ValueError: Cannot find global var \"" << name << "\" in the Module\n"
127         << "candidates are: [";
128     int counter = 0;
129     for (auto kv : global_var_map_) {
130       if (counter++ != 0) {
131         msg << ", ";
132       }
133       msg << "\"" << kv.first << "\"";
134     }
135     msg << "]";
136     LOG(FATAL) << msg.str();
137   }
138   return (*it).second;
139 }
140 
GetGlobalVars() const141 tvm::Array<GlobalVar> IRModuleNode::GetGlobalVars() const {
142   std::vector<GlobalVar> global_vars;
143   for (const auto& pair : global_var_map_) {
144     global_vars.push_back(pair.second);
145   }
146   return tvm::Array<GlobalVar>(global_vars);
147 }
148 
GetGlobalTypeVar(const String & name) const149 GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const {
150   CHECK(global_type_var_map_.defined());
151   auto it = global_type_var_map_.find(name);
152   CHECK(it != global_type_var_map_.end())
153       << "Cannot find global type var " << name << " in the Module";
154   return (*it).second;
155 }
156 
GetConstructor(const String & adt,const String & cons) const157 Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const {
158   TypeData typeDef = this->LookupTypeDef(adt);
159   for (Constructor c : typeDef->constructors) {
160     if (cons.compare(c->name_hint) == 0) {
161       return c;
162     }
163   }
164 
165   LOG(FATAL) << adt << " does not contain constructor " << cons;
166   throw std::runtime_error("Constructor Not Found.");
167 }
168 
GetGlobalTypeVars() const169 tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const {
170   std::vector<GlobalTypeVar> global_type_vars;
171   for (const auto& pair : global_type_var_map_) {
172     global_type_vars.push_back(pair.second);
173   }
174   return tvm::Array<GlobalTypeVar>(global_type_vars);
175 }
176 
177 template <typename T>
concat(const tvm::Array<T> & l,const tvm::Array<T> & r)178 tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
179   tvm::Array<T> ret(l);
180   for (const T& t : r) {
181     ret.push_back(t);
182   }
183   return ret;
184 }
185 
186 // helper function to run type check
RunTypeCheck(const IRModule & mod,const GlobalVar & var,relay::Function f)187 relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::Function f) {
188   auto func = Downcast<relay::Function>(relay::DeDup(std::move(f)));
189   // Type check the item before we add it to the module.
190   auto fv = relay::FreeVars(func);
191   auto ftv = relay::FreeTypeVars(func, mod);
192   CHECK_EQ(fv.size(), 0) << "There are free variables: " << fv
193                          << " in function: " << AsText(func, false);
194   CHECK_EQ(ftv.size(), 0) << "There are free type variables: " << fv
195                           << " in function: " << AsText(func, false);
196   // Type check the item before we add it to the module.
197   relay::Function checked_func = InferType(func, mod, var);
198   return checked_func;
199 }
200 
Add(const GlobalVar & var,const BaseFunc & f,bool update)201 void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) {
202   BaseFunc checked_func = f;
203   if (auto* ptr = f.as<relay::FunctionNode>()) {
204     checked_func = RunTypeCheck(GetRef<IRModule>(this), var, GetRef<relay::Function>(ptr));
205   }
206 
207   Type type = checked_func->checked_type();
208   CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);
209 
210   if (functions.find(var) != functions.end()) {
211     CHECK(update) << "Already have definition for " << var->name_hint;
212     auto old_type = functions[var]->checked_type();
213     CHECK(tvm::StructuralEqual()(type, old_type))
214         << "Module#update changes type, not possible in this mode.";
215   }
216   var->checked_type_ = type;
217   AddUnchecked(var, checked_func);
218 }
219 
AddUnchecked(const GlobalVar & var,const BaseFunc & func)220 void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
221   this->functions.Set(var, func);
222 
223   auto it = global_var_map_.find(var->name_hint);
224   if (it != global_var_map_.end()) {
225     CHECK_EQ((*it).second, var);
226   } else {
227     CHECK(global_var_map_.count(var->name_hint) == 0)
228         << "Duplicate global function name " << var->name_hint;
229   }
230 
231   global_var_map_.Set(var->name_hint, var);
232 }
233 
RegisterConstructors(const GlobalTypeVar & var,const TypeData & type)234 void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
235   // We hash the global type var name to use as a globally unique prefix for tags.
236   // The hash will be used as the most significant byte of the tag, with the index of
237   // the constructor in the less significant bytes
238   size_t hash = std::hash<std::string>()(var->name_hint);
239   int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
240   for (size_t i = 0; i < type->constructors.size(); ++i) {
241     type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
242     constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i];
243   }
244 }
245 
AddTypeDef(const GlobalTypeVar & var,const TypeData & type,bool update)246 void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
247   AddTypeDefUnchecked(var, type, update);
248   // need to kind check at the end because the check can look up
249   // a definition potentially
250   CHECK(relay::KindCheck(type, GetRef<IRModule>(this)) == TypeKind::kTypeData)
251       << "Invalid or malformed typedata given to module: " << type;
252 }
253 
AddTypeDefUnchecked(const GlobalTypeVar & var,const TypeData & type,bool update)254 void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
255                                        bool update) {
256   this->type_definitions.Set(var, type);
257   if (!update) {
258     // set global type var map
259     CHECK(global_type_var_map_.count(var->name_hint) == 0)
260         << "Duplicate global type definition name " << var->name_hint;
261   }
262   global_type_var_map_.Set(var->name_hint, var);
263   RegisterConstructors(var, type);
264 }
265 
Update(const GlobalVar & var,const BaseFunc & func)266 void IRModuleNode::Update(const GlobalVar& var, const BaseFunc& func) {
267   this->Add(var, func, true);
268 }
269 
UpdateTypeDef(const GlobalTypeVar & var,const TypeData & type)270 void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) {
271   this->AddTypeDef(var, type, true);
272 }
273 
Remove(const GlobalVar & var)274 void IRModuleNode::Remove(const GlobalVar& var) {
275   auto functions_node = this->functions.CopyOnWrite();
276   functions_node->erase(var);
277   auto gvar_node = global_var_map_.CopyOnWrite();
278   gvar_node->erase(var->name_hint);
279 }
280 
Lookup(const GlobalVar & var) const281 BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
282   auto it = functions.find(var);
283   CHECK(it != functions.end()) << "There is no definition of " << var->name_hint;
284   return (*it).second;
285 }
286 
Lookup(const String & name) const287 BaseFunc IRModuleNode::Lookup(const String& name) const {
288   GlobalVar id = this->GetGlobalVar(name);
289   return this->Lookup(id);
290 }
291 
LookupTypeDef(const GlobalTypeVar & var) const292 TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
293   auto it = type_definitions.find(var);
294   CHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint;
295   return (*it).second;
296 }
297 
LookupTypeDef(const String & name) const298 TypeData IRModuleNode::LookupTypeDef(const String& name) const {
299   GlobalTypeVar id = this->GetGlobalTypeVar(name);
300   return this->LookupTypeDef(id);
301 }
302 
LookupTag(const int32_t tag)303 Constructor IRModuleNode::LookupTag(const int32_t tag) {
304   auto it = constructor_tag_map_.find(tag);
305   CHECK(it != constructor_tag_map_.end()) << "There is no constructor with the tag " << tag;
306   return (*it).second;
307 }
308 
Update(const IRModule & mod)309 void IRModuleNode::Update(const IRModule& mod) {
310   // add functions and type defs. we add them unchecked first, so all definitions
311   // can reference each other, independent of the order in which they were defined.
312   for (auto pair : mod->functions) {
313     this->AddUnchecked(pair.first, pair.second);
314   }
315   for (auto pair : mod->type_definitions) {
316     this->AddTypeDefUnchecked(pair.first, pair.second);
317   }
318   for (auto pair : mod->functions) {
319     this->Update(pair.first, pair.second);
320   }
321   for (auto pair : mod->type_definitions) {
322     this->UpdateTypeDef(pair.first, pair.second);
323   }
324 }
325 
FromExpr(const RelayExpr & expr,const tvm::Map<GlobalVar,BaseFunc> & global_funcs,const tvm::Map<GlobalTypeVar,TypeData> & type_definitions)326 IRModule IRModule::FromExpr(const RelayExpr& expr,
327                             const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
328                             const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
329   auto mod = IRModule(global_funcs, type_definitions);
330   BaseFunc func;
331   std::string gv_name = "main";
332 
333   if (auto* func_node = expr.as<BaseFuncNode>()) {
334     func = GetRef<BaseFunc>(func_node);
335     if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
336       gv_name = opt.value();
337     }
338 
339   } else {
340     func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
341   }
342   auto main_gv = GlobalVar(gv_name);
343   mod->Add(main_gv, func);
344   return mod;
345 }
346 
Import(const String & path)347 void IRModuleNode::Import(const String& path) {
348   if (this->import_set_.count(path) == 0) {
349     this->import_set_.insert(path);
350     DLOG(INFO) << "Importing: " << path;
351     std::fstream src_file(path, std::fstream::in);
352     std::string file_contents{std::istreambuf_iterator<char>(src_file),
353                               std::istreambuf_iterator<char>()};
354     auto mod_to_import = IRModule::FromText(file_contents, path);
355     Update(mod_to_import);
356   }
357 }
358 
ImportFromStd(const String & path)359 void IRModuleNode::ImportFromStd(const String& path) {
360   auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
361   CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
362   std::string std_path = (*f)();
363   this->Import(std_path + "/" + path);
364 }
365 
Imports() const366 std::unordered_set<String> IRModuleNode::Imports() const { return this->import_set_; }
367 
FromText(const String & text,const String & source_path)368 IRModule IRModule::FromText(const String& text, const String& source_path) {
369   return tvm::parser::ParseModule(source_path, text);
370 }
371 
372 TVM_REGISTER_NODE_TYPE(IRModuleNode);
373 
374 TVM_REGISTER_GLOBAL("ir.IRModule")
375     .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
__anon85f932db0302(tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> types) 376                        tvm::Map<GlobalTypeVar, TypeData> types) {
377       return IRModule(funcs, types, {});
378     });
379 
__anon85f932db0402(TVMArgs args, TVMRetValue* ret) 380 TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) {
381   IRModule mod = args[0];
382   GlobalVar var = args[1];
383   ObjectRef val = args[2];
384   bool update = args[3];
385   CHECK(val->IsInstance<RelayExprNode>());
386 
387   if (val->IsInstance<BaseFuncNode>()) {
388     mod->Add(var, Downcast<BaseFunc>(val), update);
389   } else if (val->IsInstance<GlobalVarNode>()) {
390     GlobalVar gv = Downcast<GlobalVar>(val);
391     auto mod_copy = IRModule(make_object<IRModuleNode>(*mod.operator->()));
392     mod_copy = relay::transform::EtaExpand(
393         /* expand_constructor */ false,
394         /* expand_global_var */ true)(mod_copy);
395     auto func = mod_copy->Lookup(gv->name_hint);
396     mod->Add(var, Downcast<relay::Function>(func), update);
397   } else {
398     auto func = relay::Function({}, Downcast<RelayExpr>(val), Type(nullptr), {});
399     mod->Add(var, func, update);
400   }
401   *ret = mod;
402 });
403 
404 TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
405 
406 TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
407     .set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);
408 
409 TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars")
410     .set_body_method<IRModule>(&IRModuleNode::GetGlobalVars);
411 
412 TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars")
413     .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
414 
415 TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar")
416     .set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
417 
418 TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar")
419     .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
420 
__anon85f932db0502(IRModule mod, GlobalVar var) 421 TVM_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) {
422   return mod->Lookup(var);
423 });
424 
__anon85f932db0602(IRModule mod, String var) 425 TVM_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) {
426   return mod->Lookup(var);
427 });
428 
__anon85f932db0702(IRModule mod, GlobalTypeVar var) 429 TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) {
430   return mod->LookupTypeDef(var);
431 });
432 
__anon85f932db0802(IRModule mod, String var) 433 TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) {
434   return mod->LookupTypeDef(var);
435 });
436 
__anon85f932db0902(IRModule mod, int32_t tag) 437 TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) {
438   return mod->LookupTag(tag);
439 });
440 
441 TVM_REGISTER_GLOBAL("ir.Module_FromExpr")
442     .set_body_typed([](RelayExpr e, tvm::Map<GlobalVar, BaseFunc> funcs,
__anon85f932db0a02(RelayExpr e, tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> type_defs) 443                        tvm::Map<GlobalTypeVar, TypeData> type_defs) {
444       return IRModule::FromExpr(e, funcs, type_defs);
445     });
446 
__anon85f932db0b02(IRModule mod, IRModule from) 447 TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) {
448   mod->Update(from);
449 });
450 
451 TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction")
__anon85f932db0c02(IRModule mod, GlobalVar gv, BaseFunc func) 452     .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); });
453 
__anon85f932db0d02(IRModule mod, String path) 454 TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) {
455   mod->Import(path);
456 });
457 
__anon85f932db0e02(IRModule mod, String path) 458 TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) {
459   mod->ImportFromStd(path);
460 });
461 
462 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
__anon85f932db0f02(const ObjectRef& ref, ReprPrinter* p) 463     .set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
464       auto* node = static_cast<const IRModuleNode*>(ref.get());
465       p->stream << "IRModuleNode( " << node->functions << ")";
466     });
467 
468 }  // namespace tvm
469