1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file module.cc
22 * \brief TVM module system
23 */
24 #include <tvm/runtime/module.h>
25 #include <tvm/runtime/packed_func.h>
26 #include <tvm/runtime/registry.h>
27
28 #include <cstring>
29 #include <unordered_set>
30
31 #include "file_util.h"
32
33 namespace tvm {
34 namespace runtime {
35
Import(Module other)36 void ModuleNode::Import(Module other) {
37 // specially handle rpc
38 if (!std::strcmp(this->type_key(), "rpc")) {
39 static const PackedFunc* fimport_ = nullptr;
40 if (fimport_ == nullptr) {
41 fimport_ = runtime::Registry::Get("rpc.ImportRemoteModule");
42 CHECK(fimport_ != nullptr);
43 }
44 (*fimport_)(GetRef<Module>(this), other);
45 return;
46 }
47 // cyclic detection.
48 std::unordered_set<const ModuleNode*> visited{other.operator->()};
49 std::vector<const ModuleNode*> stack{other.operator->()};
50 while (!stack.empty()) {
51 const ModuleNode* n = stack.back();
52 stack.pop_back();
53 for (const Module& m : n->imports_) {
54 const ModuleNode* next = m.operator->();
55 if (visited.count(next)) continue;
56 visited.insert(next);
57 stack.push_back(next);
58 }
59 }
60 CHECK(!visited.count(this)) << "Cyclic dependency detected during import";
61 this->imports_.emplace_back(std::move(other));
62 }
63
GetFunction(const std::string & name,bool query_imports)64 PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports) {
65 ModuleNode* self = this;
66 PackedFunc pf = self->GetFunction(name, GetObjectPtr<Object>(this));
67 if (pf != nullptr) return pf;
68 if (query_imports) {
69 for (Module& m : self->imports_) {
70 pf = m.operator->()->GetFunction(name, query_imports);
71 if (pf != nullptr) {
72 return pf;
73 }
74 }
75 }
76 return pf;
77 }
78
LoadFromFile(const std::string & file_name,const std::string & format)79 Module Module::LoadFromFile(const std::string& file_name, const std::string& format) {
80 std::string fmt = GetFileFormat(file_name, format);
81 CHECK(fmt.length() != 0) << "Cannot deduce format of file " << file_name;
82 if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
83 fmt = "so";
84 }
85 std::string load_f_name = "runtime.module.loadfile_" + fmt;
86 const PackedFunc* f = Registry::Get(load_f_name);
87 CHECK(f != nullptr) << "Loader of " << format << "(" << load_f_name << ") is not presented.";
88 Module m = (*f)(file_name, format);
89 return m;
90 }
91
SaveToFile(const std::string & file_name,const std::string & format)92 void ModuleNode::SaveToFile(const std::string& file_name, const std::string& format) {
93 LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToFile";
94 }
95
SaveToBinary(dmlc::Stream * stream)96 void ModuleNode::SaveToBinary(dmlc::Stream* stream) {
97 LOG(FATAL) << "Module[" << type_key() << "] does not support SaveToBinary";
98 }
99
GetSource(const std::string & format)100 std::string ModuleNode::GetSource(const std::string& format) {
101 LOG(FATAL) << "Module[" << type_key() << "] does not support GetSource";
102 return "";
103 }
104
GetFuncFromEnv(const std::string & name)105 const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
106 auto it = import_cache_.find(name);
107 if (it != import_cache_.end()) return it->second.get();
108 PackedFunc pf;
109 for (Module& m : this->imports_) {
110 pf = m.GetFunction(name, true);
111 if (pf != nullptr) break;
112 }
113 if (pf == nullptr) {
114 const PackedFunc* f = Registry::Get(name);
115 CHECK(f != nullptr) << "Cannot find function " << name
116 << " in the imported modules or global registry";
117 return f;
118 } else {
119 import_cache_.insert(std::make_pair(name, std::make_shared<PackedFunc>(pf)));
120 return import_cache_.at(name).get();
121 }
122 }
123
RuntimeEnabled(const std::string & target)124 bool RuntimeEnabled(const std::string& target) {
125 std::string f_name;
126 if (target == "cpu") {
127 return true;
128 } else if (target == "cuda" || target == "gpu") {
129 f_name = "device_api.gpu";
130 } else if (target == "cl" || target == "opencl" || target == "sdaccel") {
131 f_name = "device_api.opencl";
132 } else if (target == "mtl" || target == "metal") {
133 f_name = "device_api.metal";
134 } else if (target == "tflite") {
135 f_name = "target.runtime.tflite";
136 } else if (target == "vulkan") {
137 f_name = "device_api.vulkan";
138 } else if (target == "stackvm") {
139 f_name = "target.build.stackvm";
140 } else if (target == "rpc") {
141 f_name = "device_api.rpc";
142 } else if (target == "micro_dev") {
143 f_name = "device_api.micro_dev";
144 } else if (target == "hexagon") {
145 f_name = "device_api.hexagon";
146 } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
147 f_name = "device_api.gpu";
148 } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
149 f_name = "device_api.rocm";
150 } else if (target.length() >= 4 && target.substr(0, 4) == "llvm") {
151 const PackedFunc* pf = runtime::Registry::Get("codegen.llvm_target_enabled");
152 if (pf == nullptr) return false;
153 return (*pf)(target);
154 } else {
155 LOG(FATAL) << "Unknown optional runtime " << target;
156 }
157 return runtime::Registry::Get(f_name) != nullptr;
158 }
159
160 TVM_REGISTER_GLOBAL("runtime.RuntimeEnabled").set_body_typed(RuntimeEnabled);
161
__anon95708a840102(Module mod, std::string fmt) 162 TVM_REGISTER_GLOBAL("runtime.ModuleGetSource").set_body_typed([](Module mod, std::string fmt) {
163 return mod->GetSource(fmt);
164 });
165
__anon95708a840202(Module mod) 166 TVM_REGISTER_GLOBAL("runtime.ModuleImportsSize").set_body_typed([](Module mod) {
167 return static_cast<int64_t>(mod->imports().size());
168 });
169
__anon95708a840302(Module mod, int index) 170 TVM_REGISTER_GLOBAL("runtime.ModuleGetImport").set_body_typed([](Module mod, int index) {
171 return mod->imports().at(index);
172 });
173
__anon95708a840402(Module mod) 174 TVM_REGISTER_GLOBAL("runtime.ModuleGetTypeKey").set_body_typed([](Module mod) {
175 return std::string(mod->type_key());
176 });
177
178 TVM_REGISTER_GLOBAL("runtime.ModuleLoadFromFile").set_body_typed(Module::LoadFromFile);
179
180 TVM_REGISTER_GLOBAL("runtime.ModuleSaveToFile")
__anon95708a840502(Module mod, std::string name, std::string fmt) 181 .set_body_typed([](Module mod, std::string name, std::string fmt) {
182 mod->SaveToFile(name, fmt);
183 });
184
185 TVM_REGISTER_OBJECT_TYPE(ModuleNode);
186 } // namespace runtime
187 } // namespace tvm
188