1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file module.cc
22 * \brief TVM module system
23 */
24 #include <tvm/runtime/module.h>
25 #include <tvm/runtime/registry.h>
26 #include <tvm/runtime/packed_func.h>
27 #include <unordered_set>
28 #include <cstring>
29 #ifndef _LIBCPP_SGX_CONFIG
30 #include "file_util.h"
31 #endif
32
33 namespace tvm {
34 namespace runtime {
35
Import(Module other)36 void ModuleNode::Import(Module other) {
37 // specially handle rpc
38 if (!std::strcmp(this->type_key(), "rpc")) {
39 static const PackedFunc* fimport_ = nullptr;
40 if (fimport_ == nullptr) {
41 fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
42 CHECK(fimport_ != nullptr);
43 }
44 (*fimport_)(GetRef<Module>(this), other);
45 return;
46 }
47 // cyclic detection.
48 std::unordered_set<const ModuleNode*> visited{other.operator->()};
49 std::vector<const ModuleNode*> stack{other.operator->()};
50 while (!stack.empty()) {
51 const ModuleNode* n = stack.back();
52 stack.pop_back();
53 for (const Module& m : n->imports_) {
54 const ModuleNode* next = m.operator->();
55 if (visited.count(next)) continue;
56 visited.insert(next);
57 stack.push_back(next);
58 }
59 }
60 CHECK(!visited.count(this))
61 << "Cyclic dependency detected during import";
62 this->imports_.emplace_back(std::move(other));
63 }
64
GetFunction(const std::string & name,bool query_imports)65 PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) {
66 ModuleNode* self = this;
67 PackedFunc pf = self->GetFunction(name, GetObjectPtr<Object>(this));
68 if (pf != nullptr) return pf;
69 if (query_imports) {
70 for (Module& m : self->imports_) {
71 pf = m->GetFunction(name, m.data_);
72 if (pf != nullptr) return pf;
73 }
74 }
75 return pf;
76 }
77
LoadFromFile(const std::string & file_name,const std::string & format)78 Module Module::LoadFromFile(const std::string& file_name,
79 const std::string& format) {
80 #ifndef _LIBCPP_SGX_CONFIG
81 std::string fmt = GetFileFormat(file_name, format);
82 CHECK(fmt.length() != 0)
83 << "Cannot deduce format of file " << file_name;
84 if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
85 fmt = "so";
86 }
87 std::string load_f_name = "module.loadfile_" + fmt;
88 const PackedFunc* f = Registry::Get(load_f_name);
89 CHECK(f != nullptr)
90 << "Loader of " << format << "("
91 << load_f_name << ") is not presented.";
92 Module m = (*f)(file_name, format);
93 return m;
94 #else
95 LOG(FATAL) << "SGX does not support LoadFromFile";
96 #endif
97 }
98
SaveToFile(const std::string & file_name,const std::string & format)99 void ModuleNode::SaveToFile(const std::string& file_name,
100 const std::string& format) {
101 LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
102 }
103
SaveToBinary(dmlc::Stream * stream)104 void ModuleNode::SaveToBinary(dmlc::Stream* stream) {
105 LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary";
106 }
107
GetSource(const std::string & format)108 std::string ModuleNode::GetSource(const std::string& format) {
109 LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource";
110 return "";
111 }
112
GetFuncFromEnv(const std::string & name)113 const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
114 auto it = import_cache_.find(name);
115 if (it != import_cache_.end()) return it->second.get();
116 PackedFunc pf;
117 for (Module& m : this->imports_) {
118 pf = m.GetFunction(name, false);
119 if (pf != nullptr) break;
120 }
121 if (pf == nullptr) {
122 const PackedFunc* f = Registry::Get(name);
123 CHECK(f != nullptr)
124 << "Cannot find function " << name
125 << " in the imported modules or global registry";
126 return f;
127 } else {
128 std::unique_ptr<PackedFunc> f(new PackedFunc(pf));
129 import_cache_[name] = std::move(f);
130 return import_cache_.at(name).get();
131 }
132 }
133
RuntimeEnabled(const std::string & target)134 bool RuntimeEnabled(const std::string& target) {
135 std::string f_name;
136 if (target == "cpu") {
137 return true;
138 } else if (target == "cuda" || target == "gpu") {
139 f_name = "device_api.gpu";
140 } else if (target == "cl" || target == "opencl" || target == "sdaccel") {
141 f_name = "device_api.opencl";
142 } else if (target == "gl" || target == "opengl") {
143 f_name = "device_api.opengl";
144 } else if (target == "mtl" || target == "metal") {
145 f_name = "device_api.metal";
146 } else if (target == "vulkan") {
147 f_name = "device_api.vulkan";
148 } else if (target == "stackvm") {
149 f_name = "codegen.build_stackvm";
150 } else if (target == "rpc") {
151 f_name = "device_api.rpc";
152 } else if (target == "vpi" || target == "verilog") {
153 f_name = "device_api.vpi";
154 } else if (target == "micro_dev") {
155 f_name = "device_api.micro_dev";
156 } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
157 f_name = "device_api.gpu";
158 } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
159 f_name = "device_api.rocm";
160 } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
161 const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
162 if (pf == nullptr) return false;
163 return (*pf)(target);
164 } else {
165 LOG(FATAL) << "Unknown optional runtime " << target;
166 }
167 return runtime::Registry::Get(f_name) != nullptr;
168 }
169
170 TVM_REGISTER_GLOBAL("module._Enabled")
__anon19ca7abf0102(TVMArgs args, TVMRetValue *ret) 171 .set_body([](TVMArgs args, TVMRetValue *ret) {
172 *ret = RuntimeEnabled(args[0]);
173 });
174
175 TVM_REGISTER_GLOBAL("module._GetSource")
__anon19ca7abf0202(TVMArgs args, TVMRetValue *ret) 176 .set_body([](TVMArgs args, TVMRetValue *ret) {
177 *ret = args[0].operator Module()->GetSource(args[1]);
178 });
179
180 TVM_REGISTER_GLOBAL("module._ImportsSize")
__anon19ca7abf0302(TVMArgs args, TVMRetValue *ret) 181 .set_body([](TVMArgs args, TVMRetValue *ret) {
182 *ret = static_cast<int64_t>(
183 args[0].operator Module()->imports().size());
184 });
185
186 TVM_REGISTER_GLOBAL("module._GetImport")
__anon19ca7abf0402(TVMArgs args, TVMRetValue *ret) 187 .set_body([](TVMArgs args, TVMRetValue *ret) {
188 *ret = args[0].operator Module()->
189 imports().at(args[1].operator int());
190 });
191
192 TVM_REGISTER_GLOBAL("module._GetTypeKey")
__anon19ca7abf0502(TVMArgs args, TVMRetValue *ret) 193 .set_body([](TVMArgs args, TVMRetValue *ret) {
194 *ret = std::string(args[0].operator Module()->type_key());
195 });
196
197 TVM_REGISTER_GLOBAL("module._LoadFromFile")
__anon19ca7abf0602(TVMArgs args, TVMRetValue *ret) 198 .set_body([](TVMArgs args, TVMRetValue *ret) {
199 *ret = Module::LoadFromFile(args[0], args[1]);
200 });
201
202 TVM_REGISTER_GLOBAL("module._SaveToFile")
__anon19ca7abf0702(TVMArgs args, TVMRetValue *ret) 203 .set_body([](TVMArgs args, TVMRetValue *ret) {
204 args[0].operator Module()->
205 SaveToFile(args[1], args[2]);
206 });
207 } // namespace runtime
208 } // namespace tvm
209