1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file source_module.cc
22 * \brief Source code module, only for viewing
23 */
24 #include <tvm/runtime/packed_func.h>
25 #include "codegen_source_base.h"
26 #include "../runtime/file_util.h"
27 #include "../runtime/meta_data.h"
28
29 namespace tvm {
30 namespace codegen {
31
32 using runtime::TVMArgs;
33 using runtime::TVMRetValue;
34 using runtime::PackedFunc;
35
36 using runtime::GetFileFormat;
37 using runtime::GetMetaFilePath;
38 using runtime::FunctionInfo;
39 using runtime::SaveBinaryToFile;
40
41 // Simulator function
42 class SourceModuleNode : public runtime::ModuleNode {
43 public:
SourceModuleNode(std::string code,std::string fmt)44 SourceModuleNode(std::string code,
45 std::string fmt)
46 : code_(code), fmt_(fmt) {}
type_key() const47 const char* type_key() const {
48 return "source";
49 }
50
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)51 PackedFunc GetFunction(
52 const std::string& name,
53 const ObjectPtr<Object>& sptr_to_self) final {
54 LOG(FATAL) << "Source module cannot execute, to get executable module"
55 << " build TVM with \'" << fmt_ << "\' runtime support";
56 return PackedFunc();
57 }
58
GetSource(const std::string & format)59 std::string GetSource(const std::string& format) final {
60 return code_;
61 }
62
63 protected:
64 std::string code_;
65 std::string fmt_;
66 };
67
SourceModuleCreate(std::string code,std::string fmt)68 runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
69 auto n = make_object<SourceModuleNode>(code, fmt);
70 return runtime::Module(n);
71 }
72
73 // Simulator function
74 class CSourceModuleNode : public runtime::ModuleNode {
75 public:
CSourceModuleNode(std::string code,std::string fmt)76 CSourceModuleNode(std::string code,
77 std::string fmt)
78 : code_(code), fmt_(fmt) {}
type_key() const79 const char* type_key() const {
80 return "c";
81 }
82
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)83 PackedFunc GetFunction(
84 const std::string& name,
85 const ObjectPtr<Object>& sptr_to_self) final {
86 LOG(FATAL) << "C Source module cannot execute, to get executable module"
87 << " build TVM with \'" << fmt_ << "\' runtime support";
88 return PackedFunc();
89 }
90
GetSource(const std::string & format)91 std::string GetSource(const std::string& format) final {
92 return code_;
93 }
94
SaveToFile(const std::string & file_name,const std::string & format)95 void SaveToFile(const std::string& file_name,
96 const std::string& format) final {
97 std::string fmt = GetFileFormat(file_name, format);
98 std::string meta_file = GetMetaFilePath(file_name);
99 if (fmt == "cc") {
100 CHECK_NE(code_.length(), 0);
101 SaveBinaryToFile(file_name, code_);
102 } else {
103 CHECK_EQ(fmt, fmt_)
104 << "Can only save to format=" << fmt_;
105 }
106 }
107
108 protected:
109 std::string code_;
110 std::string fmt_;
111 };
112
CSourceModuleCreate(std::string code,std::string fmt)113 runtime::Module CSourceModuleCreate(std::string code, std::string fmt) {
114 auto n = make_object<CSourceModuleNode>(code, fmt);
115 return runtime::Module(n);
116 }
117
118 // supports limited save without cross compile
119 class DeviceSourceModuleNode final : public runtime::ModuleNode {
120 public:
DeviceSourceModuleNode(std::string data,std::string fmt,std::unordered_map<std::string,FunctionInfo> fmap,std::string type_key,std::function<std::string (const std::string &)> fget_source)121 DeviceSourceModuleNode(std::string data,
122 std::string fmt,
123 std::unordered_map<std::string, FunctionInfo> fmap,
124 std::string type_key,
125 std::function<std::string(const std::string&)> fget_source)
126 : data_(data),
127 fmt_(fmt),
128 fmap_(fmap),
129 type_key_(type_key),
130 fget_source_(fget_source) {}
131
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)132 PackedFunc GetFunction(
133 const std::string& name,
134 const ObjectPtr<Object>& sptr_to_self) final {
135 LOG(FATAL) << "Source module cannot execute, to get executable module"
136 << " build TVM with \'" << fmt_ << "\' runtime support";
137 return PackedFunc();
138 }
139
GetSource(const std::string & format)140 std::string GetSource(const std::string& format) final {
141 if (fget_source_ != nullptr) {
142 return fget_source_(format);
143 } else {
144 return data_;
145 }
146 }
147
type_key() const148 const char* type_key() const {
149 return type_key_.c_str();
150 }
151
SaveToFile(const std::string & file_name,const std::string & format)152 void SaveToFile(const std::string& file_name,
153 const std::string& format) final {
154 std::string fmt = GetFileFormat(file_name, format);
155 CHECK_EQ(fmt, fmt_)
156 << "Can only save to format=" << fmt_;
157 std::string meta_file = GetMetaFilePath(file_name);
158 SaveMetaDataToFile(meta_file, fmap_);
159 SaveBinaryToFile(file_name, data_);
160 }
161
SaveToBinary(dmlc::Stream * stream)162 void SaveToBinary(dmlc::Stream* stream) final {
163 stream->Write(fmt_);
164 stream->Write(fmap_);
165 stream->Write(data_);
166 }
167
168 private:
169 std::string data_;
170 std::string fmt_;
171 std::unordered_map<std::string, FunctionInfo> fmap_;
172 std::string type_key_;
173 std::function<std::string(const std::string&)> fget_source_;
174 };
175
DeviceSourceModuleCreate(std::string data,std::string fmt,std::unordered_map<std::string,FunctionInfo> fmap,std::string type_key,std::function<std::string (const std::string &)> fget_source)176 runtime::Module DeviceSourceModuleCreate(
177 std::string data,
178 std::string fmt,
179 std::unordered_map<std::string, FunctionInfo> fmap,
180 std::string type_key,
181 std::function<std::string(const std::string&)> fget_source) {
182 auto n = make_object<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
183 return runtime::Module(n);
184 }
185
186 TVM_REGISTER_GLOBAL("module.source_module_create")
187 .set_body_typed(SourceModuleCreate);
188
189 TVM_REGISTER_GLOBAL("module.csource_module_create")
190 .set_body_typed(CSourceModuleCreate);
191 } // namespace codegen
192 } // namespace tvm
193