1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file module.cc
22  * \brief TVM module system
23  */
24 #include <tvm/runtime/module.h>
25 #include <tvm/runtime/packed_func.h>
26 #include <tvm/runtime/registry.h>
27 
28 #include <cstring>
29 #include <unordered_set>
30 
31 #include "file_util.h"
32 
33 namespace tvm {
34 namespace runtime {
35 
Import(Module other)36 void ModuleNode::Import(Module other) {
37   // specially handle rpc
38   if (!std::strcmp(this->type_key(), "rpc")) {
39     static const PackedFunc* fimport_ = nullptr;
40     if (fimport_ == nullptr) {
41       fimport_ = runtime::Registry::Get("rpc.ImportRemoteModule");
42       CHECK(fimport_ != nullptr);
43     }
44     (*fimport_)(GetRef<Module>(this), other);
45     return;
46   }
47   // cyclic detection.
48   std::unordered_set<const ModuleNode*> visited{other.operator->()};
49   std::vector<const ModuleNode*> stack{other.operator->()};
50   while (!stack.empty()) {
51     const ModuleNode* n = stack.back();
52     stack.pop_back();
53     for (const Module& m : n->imports_) {
54       const ModuleNode* next = m.operator->();
55       if (visited.count(next)) continue;
56       visited.insert(next);
57       stack.push_back(next);
58     }
59   }
60   CHECK(!visited.count(this)) << "Cyclic dependency detected during import";
61   this->imports_.emplace_back(std::move(other));
62 }
63 
GetFunction(const std::string & name,bool query_imports)64 PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) {
65   ModuleNode* self = this;
66   PackedFunc pf = self->GetFunction(name, GetObjectPtr<Object>(this));
67   if (pf != nullptr) return pf;
68   if (query_imports) {
69     for (Module& m : self->imports_) {
70       pf = m.operator->()->GetFunction(name, query_imports);
71       if (pf != nullptr) {
72         return pf;
73       }
74     }
75   }
76   return pf;
77 }
78 
LoadFromFile(const std::string & file_name,const std::string & format)79 Module Module::LoadFromFile(const std::string& file_name, const std::string& format) {
80   std::string fmt = GetFileFormat(file_name, format);
81   CHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name;
82   if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
83     fmt = "so";
84   }
85   std::string load_f_name = "runtime.module.loadfile_" + fmt;
86   const PackedFunc* f = Registry::Get(load_f_name);
87   CHECK(f != nullptr) << "Loader of " << format << "(" << load_f_name << ") is not presented.";
88   Module m = (*f)(file_name, format);
89   return m;
90 }
91 
SaveToFile(const std::string & file_name,const std::string & format)92 void ModuleNode::SaveToFile(const std::string& file_name, const std::string& format) {
93   LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
94 }
95 
SaveToBinary(dmlc::Stream * stream)96 void ModuleNode::SaveToBinary(dmlc::Stream* stream) {
97   LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary";
98 }
99 
GetSource(const std::string & format)100 std::string ModuleNode::GetSource(const std::string& format) {
101   LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource";
102   return "";
103 }
104 
GetFuncFromEnv(const std::string & name)105 const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
106   auto it = import_cache_.find(name);
107   if (it != import_cache_.end()) return it->second.get();
108   PackedFunc pf;
109   for (Module& m : this->imports_) {
110     pf = m.GetFunction(name, true);
111     if (pf != nullptr) break;
112   }
113   if (pf == nullptr) {
114     const PackedFunc* f = Registry::Get(name);
115     CHECK(f != nullptr) << "Cannot find function " << name
116                         << " in the imported modules or global registry";
117     return f;
118   } else {
119     import_cache_.insert(std::make_pair(name, std::make_shared<PackedFunc>(pf)));
120     return import_cache_.at(name).get();
121   }
122 }
123 
RuntimeEnabled(const std::string & target)124 bool RuntimeEnabled(const std::string& target) {
125   std::string f_name;
126   if (target == "cpu") {
127     return true;
128   } else if (target == "cuda" || target == "gpu") {
129     f_name = "device_api.gpu";
130   } else if (target == "cl" || target == "opencl" || target == "sdaccel") {
131     f_name = "device_api.opencl";
132   } else if (target == "mtl" || target == "metal") {
133     f_name = "device_api.metal";
134   } else if (target == "tflite") {
135     f_name = "target.runtime.tflite";
136   } else if (target == "vulkan") {
137     f_name = "device_api.vulkan";
138   } else if (target == "stackvm") {
139     f_name = "target.build.stackvm";
140   } else if (target == "rpc") {
141     f_name = "device_api.rpc";
142   } else if (target == "micro_dev") {
143     f_name = "device_api.micro_dev";
144   } else if (target == "hexagon") {
145     f_name = "device_api.hexagon";
146   } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
147     f_name = "device_api.gpu";
148   } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
149     f_name = "device_api.rocm";
150   } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
151     const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
152     if (pf == nullptr) return false;
153     return (*pf)(target);
154   } else {
155     LOG(FATAL) << "Unknown optional runtime " << target;
156   }
157   return runtime::Registry::Get(f_name) != nullptr;
158 }
159 
160 TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled").set_body_typed(RuntimeEnabled);
161 
__anon95708a840102(Module mod, std::string fmt) 162 TVM_REGISTER_GLOBAL("runtime.ModuleGetSource").set_body_typed([](Module mod, std::string fmt) {
163   return mod->GetSource(fmt);
164 });
165 
__anon95708a840202(Module mod) 166 TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize").set_body_typed([](Module mod) {
167   return static_cast<int64_t>(mod->imports().size());
168 });
169 
__anon95708a840302(Module mod, int index) 170 TVM_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int index) {
171   return mod->imports().at(index);
172 });
173 
__anon95708a840402(Module mod) 174 TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) {
175   return std::string(mod->type_key());
176 });
177 
178 TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile);
179 
180 TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
__anon95708a840502(Module mod, std::string name, std::string fmt) 181     .set_body_typed([](Module mod, std::string name, std::string fmt) {
182       mod->SaveToFile(name, fmt);
183     });
184 
185 TVM_REGISTER_OBJECT_TYPE(ModuleNode);
186 }  // namespace runtime
187 }  // namespace tvm
188