1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file source_module.cc
22  * \brief Source code module, only for viewing
23  */
24 #include <tvm/runtime/packed_func.h>
25 #include "codegen_source_base.h"
26 #include "../runtime/file_util.h"
27 #include "../runtime/meta_data.h"
28 
29 namespace tvm {
30 namespace codegen {
31 
32 using runtime::TVMArgs;
33 using runtime::TVMRetValue;
34 using runtime::PackedFunc;
35 
36 using runtime::GetFileFormat;
37 using runtime::GetMetaFilePath;
38 using runtime::FunctionInfo;
39 using runtime::SaveBinaryToFile;
40 
41 // Simulator function
42 class SourceModuleNode : public runtime::ModuleNode {
43  public:
SourceModuleNode(std::string code,std::string fmt)44   SourceModuleNode(std::string code,
45                    std::string fmt)
46       : code_(code), fmt_(fmt) {}
type_key() const47   const char* type_key() const {
48     return "source";
49   }
50 
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)51   PackedFunc GetFunction(
52       const std::string& name,
53       const ObjectPtr<Object>& sptr_to_self) final {
54     LOG(FATAL) << "Source module cannot execute, to get executable module"
55                << " build TVM with \'" << fmt_ << "\' runtime support";
56     return PackedFunc();
57   }
58 
GetSource(const std::string & format)59   std::string GetSource(const std::string& format) final {
60     return code_;
61   }
62 
63  protected:
64   std::string code_;
65   std::string fmt_;
66 };
67 
SourceModuleCreate(std::string code,std::string fmt)68 runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
69   auto n = make_object<SourceModuleNode>(code, fmt);
70   return runtime::Module(n);
71 }
72 
73 // Simulator function
74 class CSourceModuleNode : public runtime::ModuleNode {
75  public:
CSourceModuleNode(std::string code,std::string fmt)76   CSourceModuleNode(std::string code,
77                    std::string fmt)
78       : code_(code), fmt_(fmt) {}
type_key() const79   const char* type_key() const {
80     return "c";
81   }
82 
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)83   PackedFunc GetFunction(
84       const std::string& name,
85       const ObjectPtr<Object>& sptr_to_self) final {
86     LOG(FATAL) << "C Source module cannot execute, to get executable module"
87                << " build TVM with \'" << fmt_ << "\' runtime support";
88     return PackedFunc();
89   }
90 
GetSource(const std::string & format)91   std::string GetSource(const std::string& format) final {
92     return code_;
93   }
94 
SaveToFile(const std::string & file_name,const std::string & format)95   void SaveToFile(const std::string& file_name,
96                   const std::string& format) final {
97     std::string fmt = GetFileFormat(file_name, format);
98     std::string meta_file = GetMetaFilePath(file_name);
99     if (fmt == "cc") {
100       CHECK_NE(code_.length(), 0);
101       SaveBinaryToFile(file_name, code_);
102     } else {
103       CHECK_EQ(fmt, fmt_)
104           << "Can only save to format=" << fmt_;
105     }
106   }
107 
108  protected:
109   std::string code_;
110   std::string fmt_;
111 };
112 
CSourceModuleCreate(std::string code,std::string fmt)113 runtime::Module CSourceModuleCreate(std::string code, std::string fmt) {
114   auto n = make_object<CSourceModuleNode>(code, fmt);
115   return runtime::Module(n);
116 }
117 
118 // supports limited save without cross compile
119 class DeviceSourceModuleNode final : public runtime::ModuleNode {
120  public:
DeviceSourceModuleNode(std::string data,std::string fmt,std::unordered_map<std::string,FunctionInfo> fmap,std::string type_key,std::function<std::string (const std::string &)> fget_source)121   DeviceSourceModuleNode(std::string data,
122                          std::string fmt,
123                          std::unordered_map<std::string, FunctionInfo> fmap,
124                          std::string type_key,
125                          std::function<std::string(const std::string&)> fget_source)
126     : data_(data),
127       fmt_(fmt),
128       fmap_(fmap),
129       type_key_(type_key),
130       fget_source_(fget_source) {}
131 
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)132   PackedFunc GetFunction(
133         const std::string& name,
134         const ObjectPtr<Object>& sptr_to_self) final {
135     LOG(FATAL) << "Source module cannot execute, to get executable module"
136                << " build TVM with \'" << fmt_ << "\' runtime support";
137     return PackedFunc();
138   }
139 
GetSource(const std::string & format)140   std::string GetSource(const std::string& format) final {
141     if (fget_source_ != nullptr) {
142       return fget_source_(format);
143     } else {
144       return data_;
145     }
146   }
147 
type_key() const148   const char* type_key() const {
149     return type_key_.c_str();
150   }
151 
SaveToFile(const std::string & file_name,const std::string & format)152   void SaveToFile(const std::string& file_name,
153                   const std::string& format) final {
154     std::string fmt = GetFileFormat(file_name, format);
155     CHECK_EQ(fmt, fmt_)
156         << "Can only save to format=" << fmt_;
157     std::string meta_file = GetMetaFilePath(file_name);
158     SaveMetaDataToFile(meta_file, fmap_);
159     SaveBinaryToFile(file_name, data_);
160   }
161 
SaveToBinary(dmlc::Stream * stream)162   void SaveToBinary(dmlc::Stream* stream) final {
163     stream->Write(fmt_);
164     stream->Write(fmap_);
165     stream->Write(data_);
166   }
167 
168  private:
169   std::string data_;
170   std::string fmt_;
171   std::unordered_map<std::string, FunctionInfo> fmap_;
172   std::string type_key_;
173   std::function<std::string(const std::string&)> fget_source_;
174 };
175 
DeviceSourceModuleCreate(std::string data,std::string fmt,std::unordered_map<std::string,FunctionInfo> fmap,std::string type_key,std::function<std::string (const std::string &)> fget_source)176 runtime::Module DeviceSourceModuleCreate(
177     std::string data,
178     std::string fmt,
179     std::unordered_map<std::string, FunctionInfo> fmap,
180     std::string type_key,
181     std::function<std::string(const std::string&)> fget_source) {
182   auto n = make_object<DeviceSourceModuleNode>(data, fmt, fmap, type_key, fget_source);
183   return runtime::Module(n);
184 }
185 
186 TVM_REGISTER_GLOBAL("module.source_module_create")
187 .set_body_typed(SourceModuleCreate);
188 
189 TVM_REGISTER_GLOBAL("module.csource_module_create")
190 .set_body_typed(CSourceModuleCreate);
191 }  // namespace codegen
192 }  // namespace tvm
193