1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #include <mxnet/rtc.h>
21 #include <typeinfo>
22
23 #include "../common/cuda_utils.h"
24 #include "../operator/operator_common.h"
25
26 #if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
27
28 namespace mxnet {
29 namespace rtc {
30
Chunk(const char * source,const std::vector<std::string> & options,const std::vector<std::string> & exports)31 CudaModule::Chunk::Chunk(
32 const char* source,
33 const std::vector<std::string>& options,
34 const std::vector<std::string>& exports) {
35 NVRTC_CALL(nvrtcCreateProgram(&prog_, source, "source.cu", 0, nullptr, nullptr));
36 for (const auto& i : exports) exports_.insert(i);
37 #if CUDA_VERSION >= 8000
38 for (const auto& func : exports) {
39 NVRTC_CALL(nvrtcAddNameExpression(prog_, func.c_str()));
40 }
41 #else
42 CHECK_EQ(exports.size(), 0)
43 << "Exporting is only supported with CUDA 8.0 and above. "
44 << "For lower version of CUDA, please prepend your kernel defintiions "
45 << "with extern \"C\" instead.";
46 #endif
47 std::vector<const char*> c_options;
48 for (const auto& i : options) c_options.push_back(i.c_str());
49 nvrtcResult compile_res = nvrtcCompileProgram(prog_, c_options.size(), c_options.data());
50 if (compile_res != NVRTC_SUCCESS) {
51 size_t err_size;
52 NVRTC_CALL(nvrtcGetProgramLogSize(prog_, &err_size));
53 std::vector<char> err(err_size);
54 NVRTC_CALL(nvrtcGetProgramLog(prog_, err.data()));
55 LOG(FATAL) << err.data();
56 }
57
58 size_t ptx_size;
59 NVRTC_CALL(nvrtcGetPTXSize(prog_, &ptx_size));
60 ptx_ = new char[ptx_size];
61 NVRTC_CALL(nvrtcGetPTX(prog_, ptx_));
62 }
63
64
~Chunk()65 CudaModule::Chunk::~Chunk() {
66 for (const auto& kv : mod_) {
67 CUDA_DRIVER_CALL(cuModuleUnload(kv.second));
68 }
69 NVRTC_CALL(nvrtcDestroyProgram(&prog_));
70 delete ptx_;
71 }
72
73
GetFunction(const std::string & mangled_name,const Context & ctx)74 CUfunction CudaModule::Chunk::GetFunction(
75 const std::string& mangled_name,
76 const Context& ctx) {
77 CHECK_EQ(ctx.dev_mask(), Context::kGPU)
78 << "CUDA Runtime compilation only supports Nvidia GPU.";
79 auto iter = mod_.find(ctx.dev_id);
80 mxnet::common::cuda::DeviceStore device_store;
81 CUmodule module;
82 if (iter != mod_.end()) {
83 module = iter->second;
84 } else {
85 device_store.SetDevice(ctx.dev_id);
86 CUDA_DRIVER_CALL(cuModuleLoadDataEx(&module, ptx_, 0, 0, 0));
87 mod_[ctx.dev_id] = module;
88 }
89 CUfunction function;
90 auto err = cuModuleGetFunction(&function, module, mangled_name.c_str());
91 if (err == CUDA_ERROR_NOT_FOUND) {
92 LOG(FATAL) << "Cannot find cuda kernel with name '" << mangled_name
93 << "'. Please either prepend kernel definition "
94 << "with 'extern \"C\"' or add its name to exports "
95 << "when creating CudaModule.";
96 }
97 CUDA_DRIVER_CALL(err);
98 return function;
99 }
100
101
GetKernel(const std::string & name,const std::vector<ArgType> & signature)102 std::shared_ptr<CudaModule::Kernel> CudaModule::GetKernel(
103 const std::string& name, const std::vector<ArgType>& signature) {
104 std::string mangled_name = name;
105 #if CUDA_VERSION >= 8000
106 if (ptr_->exports_.count(name)) {
107 const char * c_mangled_name;
108 NVRTC_CALL(nvrtcGetLoweredName(ptr_->prog_, name.c_str(), &c_mangled_name));
109 mangled_name = c_mangled_name;
110 }
111 #endif
112 return std::shared_ptr<Kernel>(new Kernel(ptr_, mangled_name, signature));
113 }
114
115
Kernel(const std::shared_ptr<CudaModule::Chunk> & mod,const std::string & mangled_name,const std::vector<ArgType> & signature)116 CudaModule::Kernel::Kernel(
117 const std::shared_ptr<CudaModule::Chunk>& mod,
118 const std::string& mangled_name,
119 const std::vector<ArgType>& signature)
120 : mangled_name_(mangled_name), signature_(signature), mod_(mod) {
121 }
122
Launch(const Context & ctx,const std::vector<dmlc::any> & args,uint32_t grid_dim_x,uint32_t grid_dim_y,uint32_t grid_dim_z,uint32_t block_dim_x,uint32_t block_dim_y,uint32_t block_dim_z,uint32_t shared_mem)123 void CudaModule::Kernel::Launch(
124 const Context& ctx, const std::vector<dmlc::any>& args,
125 uint32_t grid_dim_x, uint32_t grid_dim_y, uint32_t grid_dim_z,
126 uint32_t block_dim_x, uint32_t block_dim_y, uint32_t block_dim_z,
127 uint32_t shared_mem) {
128 CHECK_EQ(ctx.dev_mask(), Context::kGPU)
129 << "CUDA Runtime compilation only supports Nvidia GPU.";
130
131 auto mod = mod_;
132 auto arg_types = signature();
133
134 CUfunction function;
135 auto iter = func_.find(ctx.dev_id);
136 if (iter != func_.end()) {
137 function = iter->second;
138 } else {
139 function = mod_->GetFunction(mangled_name_, ctx);
140 func_[ctx.dev_id] = function;
141 }
142
143 std::vector<Engine::VarHandle> read_vars, write_vars;
144 for (size_t i = 0; i < arg_types.size(); ++i) {
145 if (!arg_types[i].is_ndarray) continue;
146 const auto& array = dmlc::get<NDArray>(args[i]);
147 CHECK_EQ(array.dtype(), arg_types[i].dtype)
148 << "The i-th argument is expected to be an NDArray of "
149 << op::type_string(arg_types[i].dtype) << " type, but got "
150 << op::type_string(array.dtype()) << " instead.";
151 if (arg_types[i].is_const) {
152 read_vars.emplace_back(array.var());
153 } else {
154 write_vars.emplace_back(array.var());
155 }
156 }
157
158 Engine::Get()->PushSync(
159 [function, mod, args, arg_types, grid_dim_x, grid_dim_y, grid_dim_z,
160 block_dim_x, block_dim_y, block_dim_z, shared_mem](RunContext rctx) {
161 std::vector<void*> p_args;
162 for (size_t i = 0; i < arg_types.size(); ++i) {
163 if (arg_types[i].is_ndarray) {
164 const auto& array = dmlc::get<NDArray>(args[i]);
165 p_args.push_back(reinterpret_cast<void*>(const_cast<void**>(&array.data().dptr_)));
166 } else {
167 MSHADOW_TYPE_SWITCH(arg_types[i].dtype, DType, {
168 const auto& number = dmlc::get<DType>(args[i]);
169 p_args.push_back(const_cast<DType*>(&number));
170 });
171 }
172 }
173
174 mshadow::Stream<gpu> *s = rctx.get_stream<gpu>();
175 CUDA_DRIVER_CALL(cuLaunchKernel(
176 function, grid_dim_x, grid_dim_y, grid_dim_z,
177 block_dim_x, block_dim_y, block_dim_z,
178 shared_mem, s->stream_,
179 p_args.data(), 0));
180 CUDA_CALL(cudaStreamSynchronize(s->stream_));
181 }, ctx, read_vars, write_vars, FnProperty::kNormal, 0,
182 mangled_name_.c_str());
183 }
184
185
186 } // namespace rtc
187 } // namespace mxnet
188
189 #endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
190