1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file opencl_module.cc
22  */
23 #include <dmlc/memory_io.h>
24 #include <tvm/runtime/registry.h>
25 #include <vector>
26 #include <string>
27 #include <unordered_map>
28 #include "opencl_common.h"
29 #include "opencl_module.h"
30 
31 namespace tvm {
32 namespace runtime {
33 
34 class OpenCLWrappedFunc {
35  public:
36   // initialize the OpenCL function.
Init(OpenCLModuleNode * m,ObjectPtr<Object> sptr,OpenCLModuleNode::KTRefEntry entry,std::string func_name,std::vector<size_t> arg_size,const std::vector<std::string> & thread_axis_tags)37   void Init(OpenCLModuleNode* m,
38             ObjectPtr<Object> sptr,
39             OpenCLModuleNode::KTRefEntry entry,
40             std::string func_name,
41             std::vector<size_t> arg_size,
42             const std::vector<std::string>& thread_axis_tags)  {
43     w_ = m->GetGlobalWorkspace().get();
44     m_ = m;
45     sptr_ = sptr;
46     entry_ = entry;
47     func_name_ = func_name;
48     arg_size_ = arg_size;
49     thread_axis_cfg_.Init(arg_size.size(), thread_axis_tags);
50   }
51   // invoke the function with void arguments
operator ()(TVMArgs args,TVMRetValue * rv,void ** void_args) const52   void operator()(TVMArgs args,
53                   TVMRetValue* rv,
54                   void** void_args) const {
55     CHECK(w_->context != nullptr) << "No OpenCL device";
56     cl::OpenCLThreadEntry* t = w_->GetThreadEntry();
57     // get the kernel from thread local kernel table.
58     if (entry_.kernel_id >= t->kernel_table.size()) {
59       t->kernel_table.resize(entry_.kernel_id + 1);
60     }
61     const auto& e = t->kernel_table[entry_.kernel_id];
62     cl_kernel kernel = e.kernel;
63     if (kernel == nullptr || e.version != entry_.version) {
64       kernel = m_->InstallKernel(w_, t, func_name_, entry_);
65     }
66     // setup arguments.
67     for (cl_uint i = 0; i < arg_size_.size(); ++i) {
68       OPENCL_CALL(clSetKernelArg(kernel, i, arg_size_[i], void_args[i]));
69     }
70     cl_command_queue queue = w_->GetQueue(t->context);
71     ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
72     cl_uint work_dim = static_cast<cl_uint>(thread_axis_cfg_.work_dim());
73     for (cl_uint i = 0; i < work_dim; ++i) {
74       wl.work_size[i] *= wl.work_size[i + 3];
75     }
76     // launch kernel
77     OPENCL_CALL(clEnqueueNDRangeKernel(
78         queue, kernel, work_dim, nullptr,
79         wl.work_size,
80         wl.work_size + 3,
81         0, nullptr, nullptr));
82   }
83 
84  private:
85   // global workspace.
86   cl::OpenCLWorkspace* w_;
87   // The module
88   OpenCLModuleNode* m_;
89   // resource handle
90   ObjectPtr<Object> sptr_;
91   // global kernel id in the kernel table.
92   OpenCLModuleNode::KTRefEntry entry_;
93   // The name of the function.
94   std::string func_name_;
95   // convert code for void argument
96   std::vector<size_t> arg_size_;
97   // thread axis config
98   ThreadAxisConfig thread_axis_cfg_;
99 };
100 
~OpenCLModuleNode()101 OpenCLModuleNode::~OpenCLModuleNode() {
102   {
103     // free the kernel ids in global table.
104     std::lock_guard<std::mutex> lock(workspace_->mu);
105     for (auto& kv : kid_map_) {
106       workspace_->free_kernel_ids.push_back(kv.second.kernel_id);
107     }
108   }
109   // free the kernels
110   for (cl_kernel k : kernels_) {
111     OPENCL_CALL(clReleaseKernel(k));
112   }
113   if (program_) {
114     OPENCL_CALL(clReleaseProgram(program_));
115   }
116 }
117 
GetGlobalWorkspace()118 const std::shared_ptr<cl::OpenCLWorkspace>& OpenCLModuleNode::GetGlobalWorkspace() {
119   return cl::OpenCLWorkspace::Global();
120 }
121 
GetFunction(const std::string & name,const ObjectPtr<Object> & sptr_to_self)122 PackedFunc OpenCLModuleNode::GetFunction(
123     const std::string& name,
124     const ObjectPtr<Object>& sptr_to_self) {
125   CHECK_EQ(sptr_to_self.get(), this);
126   CHECK_NE(name, symbol::tvm_module_main)
127       << "Device function do not have main";
128   auto it = fmap_.find(name);
129   if (it == fmap_.end()) return PackedFunc();
130   const FunctionInfo& info = it->second;
131   OpenCLWrappedFunc f;
132   std::vector<size_t> arg_size(info.arg_types.size());
133   for (size_t i = 0; i < info.arg_types.size(); ++i) {
134     TVMType t = info.arg_types[i];
135     CHECK_EQ(t.lanes, 1U);
136     if (t.code == kHandle) {
137       // specially store pointer type size in OpenCL driver
138       arg_size[i] = sizeof(void*);
139     } else {
140       uint32_t bits = t.bits;
141       CHECK_EQ(bits % 8, 0U);
142       arg_size[i] = bits / 8;
143     }
144   }
145   // initialize the wrapped func.
146   f.Init(this, sptr_to_self, kid_map_.at(name),
147          name, arg_size, info.thread_axis_tags);
148   return PackFuncVoidAddr(f, info.arg_types);
149 }
150 
SaveToFile(const std::string & file_name,const std::string & format)151 void OpenCLModuleNode::SaveToFile(const std::string& file_name,
152                                   const std::string& format) {
153   std::string fmt = GetFileFormat(file_name, format);
154   CHECK_EQ(fmt, fmt_)
155       << "Can only save to format=" << fmt_;
156   std::string meta_file = GetMetaFilePath(file_name);
157   SaveMetaDataToFile(meta_file, fmap_);
158   SaveBinaryToFile(file_name, data_);
159 }
160 
SaveToBinary(dmlc::Stream * stream)161 void OpenCLModuleNode::SaveToBinary(dmlc::Stream* stream) {
162   stream->Write(fmt_);
163   stream->Write(fmap_);
164   stream->Write(data_);
165 }
166 
GetSource(const std::string & format)167 std::string OpenCLModuleNode::GetSource(const std::string& format) {
168   if (format == fmt_) return data_;
169   if (fmt_ == "cl") {
170     return data_;
171   } else {
172     return source_;
173   }
174 }
175 
Init()176 void OpenCLModuleNode::Init() {
177   workspace_ = GetGlobalWorkspace();
178   workspace_->Init();
179   device_built_flag_.resize(workspace_->devices.size(), false);
180   // initialize the kernel id, need to lock global table.
181   std::lock_guard<std::mutex> lock(workspace_->mu);
182   for (const auto& kv : fmap_) {
183     const std::string& key = kv.first;
184     KTRefEntry e;
185     if (workspace_->free_kernel_ids.size() != 0) {
186       e.kernel_id = workspace_->free_kernel_ids.back();
187       workspace_->free_kernel_ids.pop_back();
188     } else {
189       e.kernel_id = workspace_->num_registered_kernels++;
190     }
191     e.version = workspace_->timestamp++;
192     kid_map_[key] = e;
193   }
194 }
195 
InstallKernel(cl::OpenCLWorkspace * w,cl::OpenCLThreadEntry * t,const std::string & func_name,const KTRefEntry & e)196 cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w,
197                                           cl::OpenCLThreadEntry* t,
198                                           const std::string& func_name,
199                                           const KTRefEntry& e) {
200   std::lock_guard<std::mutex> lock(build_lock_);
201   int device_id = t->context.device_id;
202   if (!device_built_flag_[device_id]) {
203     // create program
204     if (fmt_ == "cl") {
205       if (program_ == nullptr) {
206         const char* s = data_.c_str();
207         size_t len = data_.length();
208         cl_int err;
209         program_ = clCreateProgramWithSource(w->context, 1, &s, &len, &err);
210         OPENCL_CHECK_ERROR(err);
211       }
212     } else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") {
213       const unsigned char* s = (const unsigned char *)data_.c_str();
214       size_t len = data_.length();
215       cl_int err;
216       cl_device_id dev = w->devices[device_id];
217       program_ = clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, NULL, &err);
218       OPENCL_CHECK_ERROR(err);
219     } else {
220       LOG(FATAL) << "Unknown OpenCL format " << fmt_;
221     }
222     // build program
223     cl_int err;
224     cl_device_id dev = w->devices[device_id];
225     err = clBuildProgram(program_, 1, &dev, nullptr, nullptr, nullptr);
226     if (err != CL_SUCCESS) {
227       size_t len;
228       std::string log;
229       clGetProgramBuildInfo(
230           program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len);
231       log.resize(len);
232       clGetProgramBuildInfo(
233           program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr);
234       LOG(FATAL) << "OpenCL build error for device=" << dev << log;
235     }
236     device_built_flag_[device_id] = true;
237   }
238   // build kernel
239   cl_int err;
240   cl_kernel kernel = clCreateKernel(program_, func_name.c_str(), &err);
241   OPENCL_CHECK_ERROR(err);
242   t->kernel_table[e.kernel_id].kernel = kernel;
243   t->kernel_table[e.kernel_id].version = e.version;
244   kernels_.push_back(kernel);
245   return kernel;
246 }
247 
OpenCLModuleCreate(std::string data,std::string fmt,std::unordered_map<std::string,FunctionInfo> fmap,std::string source)248 Module OpenCLModuleCreate(
249     std::string data,
250     std::string fmt,
251     std::unordered_map<std::string, FunctionInfo> fmap,
252     std::string source) {
253   auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);
254   n->Init();
255   return Module(n);
256 }
257 
258 // Load module from module.
OpenCLModuleLoadFile(const std::string & file_name,const std::string & format)259 Module OpenCLModuleLoadFile(const std::string& file_name,
260                             const std::string& format) {
261   std::string data;
262   std::unordered_map<std::string, FunctionInfo> fmap;
263   std::string fmt = GetFileFormat(file_name, format);
264   std::string meta_file = GetMetaFilePath(file_name);
265   LoadBinaryFromFile(file_name, &data);
266   LoadMetaDataFromFile(meta_file, &fmap);
267   return OpenCLModuleCreate(data, fmt, fmap, std::string());
268 }
269 
OpenCLModuleLoadBinary(void * strm)270 Module OpenCLModuleLoadBinary(void* strm) {
271   dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
272   std::string data;
273   std::unordered_map<std::string, FunctionInfo> fmap;
274   std::string fmt;
275   stream->Read(&fmt);
276   stream->Read(&fmap);
277   stream->Read(&data);
278   return OpenCLModuleCreate(data, fmt, fmap, std::string());
279 }
280 
281 TVM_REGISTER_GLOBAL("module.loadfile_cl")
282 .set_body_typed(OpenCLModuleLoadFile);
283 
284 TVM_REGISTER_GLOBAL("module.loadfile_clbin")
285 .set_body_typed(OpenCLModuleLoadFile);
286 
287 TVM_REGISTER_GLOBAL("module.loadbinary_opencl")
288 .set_body_typed(OpenCLModuleLoadBinary);
289 }  // namespace runtime
290 }  // namespace tvm
291