1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file module.cc
22 * \brief The global module in Relay.
23 */
24 #include <tvm/ir/module.h>
25 #include <tvm/node/structural_equal.h>
26 #include <tvm/runtime/registry.h>
27 // NOTE: reverse dependency on relay.
28 // These dependencies do not happen at the interface-level,
29 // and are only used in minimum cases where they are clearly marked.
30 //
31 // Rationale: We calls into relay's analysis module to verify correctness.
32 #include <tvm/parser/parser.h>
33 #include <tvm/relay/analysis.h>
34 #include <tvm/relay/transform.h>
35
36 #include <fstream>
37 #include <sstream>
38 #include <unordered_set>
39
40 namespace tvm {
41
IRModule(tvm::Map<GlobalVar,BaseFunc> functions,tvm::Map<GlobalTypeVar,TypeData> type_definitions,std::unordered_set<String> import_set)42 IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
43 tvm::Map<GlobalTypeVar, TypeData> type_definitions,
44 std::unordered_set<String> import_set) {
45 auto n = make_object<IRModuleNode>();
46 n->functions = std::move(functions);
47 n->type_definitions = std::move(type_definitions);
48 n->global_type_var_map_ = {};
49 n->global_var_map_ = {};
50 n->constructor_tag_map_ = {};
51 n->import_set_ = std::move(import_set);
52
53 for (const auto& kv : n->functions) {
54 // set global var map
55 CHECK(n->global_var_map_.count(kv.first->name_hint) == 0)
56 << "Duplicate global function name " << kv.first->name_hint;
57 n->global_var_map_.Set(kv.first->name_hint, kv.first);
58 }
59
60 for (const auto& kv : n->type_definitions) {
61 // set global typevar map
62 CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0)
63 << "Duplicate global type definition name " << kv.first->name_hint;
64 n->global_type_var_map_.Set(kv.first->name_hint, kv.first);
65 n->RegisterConstructors(kv.first, kv.second);
66 }
67 data_ = std::move(n);
68 }
69
SEqualReduce(const IRModuleNode * other,SEqualReducer equal) const70 bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
71 if (functions.size() != other->functions.size()) return false;
72 for (const auto& kv : this->functions) {
73 if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
74 if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
75 }
76 if (type_definitions.size() != other->type_definitions.size()) return false;
77 for (const auto& kv : this->type_definitions) {
78 if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
79 if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
80 }
81 return true;
82 }
83
SHashReduce(SHashReducer hash_reduce) const84 void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
85 using KV = std::pair<std::string, ObjectRef>;
86 // hash the functions.
87 std::vector<KV> temp;
88
89 auto reduce_temp = [&]() {
90 // sort by the hash key of the keys.
91 std::sort(temp.begin(), temp.end(),
92 [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; });
93
94 hash_reduce(static_cast<uint64_t>(temp.size()));
95 // hash the content
96 for (size_t i = 0; i < temp.size(); ++i) {
97 hash_reduce(temp[i].first);
98 hash_reduce(temp[i].second);
99 }
100 };
101
102 for (const auto& kv : this->functions) {
103 temp.emplace_back(kv.first->name_hint, kv.second);
104 }
105 reduce_temp();
106
107 temp.clear();
108 for (const auto& kv : this->type_definitions) {
109 temp.emplace_back(kv.first->name_hint, kv.second);
110 }
111 reduce_temp();
112 }
113
ContainGlobalVar(const String & name) const114 bool IRModuleNode::ContainGlobalVar(const String& name) const {
115 return global_var_map_.find(name) != global_var_map_.end();
116 }
117
ContainGlobalTypeVar(const String & name) const118 bool IRModuleNode::ContainGlobalTypeVar(const String& name) const {
119 return global_type_var_map_.find(name) != global_type_var_map_.end();
120 }
121
GetGlobalVar(const String & name) const122 GlobalVar IRModuleNode::GetGlobalVar(const String& name) const {
123 auto it = global_var_map_.find(name);
124 if (it == global_var_map_.end()) {
125 std::ostringstream msg;
126 msg << "ValueError: Cannot find global var \"" << name << "\" in the Module\n"
127 << "candidates are: [";
128 int counter = 0;
129 for (auto kv : global_var_map_) {
130 if (counter++ != 0) {
131 msg << ", ";
132 }
133 msg << "\"" << kv.first << "\"";
134 }
135 msg << "]";
136 LOG(FATAL) << msg.str();
137 }
138 return (*it).second;
139 }
140
GetGlobalVars() const141 tvm::Array<GlobalVar> IRModuleNode::GetGlobalVars() const {
142 std::vector<GlobalVar> global_vars;
143 for (const auto& pair : global_var_map_) {
144 global_vars.push_back(pair.second);
145 }
146 return tvm::Array<GlobalVar>(global_vars);
147 }
148
GetGlobalTypeVar(const String & name) const149 GlobalTypeVar IRModuleNode::GetGlobalTypeVar(const String& name) const {
150 CHECK(global_type_var_map_.defined());
151 auto it = global_type_var_map_.find(name);
152 CHECK(it != global_type_var_map_.end())
153 << "Cannot find global type var " << name << " in the Module";
154 return (*it).second;
155 }
156
GetConstructor(const String & adt,const String & cons) const157 Constructor IRModuleNode::GetConstructor(const String& adt, const String& cons) const {
158 TypeData typeDef = this->LookupTypeDef(adt);
159 for (Constructor c : typeDef->constructors) {
160 if (cons.compare(c->name_hint) == 0) {
161 return c;
162 }
163 }
164
165 LOG(FATAL) << adt << " does not contain constructor " << cons;
166 throw std::runtime_error("Constructor Not Found.");
167 }
168
GetGlobalTypeVars() const169 tvm::Array<GlobalTypeVar> IRModuleNode::GetGlobalTypeVars() const {
170 std::vector<GlobalTypeVar> global_type_vars;
171 for (const auto& pair : global_type_var_map_) {
172 global_type_vars.push_back(pair.second);
173 }
174 return tvm::Array<GlobalTypeVar>(global_type_vars);
175 }
176
177 template <typename T>
concat(const tvm::Array<T> & l,const tvm::Array<T> & r)178 tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
179 tvm::Array<T> ret(l);
180 for (const T& t : r) {
181 ret.push_back(t);
182 }
183 return ret;
184 }
185
186 // helper function to run type check
RunTypeCheck(const IRModule & mod,const GlobalVar & var,relay::Function f)187 relay::Function RunTypeCheck(const IRModule& mod, const GlobalVar& var, relay::Function f) {
188 auto func = Downcast<relay::Function>(relay::DeDup(std::move(f)));
189 // Type check the item before we add it to the module.
190 auto fv = relay::FreeVars(func);
191 auto ftv = relay::FreeTypeVars(func, mod);
192 CHECK_EQ(fv.size(), 0) << "There are free variables: " << fv
193 << " in function: " << AsText(func, false);
194 CHECK_EQ(ftv.size(), 0) << "There are free type variables: " << fv
195 << " in function: " << AsText(func, false);
196 // Type check the item before we add it to the module.
197 relay::Function checked_func = InferType(func, mod, var);
198 return checked_func;
199 }
200
Add(const GlobalVar & var,const BaseFunc & f,bool update)201 void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) {
202 BaseFunc checked_func = f;
203 if (auto* ptr = f.as<relay::FunctionNode>()) {
204 checked_func = RunTypeCheck(GetRef<IRModule>(this), var, GetRef<relay::Function>(ptr));
205 }
206
207 Type type = checked_func->checked_type();
208 CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);
209
210 if (functions.find(var) != functions.end()) {
211 CHECK(update) << "Already have definition for " << var->name_hint;
212 auto old_type = functions[var]->checked_type();
213 CHECK(tvm::StructuralEqual()(type, old_type))
214 << "Module#update changes type, not possible in this mode.";
215 }
216 var->checked_type_ = type;
217 AddUnchecked(var, checked_func);
218 }
219
AddUnchecked(const GlobalVar & var,const BaseFunc & func)220 void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
221 this->functions.Set(var, func);
222
223 auto it = global_var_map_.find(var->name_hint);
224 if (it != global_var_map_.end()) {
225 CHECK_EQ((*it).second, var);
226 } else {
227 CHECK(global_var_map_.count(var->name_hint) == 0)
228 << "Duplicate global function name " << var->name_hint;
229 }
230
231 global_var_map_.Set(var->name_hint, var);
232 }
233
RegisterConstructors(const GlobalTypeVar & var,const TypeData & type)234 void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
235 // We hash the global type var name to use as a globally unique prefix for tags.
236 // The hash will be used as the most significant byte of the tag, with the index of
237 // the constructor in the less significant bytes
238 size_t hash = std::hash<std::string>()(var->name_hint);
239 int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
240 for (size_t i = 0; i < type->constructors.size(); ++i) {
241 type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
242 constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i];
243 }
244 }
245
AddTypeDef(const GlobalTypeVar & var,const TypeData & type,bool update)246 void IRModuleNode::AddTypeDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
247 AddTypeDefUnchecked(var, type, update);
248 // need to kind check at the end because the check can look up
249 // a definition potentially
250 CHECK(relay::KindCheck(type, GetRef<IRModule>(this)) == TypeKind::kTypeData)
251 << "Invalid or malformed typedata given to module: " << type;
252 }
253
AddTypeDefUnchecked(const GlobalTypeVar & var,const TypeData & type,bool update)254 void IRModuleNode::AddTypeDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
255 bool update) {
256 this->type_definitions.Set(var, type);
257 if (!update) {
258 // set global type var map
259 CHECK(global_type_var_map_.count(var->name_hint) == 0)
260 << "Duplicate global type definition name " << var->name_hint;
261 }
262 global_type_var_map_.Set(var->name_hint, var);
263 RegisterConstructors(var, type);
264 }
265
Update(const GlobalVar & var,const BaseFunc & func)266 void IRModuleNode::Update(const GlobalVar& var, const BaseFunc& func) {
267 this->Add(var, func, true);
268 }
269
UpdateTypeDef(const GlobalTypeVar & var,const TypeData & type)270 void IRModuleNode::UpdateTypeDef(const GlobalTypeVar& var, const TypeData& type) {
271 this->AddTypeDef(var, type, true);
272 }
273
Remove(const GlobalVar & var)274 void IRModuleNode::Remove(const GlobalVar& var) {
275 auto functions_node = this->functions.CopyOnWrite();
276 functions_node->erase(var);
277 auto gvar_node = global_var_map_.CopyOnWrite();
278 gvar_node->erase(var->name_hint);
279 }
280
Lookup(const GlobalVar & var) const281 BaseFunc IRModuleNode::Lookup(const GlobalVar& var) const {
282 auto it = functions.find(var);
283 CHECK(it != functions.end()) << "There is no definition of " << var->name_hint;
284 return (*it).second;
285 }
286
Lookup(const String & name) const287 BaseFunc IRModuleNode::Lookup(const String& name) const {
288 GlobalVar id = this->GetGlobalVar(name);
289 return this->Lookup(id);
290 }
291
LookupTypeDef(const GlobalTypeVar & var) const292 TypeData IRModuleNode::LookupTypeDef(const GlobalTypeVar& var) const {
293 auto it = type_definitions.find(var);
294 CHECK(it != type_definitions.end()) << "There is no definition of " << var->name_hint;
295 return (*it).second;
296 }
297
LookupTypeDef(const String & name) const298 TypeData IRModuleNode::LookupTypeDef(const String& name) const {
299 GlobalTypeVar id = this->GetGlobalTypeVar(name);
300 return this->LookupTypeDef(id);
301 }
302
LookupTag(const int32_t tag)303 Constructor IRModuleNode::LookupTag(const int32_t tag) {
304 auto it = constructor_tag_map_.find(tag);
305 CHECK(it != constructor_tag_map_.end()) << "There is no constructor with the tag " << tag;
306 return (*it).second;
307 }
308
Update(const IRModule & mod)309 void IRModuleNode::Update(const IRModule& mod) {
310 // add functions and type defs. we add them unchecked first, so all definitions
311 // can reference each other, independent of the order in which they were defined.
312 for (auto pair : mod->functions) {
313 this->AddUnchecked(pair.first, pair.second);
314 }
315 for (auto pair : mod->type_definitions) {
316 this->AddTypeDefUnchecked(pair.first, pair.second);
317 }
318 for (auto pair : mod->functions) {
319 this->Update(pair.first, pair.second);
320 }
321 for (auto pair : mod->type_definitions) {
322 this->UpdateTypeDef(pair.first, pair.second);
323 }
324 }
325
FromExpr(const RelayExpr & expr,const tvm::Map<GlobalVar,BaseFunc> & global_funcs,const tvm::Map<GlobalTypeVar,TypeData> & type_definitions)326 IRModule IRModule::FromExpr(const RelayExpr& expr,
327 const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
328 const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
329 auto mod = IRModule(global_funcs, type_definitions);
330 BaseFunc func;
331 std::string gv_name = "main";
332
333 if (auto* func_node = expr.as<BaseFuncNode>()) {
334 func = GetRef<BaseFunc>(func_node);
335 if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
336 gv_name = opt.value();
337 }
338
339 } else {
340 func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
341 }
342 auto main_gv = GlobalVar(gv_name);
343 mod->Add(main_gv, func);
344 return mod;
345 }
346
Import(const String & path)347 void IRModuleNode::Import(const String& path) {
348 if (this->import_set_.count(path) == 0) {
349 this->import_set_.insert(path);
350 DLOG(INFO) << "Importing: " << path;
351 std::fstream src_file(path, std::fstream::in);
352 std::string file_contents{std::istreambuf_iterator<char>(src_file),
353 std::istreambuf_iterator<char>()};
354 auto mod_to_import = IRModule::FromText(file_contents, path);
355 Update(mod_to_import);
356 }
357 }
358
ImportFromStd(const String & path)359 void IRModuleNode::ImportFromStd(const String& path) {
360 auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
361 CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
362 std::string std_path = (*f)();
363 this->Import(std_path + "/" + path);
364 }
365
Imports() const366 std::unordered_set<String> IRModuleNode::Imports() const { return this->import_set_; }
367
FromText(const String & text,const String & source_path)368 IRModule IRModule::FromText(const String& text, const String& source_path) {
369 return tvm::parser::ParseModule(source_path, text);
370 }
371
372 TVM_REGISTER_NODE_TYPE(IRModuleNode);
373
374 TVM_REGISTER_GLOBAL("ir.IRModule")
375 .set_body_typed([](tvm::Map<GlobalVar, BaseFunc> funcs,
__anon85f932db0302(tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> types) 376 tvm::Map<GlobalTypeVar, TypeData> types) {
377 return IRModule(funcs, types, {});
378 });
379
__anon85f932db0402(TVMArgs args, TVMRetValue* ret) 380 TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) {
381 IRModule mod = args[0];
382 GlobalVar var = args[1];
383 ObjectRef val = args[2];
384 bool update = args[3];
385 CHECK(val->IsInstance<RelayExprNode>());
386
387 if (val->IsInstance<BaseFuncNode>()) {
388 mod->Add(var, Downcast<BaseFunc>(val), update);
389 } else if (val->IsInstance<GlobalVarNode>()) {
390 GlobalVar gv = Downcast<GlobalVar>(val);
391 auto mod_copy = IRModule(make_object<IRModuleNode>(*mod.operator->()));
392 mod_copy = relay::transform::EtaExpand(
393 /* expand_constructor */ false,
394 /* expand_global_var */ true)(mod_copy);
395 auto func = mod_copy->Lookup(gv->name_hint);
396 mod->Add(var, Downcast<relay::Function>(func), update);
397 } else {
398 auto func = relay::Function({}, Downcast<RelayExpr>(val), Type(nullptr), {});
399 mod->Add(var, func, update);
400 }
401 *ret = mod;
402 });
403
404 TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method<IRModule>(&IRModuleNode::AddTypeDef);
405
406 TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar")
407 .set_body_method<IRModule>(&IRModuleNode::GetGlobalVar);
408
409 TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVars")
410 .set_body_method<IRModule>(&IRModuleNode::GetGlobalVars);
411
412 TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars")
413 .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars);
414
415 TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar")
416 .set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar);
417
418 TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar")
419 .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar);
420
__anon85f932db0502(IRModule mod, GlobalVar var) 421 TVM_REGISTER_GLOBAL("ir.Module_Lookup").set_body_typed([](IRModule mod, GlobalVar var) {
422 return mod->Lookup(var);
423 });
424
__anon85f932db0602(IRModule mod, String var) 425 TVM_REGISTER_GLOBAL("ir.Module_Lookup_str").set_body_typed([](IRModule mod, String var) {
426 return mod->Lookup(var);
427 });
428
__anon85f932db0702(IRModule mod, GlobalTypeVar var) 429 TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) {
430 return mod->LookupTypeDef(var);
431 });
432
__anon85f932db0802(IRModule mod, String var) 433 TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) {
434 return mod->LookupTypeDef(var);
435 });
436
__anon85f932db0902(IRModule mod, int32_t tag) 437 TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) {
438 return mod->LookupTag(tag);
439 });
440
441 TVM_REGISTER_GLOBAL("ir.Module_FromExpr")
442 .set_body_typed([](RelayExpr e, tvm::Map<GlobalVar, BaseFunc> funcs,
__anon85f932db0a02(RelayExpr e, tvm::Map<GlobalVar, BaseFunc> funcs, tvm::Map<GlobalTypeVar, TypeData> type_defs) 443 tvm::Map<GlobalTypeVar, TypeData> type_defs) {
444 return IRModule::FromExpr(e, funcs, type_defs);
445 });
446
__anon85f932db0b02(IRModule mod, IRModule from) 447 TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) {
448 mod->Update(from);
449 });
450
451 TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction")
__anon85f932db0c02(IRModule mod, GlobalVar gv, BaseFunc func) 452 .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); });
453
__anon85f932db0d02(IRModule mod, String path) 454 TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) {
455 mod->Import(path);
456 });
457
__anon85f932db0e02(IRModule mod, String path) 458 TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) {
459 mod->ImportFromStd(path);
460 });
461
462 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
__anon85f932db0f02(const ObjectRef& ref, ReprPrinter* p) 463 .set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) {
464 auto* node = static_cast<const IRModuleNode*>(ref.get());
465 p->stream << "IRModuleNode( " << node->functions << ")";
466 });
467
468 } // namespace tvm
469