1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file module.cc
22  * \brief TVM module system
23  */
24 #include <tvm/runtime/module.h>
25 #include <tvm/runtime/registry.h>
26 #include <tvm/runtime/packed_func.h>
27 #include <unordered_set>
28 #include <cstring>
29 #ifndef _LIBCPP_SGX_CONFIG
30 #include "file_util.h"
31 #endif
32 
33 namespace tvm {
34 namespace runtime {
35 
Import(Module other)36 void ModuleNode::Import(Module other) {
37   // specially handle rpc
38   if (!std::strcmp(this->type_key(), "rpc")) {
39     static const PackedFunc* fimport_ = nullptr;
40     if (fimport_ == nullptr) {
41       fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
42       CHECK(fimport_ != nullptr);
43     }
44     (*fimport_)(GetRef<Module>(this), other);
45     return;
46   }
47   // cyclic detection.
48   std::unordered_set<const ModuleNode*> visited{other.operator->()};
49   std::vector<const ModuleNode*> stack{other.operator->()};
50   while (!stack.empty()) {
51     const ModuleNode* n = stack.back();
52     stack.pop_back();
53     for (const Module& m : n->imports_) {
54       const ModuleNode* next = m.operator->();
55       if (visited.count(next)) continue;
56       visited.insert(next);
57       stack.push_back(next);
58     }
59   }
60   CHECK(!visited.count(this))
61       << "Cyclic dependency detected during import";
62   this->imports_.emplace_back(std::move(other));
63 }
64 
GetFunction(const std::string & name,bool query_imports)65 PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) {
66   ModuleNode* self = this;
67   PackedFunc pf = self->GetFunction(name, GetObjectPtr<Object>(this));
68   if (pf != nullptr) return pf;
69   if (query_imports) {
70     for (Module& m : self->imports_) {
71       pf = m->GetFunction(name, m.data_);
72       if (pf != nullptr) return pf;
73     }
74   }
75   return pf;
76 }
77 
LoadFromFile(const std::string & file_name,const std::string & format)78 Module Module::LoadFromFile(const std::string& file_name,
79                             const std::string& format) {
80 #ifndef _LIBCPP_SGX_CONFIG
81   std::string fmt = GetFileFormat(file_name, format);
82   CHECK(fmt.length() != 0)
83       << "Cannot deduce format of file " << file_name;
84   if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
85     fmt = "so";
86   }
87   std::string load_f_name = "module.loadfile_" + fmt;
88   const PackedFunc* f = Registry::Get(load_f_name);
89   CHECK(f != nullptr)
90       << "Loader of " << format << "("
91       << load_f_name << ") is not presented.";
92   Module m = (*f)(file_name, format);
93   return m;
94 #else
95   LOG(FATAL) << "SGX does not support LoadFromFile";
96 #endif
97 }
98 
SaveToFile(const std::string & file_name,const std::string & format)99 void ModuleNode::SaveToFile(const std::string& file_name,
100                             const std::string& format) {
101   LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
102 }
103 
SaveToBinary(dmlc::Stream * stream)104 void ModuleNode::SaveToBinary(dmlc::Stream* stream) {
105   LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary";
106 }
107 
GetSource(const std::string & format)108 std::string ModuleNode::GetSource(const std::string& format) {
109   LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource";
110   return "";
111 }
112 
GetFuncFromEnv(const std::string & name)113 const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
114   auto it = import_cache_.find(name);
115   if (it != import_cache_.end()) return it->second.get();
116   PackedFunc pf;
117   for (Module& m : this->imports_) {
118     pf = m.GetFunction(name, false);
119     if (pf != nullptr) break;
120   }
121   if (pf == nullptr) {
122     const PackedFunc* f = Registry::Get(name);
123     CHECK(f != nullptr)
124         << "Cannot find function " << name
125         << " in the imported modules or global registry";
126     return f;
127   } else {
128     std::unique_ptr<PackedFunc> f(new PackedFunc(pf));
129     import_cache_[name] = std::move(f);
130     return import_cache_.at(name).get();
131   }
132 }
133 
RuntimeEnabled(const std::string & target)134 bool RuntimeEnabled(const std::string& target) {
135   std::string f_name;
136   if (target == "cpu") {
137     return true;
138   } else if (target == "cuda" || target == "gpu") {
139     f_name = "device_api.gpu";
140   } else if (target == "cl" || target == "opencl" || target == "sdaccel") {
141     f_name = "device_api.opencl";
142   } else if (target == "gl" || target == "opengl") {
143     f_name = "device_api.opengl";
144   } else if (target == "mtl" || target == "metal") {
145     f_name = "device_api.metal";
146   } else if (target == "vulkan") {
147     f_name = "device_api.vulkan";
148   } else if (target == "stackvm") {
149     f_name = "codegen.build_stackvm";
150   } else if (target == "rpc") {
151     f_name = "device_api.rpc";
152   } else if (target == "vpi" || target == "verilog") {
153     f_name = "device_api.vpi";
154   } else if (target == "micro_dev") {
155     f_name = "device_api.micro_dev";
156   } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
157     f_name = "device_api.gpu";
158   } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
159     f_name = "device_api.rocm";
160   } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
161     const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
162     if (pf == nullptr) return false;
163     return (*pf)(target);
164   } else {
165     LOG(FATAL) << "Unknown optional runtime " << target;
166   }
167   return runtime::Registry::Get(f_name) != nullptr;
168 }
169 
170 TVM_REGISTER_GLOBAL("module._Enabled")
__anon19ca7abf0102(TVMArgs args, TVMRetValue *ret) 171 .set_body([](TVMArgs args, TVMRetValue *ret) {
172     *ret = RuntimeEnabled(args[0]);
173     });
174 
175 TVM_REGISTER_GLOBAL("module._GetSource")
__anon19ca7abf0202(TVMArgs args, TVMRetValue *ret) 176 .set_body([](TVMArgs args, TVMRetValue *ret) {
177     *ret = args[0].operator Module()->GetSource(args[1]);
178     });
179 
180 TVM_REGISTER_GLOBAL("module._ImportsSize")
__anon19ca7abf0302(TVMArgs args, TVMRetValue *ret) 181 .set_body([](TVMArgs args, TVMRetValue *ret) {
182     *ret = static_cast<int64_t>(
183         args[0].operator Module()->imports().size());
184     });
185 
186 TVM_REGISTER_GLOBAL("module._GetImport")
__anon19ca7abf0402(TVMArgs args, TVMRetValue *ret) 187 .set_body([](TVMArgs args, TVMRetValue *ret) {
188     *ret = args[0].operator Module()->
189         imports().at(args[1].operator int());
190     });
191 
192 TVM_REGISTER_GLOBAL("module._GetTypeKey")
__anon19ca7abf0502(TVMArgs args, TVMRetValue *ret) 193 .set_body([](TVMArgs args, TVMRetValue *ret) {
194     *ret = std::string(args[0].operator Module()->type_key());
195     });
196 
197 TVM_REGISTER_GLOBAL("module._LoadFromFile")
__anon19ca7abf0602(TVMArgs args, TVMRetValue *ret) 198 .set_body([](TVMArgs args, TVMRetValue *ret) {
199     *ret = Module::LoadFromFile(args[0], args[1]);
200     });
201 
202 TVM_REGISTER_GLOBAL("module._SaveToFile")
__anon19ca7abf0702(TVMArgs args, TVMRetValue *ret) 203 .set_body([](TVMArgs args, TVMRetValue *ret) {
204     args[0].operator Module()->
205         SaveToFile(args[1], args[2]);
206     });
207 }  // namespace runtime
208 }  // namespace tvm
209