1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 // Additional use of MXNET_USE_CUDA is not needed to guard a '.cu' file.
21 #if MXNET_ENABLE_CUDA_RTC
22
23 #include <sys/stat.h>
24 #include <nvrtc.h>
25 #include <cuda.h>
26 #include <nnvm/pass_functions.h>
27 #include <algorithm>
28 #include <mutex>
29 #include "./fused_op.h"
30 #include "./fused_op-inl.h"
31 #include "../operator_common.h"
32 #include "../elemwise_op_common.h"
33 #include "../../executor/exec_pass.h"
34 #include "../../common/cuda_utils.h"
35
36 namespace mxnet {
37
38 namespace {
39
mshadowTypeToString(int type)40 inline std::string mshadowTypeToString(int type) {
41 switch (type) {
42 case mshadow::kFloat32:
43 return "float";
44 case mshadow::kFloat64:
45 return "double";
46 case mshadow::kFloat16:
47 return "half";
48 case mshadow::kUint8:
49 return "unsigned char";
50 case mshadow::kInt8:
51 return "char";
52 case mshadow::kInt32:
53 return "int";
54 case mshadow::kInt64:
55 return "long long";
56 case mshadow::kBool:
57 return "bool";
58 default:
59 LOG(FATAL) << "Unknown type enum " << type;
60 }
61 return "";
62 }
63
mshadowTypeToVectorLength(int type)64 inline int mshadowTypeToVectorLength(int type) {
65 switch (type) {
66 case mshadow::kFloat32:
67 return 1;
68 case mshadow::kFloat64:
69 return 1;
70 case mshadow::kFloat16:
71 return 2;
72 case mshadow::kUint8:
73 return 4;
74 case mshadow::kInt8:
75 return 4;
76 case mshadow::kInt32:
77 return 1;
78 case mshadow::kInt64:
79 return 1;
80 case mshadow::kBool:
81 return 4 / sizeof(bool);
82 default:
83 LOG(FATAL) << "Unknown type enum " << type;
84 }
85 return 0;
86 }
87
replaceString(std::string * input,const std::string old,const std::string repl)88 inline void replaceString(std::string *input, const std::string old, const std::string repl) {
89 size_t pos = 0;
90 while ((pos = input->find(old, pos)) != std::string::npos) {
91 input->replace(pos, old.size(), repl);
92 pos += repl.size();
93 }
94 }
95
splitStringToVector(const std::string & input,const std::string def)96 inline std::vector<int> splitStringToVector(const std::string& input, const std::string def) {
97 size_t pos_start = 0, pos_end;
98 const std::string& s = input.substr(1, input.length()-2);
99 std::vector<int> res;
100
101 auto convert_token = [def](std::string token){
102 if (token == def) {
103 return 0;
104 }
105 return std::stoi(token);
106 };
107
108 while ((pos_end = s.find(",", pos_start)) != std::string::npos) {
109 std::string token = s.substr(pos_start, pos_end - pos_start);
110 pos_start = pos_end + 1;
111 if (token.length() > 0) {
112 res.push_back(convert_token(token));
113 }
114 }
115
116 if (pos_start < s.length()) {
117 res.push_back(convert_token(s.substr(pos_start)));
118 }
119 return res;
120 }
121
ParseOpDescription(const std::vector<std::string> & op_desc,const std::map<std::pair<int,int>,std::string> & variables,const nnvm::IndexedGraph::Node & node)122 std::string ParseOpDescription(const std::vector<std::string>& op_desc,
123 const std::map<std::pair<int, int>, std::string>& variables,
124 const nnvm::IndexedGraph::Node& node) {
125 const auto* source = node.source;
126 std::string fmt = op_desc[0];
127 for (size_t j = 1; j < op_desc.size(); ++j) {
128 const std::string& desc = op_desc[j];
129 std::string sub;
130 if (desc[0] == '_') {
131 // Argument
132 const int arg_id = std::stoi(desc.substr(1));
133 sub = variables.at({node.inputs[arg_id].node_id, node.inputs[arg_id].index});
134 } else {
135 sub = source->attrs.dict.at(desc);
136 }
137 size_t pos = fmt.find("%");
138 CHECK_NE(pos, std::string::npos);
139 fmt.replace(pos, 1, sub);
140 }
141 return fmt;
142 }
143
AddShape(const mxnet::TShape & shape,std::vector<std::vector<int>> * shapes)144 void AddShape(const mxnet::TShape& shape,
145 std::vector<std::vector<int>>* shapes) {
146 // We need alignment to 8 bytes for size_t in the Shape struct
147 // so if ndim is odd, there will be 4B of padding
148 int ndim = shape.ndim();
149 const int offset = ndim % 2 == 0 ? 2 : 3;
150 shapes->push_back(std::vector<int>(ndim + offset));
151 std::vector<int>& tensor_shapes = shapes->back();
152 size_t total_size = 1;
153 for (int i = ndim-1; i >= 0; i--) {
154 tensor_shapes[i] = shape[i];
155 total_size *= shape[i];
156 }
157 size_t * shape_size_ptr = reinterpret_cast<size_t*>(&tensor_shapes[ndim + offset - 2]);
158 *shape_size_ptr = total_size;
159 }
160
AddPointerAndShape(const TBlob & data,std::vector<void * > * ptrs,std::vector<std::vector<int>> * shapes,mshadow::Stream<gpu> * s)161 void AddPointerAndShape(const TBlob& data,
162 std::vector<void*> *ptrs,
163 std::vector<std::vector<int>>* shapes,
164 mshadow::Stream<gpu> * s) {
165 using namespace mshadow;
166 MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
167 Tensor<gpu, 1, DType> tensor = data.FlatTo1D<gpu, DType>(s);
168 ptrs->push_back(tensor.dptr_);
169 AddShape(data.shape_, shapes);
170 });
171 }
172
173 // Obtain compilation log from the program.
GetCompileLog(nvrtcProgram program)174 std::string GetCompileLog(nvrtcProgram program) {
175 size_t log_size_including_null;
176 NVRTC_CALL(nvrtcGetProgramLogSize(program, &log_size_including_null));
177 // For most std::string implementations, this is probably 1 char bigger than needed. OK though.
178 std::string log(log_size_including_null, '\0');
179 NVRTC_CALL(nvrtcGetProgramLog(program, &log[0]));
180 // Make sure the string reflects the true size (so minus the null terminator).
181 log.resize(log_size_including_null - 1);
182 return log;
183 }
184
185 // Obtain compilation result (ptx assembly) from the program.
GetPtx(nvrtcProgram program)186 std::string GetPtx(nvrtcProgram program) {
187 size_t ptx_size_including_null;
188 NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size_including_null));
189 // For most std::string implementations, this is probably 1 char bigger than needed. OK though.
190 std::string ptx(ptx_size_including_null, '\0');
191 NVRTC_CALL(nvrtcGetPTX(program, &ptx[0]));
192 // Make sure the string reflects the true size (so minus the null terminator).
193 ptx.resize(ptx_size_including_null - 1);
194 return ptx;
195 }
196
197 } // namespace
198
GenerateCode(const std::vector<OpReqType> & req,const std::vector<int> & in_dtypes,const std::vector<int> & out_dtypes,const std::vector<int> & in_ndims,const std::vector<int> & out_ndims,const mxnet::ShapeVector & node_shapes,const std::vector<int> & node_dtypes,const int nvec,const std::string & kernel_name,std::vector<uint32_t> * check_shapes)199 std::string FusedOp::GenerateCode(const std::vector<OpReqType> &req,
200 const std::vector<int> &in_dtypes,
201 const std::vector<int> &out_dtypes,
202 const std::vector<int> &in_ndims,
203 const std::vector<int> &out_ndims,
204 const mxnet::ShapeVector &node_shapes,
205 const std::vector<int> &node_dtypes,
206 const int nvec,
207 const std::string &kernel_name,
208 std::vector<uint32_t>* check_shapes) {
209 const auto& g = subgraph_.indexed_graph();
210 std::string code = "";
211 int temp_name_counter = 0;
212 using NodeEntry = nnvm::IndexedGraph::NodeEntry;
213 std::map<std::pair<int, int>, std::string> variables;
214 std::map<int, int> load_index;
215 bool check_shapes_compile = true;
216
217 std::vector<uint32_t> outputs(g.num_nodes());
218
219 for (size_t i = 0; i < g.num_nodes(); ++i) {
220 const auto& node = g[i];
221 if (node.source != nullptr) {
222 outputs[i] = node.source->num_outputs();
223 } else {
224 outputs[i] = 0;
225 }
226 }
227
228 for (size_t i = 0; i < g.num_nodes(); ++i) {
229 const auto& node = g[i];
230 const auto* source = node.source;
231 if (source != nullptr) {
232 if (source->is_variable()) {
233 load_index[i] = 1;
234 } else {
235 std::string op_name = source->op()->name;
236 if (fusion::slice_ops.find(op_name) != fusion::slice_ops.end()) {
237 load_index[node.inputs[0].node_id] = 0;
238 }
239 }
240 }
241 }
242 for (size_t i = 0; i < g.num_nodes(); ++i) {
243 const auto& node = g[i];
244 const auto* source = node.source;
245 if (source != nullptr) {
246 if (source->is_variable()) {
247 if (load_index[i]) {
248 const auto& var_name = source->attrs.name;
249 code += "const auto vec_" + var_name + " = op::load_index<nvec>(" +
250 var_name + ", offset, " + var_name + "_shape);\n";
251 variables[{i, 0}] = var_name;
252 }
253 CHECK_EQ(outputs[i], 1);
254 } else {
255 std::string op_name = source->op()->name;
256 if (fusion::slice_ops.find(op_name) != fusion::slice_ops.end()) {
257 int node_id = node.inputs[0].node_id;
258 const uint32_t input_entry_id = g.entry_id(node.inputs[0]);
259 const auto& shape = node_shapes[input_entry_id];
260 const int ndim = shape.ndim();
261 const auto& var_name = g[node_id].source->attrs.name;
262 const auto vec_name = "vec_" + var_name + "_" + std::to_string(i);
263 load_index[node_id] = 0;
264 auto parse_tuple = [ndim](const std::string& input, const std::string& def) {
265 std::string out = input;
266 replaceString(&out, " ", "");
267 if (out[0] == '(') {
268 replaceString(&out, "(", "{");
269 replaceString(&out, ")", "}");
270 // First check if out is ()
271 int n_entries = out.size() != 2;
272 for (size_t i = 1; i < out.size() - 1; ++i) {
273 if (out[i] == ',') {
274 ++n_entries;
275 }
276 }
277 if (n_entries != ndim) {
278 out.pop_back();
279 for (int i = n_entries; i < ndim; ++i) {
280 out += "," + def;
281 }
282 out += "}";
283 }
284 } else {
285 out = "{" + std::move(out);
286 for (int i = 1; i < ndim; ++i) {
287 out += "," + def;
288 }
289 out += "}";
290 }
291 replaceString(&out, "None", def);
292 return out;
293 };
294 auto parse_int = [](const std::string& input, const std::string& def) {
295 std::string out = input;
296 replaceString(&out, " ", "");
297 replaceString(&out, "None", def);
298 return out;
299 };
300 auto build_tuple = [ndim](int axis, const std::string str, const std::string def) {
301 if (axis < 0 &&
302 axis >= -ndim) {
303 axis += ndim;
304 }
305 if (axis < 0 || axis >= ndim) {
306 LOG(FATAL) << "Axis " << axis << " is out of bounds for array of dimension " << ndim;
307 }
308 std::string tuple = "{";
309 for (int i = 0; i < axis; i++) {
310 tuple += def + ",";
311 }
312 tuple += str;
313 for (int i = axis + 1; i < ndim; i++) {
314 tuple += "," + def;
315 }
316 tuple += "}";
317 return tuple;
318 };
319 auto check_tuple = [ndim, nvec](const std::string str) {
320 std::vector<int> tuple = splitStringToVector(str, "INT_MAX");
321 if (tuple[ndim-1] % nvec == 0) {
322 return true;
323 }
324 return false;
325 };
326 auto build_string_end = [i, ndim, var_name](std::string* code) {
327 std::string end_var_name = var_name + "_" + std::to_string(i) + "_end";
328 *code += "op::Shape<" + std::to_string(ndim) + "> "+ end_var_name + ";\n";
329 *code += end_var_name + ".set(INT_MAX);\n";
330 return end_var_name;
331 };
332 std::string begin;
333 std::string end;
334 if (op_name == "broadcast_like" || op_name == "slice_like") {
335 uint32_t like_id = g.entry_id(i, 0);
336 begin = build_tuple(0, "0", "0");
337 std::string extra_var_name = "extra_" + std::to_string(like_id) + "_shape";
338 if (std::find(extra_shape_args_.begin(), extra_shape_args_.end(), like_id) ==
339 extra_shape_args_.end()) {
340 extra_shape_args_.push_back(like_id);
341 }
342 if (check_shapes) {
343 check_shapes->push_back(like_id);
344 check_shapes->push_back(input_entry_id);
345 }
346 end = extra_var_name;
347 } else {
348 if (op_name == "slice_axis") {
349 begin = parse_int(source->attrs.dict.at("begin"), "0");
350 end = parse_int(source->attrs.dict.at("end"), "INT_MAX");
351 int axis = std::stoi(source->attrs.dict.at("axis"));
352 begin = build_tuple(axis, begin, "0");
353 end = build_tuple(axis, end, "INT_MAX");
354 } else {
355 begin = parse_tuple(source->attrs.dict.at("begin"), "0");
356 end = parse_tuple(source->attrs.dict.at("end"), "INT_MAX");
357 }
358 if (check_shapes) {
359 if (check_tuple(begin) && check_tuple(end)) {
360 check_shapes->push_back(input_entry_id);
361 } else {
362 check_shapes_compile = false;
363 }
364 }
365 }
366 std::string slice_func = "load_slice";
367 if (!check_shapes) {
368 slice_func = "fast_" + slice_func;
369 }
370 code += "const auto " + vec_name + " = op::" + slice_func + "<nvec>(" +
371 var_name + ", " + var_name + "_shape," + begin +
372 "," + end + ", offset);\n";
373 CHECK_EQ(outputs[i], 1);
374 variables[{i, 0}] = vec_name;
375 continue;
376 }
377 }
378 }
379 }
380
381 if (!check_shapes_compile) {
382 check_shapes->clear();
383 }
384
385 size_t counter = 0;
386 for (const auto& entry : g.outputs()) {
387 std::string var_name = "output" + std::to_string(counter);
388 code += "op::VectorType<DType_" + var_name + \
389 ", nvec> vec_" + var_name + ";\n";
390 ++counter;
391 }
392
393 code += "for (int j = 0; j < nvec; j++ ) {\n";
394
395
396 for (size_t i = 0; i < g.num_nodes(); ++i) {
397 const auto& node = g[i];
398 const auto* source = node.source;
399 if (source != nullptr) {
400 std::string var_name = "temp" + std::to_string(temp_name_counter++);
401 if (source->is_variable()) {
402 if (load_index[i]) {
403 code += "const auto " + var_name + " = op::load(vec_" +
404 variables[{i, 0}] + ".x[j]);\n";
405 CHECK_EQ(outputs[i], 1);
406 variables[{i, 0}] = var_name;
407 }
408 } else {
409 std::string op_name = source->op()->name;
410 if (fusion::ops_desc.find(op_name) != fusion::ops_desc.end()) {
411 const std::vector<std::vector<std::string>>& op_descs =
412 fusion::ops_desc.at(op_name);
413 CHECK_EQ(outputs[i], op_descs.size());
414 size_t count = 0;
415 for (const auto& op_desc : op_descs) {
416 var_name = "temp" + std::to_string(temp_name_counter++);
417 const std::string& fmt = ParseOpDescription(op_desc, variables, node);
418 code += "const auto " + var_name + " = " + fmt + ";\n";
419 variables[{i, count}] = var_name;
420 ++count;
421 }
422 continue;
423 }
424
425 if (fusion::slice_ops.find(op_name) != fusion::slice_ops.end()) {
426 code += "const auto " + var_name + " = op::load(" + variables[{i, 0}] + ".x[j]);\n";
427 variables[{i, 0}] = var_name;
428 continue;
429 }
430
431
432 // Special cases with variable number
433 // of inputs/outputs, listed in
434 // fusion::variable_io_ops
435 if (op_name == "add_n") {
436 CHECK_EQ(outputs[i], 1);
437 const auto& arg = variables[{node.inputs[0].node_id, node.inputs[0].index}];
438 code += "auto " + var_name + " = " + arg + ";\n";
439 for (size_t inp = 1; inp < node.inputs.size(); ++inp) {
440 const auto& temp_arg = variables[{node.inputs[inp].node_id, node.inputs[inp].index}];
441 code += var_name + " = op::add(" + var_name + ", " + temp_arg + ");\n";
442 }
443 variables[{i, 0}] = var_name;
444 continue;
445 }
446
447 if (op_name == "_backward_Activation") {
448 CHECK_EQ(outputs[i], 1);
449 std::string act_type = node.source->attrs.dict.at("act_type");
450 std::string rhs, lhs;
451 rhs = variables[{node.inputs[0].node_id, node.inputs[0].index}];
452 if (act_type == "relu" ||
453 act_type == "sigmoid" ||
454 act_type == "tanh") {
455 lhs = variables[{node.inputs[1].node_id, node.inputs[1].index}];
456 } else {
457 lhs = variables[{node.inputs[2].node_id, node.inputs[2].index}];
458 }
459 code += "const auto " + var_name + " = op::backward_" + act_type +
460 "(" + lhs + ", " + rhs + ");\n";
461
462 variables[{i, 0}] = var_name;
463 continue;
464 }
465
466 if (op_name == "amp_multicast" || op_name == "_backward_amp_multicast") {
467 CHECK_EQ(outputs[i], node.inputs.size());
468 for (size_t counter = 0; counter < outputs[i]; ++counter) {
469 const auto& input = node.inputs[counter];
470 var_name = "temp" + std::to_string(temp_name_counter++);
471 const auto& arg = variables[{input.node_id, input.index}];
472 code += "const auto " + var_name + " = " + arg + ";\n";
473 variables[{i, counter}] = var_name;
474 }
475 continue;
476 }
477
478 if (op_name == "_backward_cast") {
479 CHECK_EQ(outputs[i], 1);
480 const int output_type = node_dtypes[g.entry_id(i, 0)];
481 const auto& arg = variables[{node.inputs[0].node_id, node.inputs[0].index}];
482 code += "const auto " + var_name + " = op::cast<" + mshadowTypeToString(output_type) +
483 ">(" + arg + ");\n";
484 variables[{i, 0}] = var_name;
485 continue;
486 }
487
488 // LeakyReLU, look for act_type
489 if (op_name == "LeakyReLU") {
490 std::string act_type = node.source->attrs.dict.at("act_type");
491 const std::vector<std::vector<std::string>>& op_descs =
492 fusion::LeakyReLU_ops.at(act_type);
493 if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_ops.end()) {
494 CHECK_EQ(outputs[i], op_descs.size());
495 size_t count = 0;
496 for (const auto& op_desc : op_descs) {
497 var_name = "temp" + std::to_string(temp_name_counter++);
498 const std::string& fmt = ParseOpDescription(op_desc, variables, node);
499 code += "const auto " + var_name + " = " + fmt + ";\n";
500 variables[{i, count}] = var_name;
501 ++count;
502 }
503 continue;
504 }
505 }
506 if (op_name == "_backward_LeakyReLU") {
507 std::string act_type = node.source->attrs.dict.at("act_type");
508 const std::vector<std::vector<std::string>>& op_descs =
509 fusion::LeakyReLU_bwd_ops.at(act_type);
510 if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_bwd_ops.end()) {
511 CHECK_EQ(outputs[i], op_descs.size());
512 size_t count = 0;
513 for (const auto& op_desc : op_descs) {
514 var_name = "temp" + std::to_string(temp_name_counter++);
515 const std::string& fmt = ParseOpDescription(op_desc, variables, node);
516 code += "const auto " + var_name + " = " + fmt + ";\n";
517 variables[{i, count}] = var_name;
518 ++count;
519 }
520 continue;
521 }
522 }
523
524 LOG(FATAL) << "Unrecognized op " + op_name;
525 }
526 } else {
527 LOG(FATAL) << "Encountered node with NULL source.";
528 }
529 }
530
531 counter = 0;
532 for (const auto& entry : g.outputs()) {
533 const std::string& var = variables[{entry.node_id, entry.index}];
534 const auto var_name = "output" + std::to_string(counter);
535 code += "vec_" + var_name + ".x[j] = op::store("+ var +", " + var_name + ");\n";
536 ++counter;
537 }
538
539 code += "}\n";
540
541 counter = 0;
542
543 for (const auto& entry : g.outputs()) {
544 const std::string& var = variables[{entry.node_id, entry.index}];
545 if (req[counter] == kWriteTo || req[counter] == kWriteInplace) {
546 const auto var_name = "output" + std::to_string(counter);
547 code += "op::store_index(vec_" + var_name + ", i, " + var_name + ", " +
548 var_name + "_shape);\n";
549 } else if (req[counter] == kAddTo) {
550 const auto var_name = "output" + std::to_string(counter);
551 code += "op::store_add_index(vec_" + var_name + ", i, " + var_name + ", " +
552 var_name + "_shape);\n";
553 } else if (req[counter] == kNullOp) {
554 // nullptr req, do not do anything
555 } else {
556 LOG(FATAL) << "Encountered unexpected req.";
557 }
558 ++counter;
559 }
560
561 // Add boilerplate and type information
562 std::string kernel_params = "";
563 std::string tensor_params = "";
564 nnvm::Symbol sym;
565 sym.outputs = subgraph_.outputs;
566 const std::vector<std::string> input_names = sym.ListInputNames(nnvm::Symbol::kAll);
567 size_t num_params = in_dtypes.size() + out_dtypes.size();
568 size_t i = 0;
569 std::string aux_code = "static const int nvec = " + std::to_string(nvec) + ";\n";
570
571 for (const auto &shape_id : extra_shape_args_) {
572 std::string shape_name = "extra_" + std::to_string(shape_id) + "_shape";
573 int ndim = node_shapes[shape_id].ndim();
574 kernel_params += " const op::Shape<" + std::to_string(ndim) + "> " + shape_name;
575 kernel_params += ", ";
576 }
577 for (const auto &type : in_dtypes) {
578 std::string type_name = mshadowTypeToString(type);
579 std::string dtype_var = "DType_" + input_names[i];
580 std::string dim_var = "ndim_" + input_names[i];
581 std::string dim_val = std::to_string(in_ndims[i]);
582 aux_code = "using " + dtype_var + " = " + type_name + ";\n" + aux_code;
583 aux_code = "static const int " + dim_var + " = " + dim_val + ";\n" + aux_code;
584 tensor_params += dtype_var + "* " +input_names[i];
585 kernel_params += " const op::Shape<" + dim_val + "> " + input_names[i]+"_shape";
586 ++i;
587 if (i < num_params) {
588 tensor_params += ", ";
589 }
590 kernel_params += ", ";
591 }
592 for (const auto &type : out_dtypes) {
593 std::string type_name = mshadowTypeToString(type);
594 std::string out_name = "output" + std::to_string(i - in_dtypes.size());
595 std::string dtype_var = "DType_" + out_name;
596 std::string dim_var = "ndim_" + out_name;
597 std::string dim_val = std::to_string(out_ndims[i - in_dtypes.size()]);
598 aux_code = "static const int " + dim_var + " = " + dim_val + ";\n" + aux_code;
599 aux_code = "using " + dtype_var + " = " + type_name + ";\n" + aux_code;
600 tensor_params += dtype_var + "* " + out_name;
601 kernel_params += " const op::Shape<" + dim_val + "> " + out_name+"_shape";
602 ++i;
603 if (i < num_params) {
604 tensor_params += ", ";
605 }
606 kernel_params += ", ";
607 }
608 kernel_params += tensor_params;
609
610 // Create kernel source (minus the common header)
611 return aux_code + "\n" +
612 "__launch_bounds__(" + std::to_string(FusedOp::NTHREADS) + ")\n" +
613 "__global__ void FusedKernel_" + kernel_name +
614 "(size_t N, " + kernel_params + ") {\n" +
615 fusion::kernel_begin + "\n" +
616 code + "\n" +
617 fusion::kernel_end;
618 }
619
CompileCode(const std::string & code,const std::string & kernel_name,int dev_id)620 CUfunction FusedOp::CompileCode(const std::string &code,
621 const std::string &kernel_name,
622 int dev_id) {
623 // Guard NVRTC calls
624 std::lock_guard<std::mutex> lock_nvrtc(mutex_);
625 // Local class for value type of compile cache
626 struct KernelInfo {
627 std::string mangled_name;
628 std::string ptx;
629 std::vector<CUfunction> functions;
630 };
631 // Maps from the cuda source code (minus header) to the ptx and jit-compiled CUfunctions.
632 using KernelCache = std::map<std::string, KernelInfo>;
633 // Per-gpu-architecture compiled kernel cache with jit-compiled function for each device context
634 static std::map<int32_t, KernelCache> compiled_kernels;
635 int sm_arch = SMArch(dev_id);
636 KernelCache& compiled_kernels_this_arch = compiled_kernels[sm_arch]; // make null map as needed
637 KernelInfo& kinfo = compiled_kernels_this_arch[code]; // make KernelInfo as needed
638 if (kinfo.ptx.size() == 0) {
639 // It's the first time we've seen this kernel, so we need to generate the ptx and mangled_name.
640 static std::string common_header =
641 std::string(fusion::fp16_support_string) + "\n" +
642 fusion::type_support_string + "\n" +
643 fusion::function_definitions + "\n" +
644 fusion::backward_function_definitions + "\n";
645 std::string code_with_header = common_header + code;
646 // If verbose mode, output kernel source, though not including the common header
647 if (dmlc::GetEnv("MXNET_FUSION_VERBOSE", false)) {
648 LOG(INFO) << "\n" << std::string(80, '-') << "\n" << code;
649 }
650 if (compiled_kernels_this_arch.size() == CACHESIZE_WARN_THRESHOLD + 1 &&
651 dmlc::GetEnv("MXNET_FUSION_SIZE_WARNING", true)) {
652 LOG(WARNING) << "The number of different fused ops exceeds " << CACHESIZE_WARN_THRESHOLD
653 << ". Set MXNET_FUSION_SIZE_WARNING=0 to quiet this warning.";
654 }
655 nvrtcProgram program;
656 NVRTC_CALL(nvrtcCreateProgram(&program, // prog
657 &code_with_header[0], // buffer
658 (kernel_name + "_kernel.cu").c_str(), // name
659 0, // num headers
660 nullptr, // headers
661 nullptr)); // include names
662
663 std::string gpu_arch_arg = "--gpu-architecture=compute_" + std::to_string(sm_arch);
664 const char *opts[] = {gpu_arch_arg.c_str(),
665 "--std=c++11"};
666 const std::string kernel_name_demangled = "FusedKernel_" + kernel_name;
667 NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str()));
668
669 nvrtcResult compileResult = nvrtcCompileProgram(program, // prog
670 2, // num options
671 opts); // options
672 CHECK_EQ(compileResult, NVRTC_SUCCESS)
673 << "NVRTC Compilation failed. Please set environment variable MXNET_USE_FUSION to 0.\n"
674 << GetCompileLog(program);
675
676 kinfo.ptx = GetPtx(program);
677 const char *mangled_name;
678 NVRTC_CALL(nvrtcGetLoweredName(program,
679 kernel_name_demangled.c_str(),
680 &mangled_name));
681 kinfo.mangled_name = mangled_name;
682 // Destroy the program.
683 NVRTC_CALL(nvrtcDestroyProgram(&program));
684 }
685 // Ensure function array is deep enough to index by dev_id
686 while (kinfo.functions.size() <= static_cast<size_t>(dev_id))
687 kinfo.functions.push_back(static_cast<CUfunction>(nullptr));
688 // Jit-compile ptx for the device as needed
689 if (kinfo.functions[dev_id] == static_cast<CUfunction>(nullptr)) {
690 // Make sure driver context is set to the proper device
691 CUdevice cu_device;
692 CUcontext context;
693 CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, dev_id));
694 CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device));
695 // Jit-compile ptx for the driver's current context
696 CUmodule module;
697 CUDA_DRIVER_CALL(cuModuleLoadData(&module, kinfo.ptx.c_str()));
698 CUDA_DRIVER_CALL(cuModuleGetFunction(&kinfo.functions[dev_id],
699 module,
700 kinfo.mangled_name.c_str()));
701 }
702 return kinfo.functions[dev_id];
703 }
704
705
CheckShapesAndTypes(const std::vector<TBlob> & inputs,const std::vector<TBlob> & outputs,std::vector<int> * in_dtypes,std::vector<int> * in_ndims,std::vector<int> * out_dtypes,std::vector<int> * out_ndims,int * nvec)706 void FusedOp::CheckShapesAndTypes(const std::vector<TBlob> &inputs,
707 const std::vector<TBlob> &outputs,
708 std::vector<int> *in_dtypes,
709 std::vector<int> *in_ndims,
710 std::vector<int> *out_dtypes,
711 std::vector<int> *out_ndims,
712 int *nvec) {
713 std::vector<mxnet::TShape> in_shapes;
714 std::vector<mxnet::TShape> out_shapes;
715 CHECK_EQ(inputs.size(), inputs_.size());
716 CHECK_EQ(outputs.size(), outputs_.size());
717
718 for (size_t counter = 0; counter < inputs.size(); ++counter) {
719 const auto& blob = inputs[counter];
720 in_dtypes->push_back(blob.type_flag_);
721 in_ndims->push_back(blob.ndim());
722 in_shapes.push_back(blob.shape_);
723 initialized_ = initialized_ && blob.type_flag_ == inputs_[counter].dtype;
724 initialized_ = initialized_ && blob.ndim() == inputs_[counter].ndim;
725 inputs_[counter].dtype = blob.type_flag_;
726 inputs_[counter].ndim = blob.ndim();
727 *nvec = max(*nvec, mshadowTypeToVectorLength(blob.type_flag_));
728 }
729
730 for (size_t counter = 0; counter < outputs.size(); ++counter) {
731 const auto& blob = outputs[counter];
732 out_dtypes->push_back(blob.type_flag_);
733 out_ndims->push_back(blob.ndim());
734 out_shapes.push_back(blob.shape_);
735 initialized_ = initialized_ && blob.type_flag_ == outputs_[counter].dtype;
736 initialized_ = initialized_ && blob.ndim() == outputs_[counter].ndim;
737 outputs_[counter].dtype = blob.type_flag_;
738 outputs_[counter].ndim = blob.ndim();
739 *nvec = max(*nvec, mshadowTypeToVectorLength(blob.type_flag_));
740 }
741
742 for (auto it = intermediate_shapes_.begin();
743 it != intermediate_shapes_.end();
744 ++it) {
745 if (it->input_attr == in_shapes && it->output_attr == out_shapes) {
746 intermediate_shapes_.erase(intermediate_shapes_.begin(), it);
747 break;
748 }
749 }
750 for (auto it = intermediate_dtypes_.begin();
751 it != intermediate_dtypes_.end();
752 ++it) {
753 if (it->input_attr == *in_dtypes && it->output_attr == *out_dtypes) {
754 intermediate_dtypes_.erase(intermediate_dtypes_.begin(), it);
755 break;
756 }
757 }
758 }
759
760 template <>
Forward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)761 void FusedOp::Forward<gpu>(const nnvm::NodeAttrs& attrs,
762 const OpContext &ctx,
763 const std::vector<TBlob> &inputs,
764 const std::vector<OpReqType> &req,
765 const std::vector<TBlob> &outputs) {
766 using namespace mshadow;
767 std::lock_guard<std::mutex> lock(my_mutex_);
768 CHECK_GE(outputs.size(), 1) << "There needs to be at least 1 output.";
769
770 std::vector<int> in_dtypes;
771 std::vector<int> in_ndims;
772 std::vector<int> out_dtypes;
773 std::vector<int> out_ndims;
774 int nvec = 1;
775
776 CheckShapesAndTypes(inputs, outputs, &in_dtypes, &in_ndims,
777 &out_dtypes, &out_ndims, &nvec);
778
779 const auto& node_shapes = intermediate_shapes_[0].internal_attr;
780 const auto& node_dtypes = intermediate_dtypes_[0].internal_attr;
781
782 int dev_id = ctx.run_ctx.ctx.dev_id;
783
784 // A change between training and inference modes may require different kernel functions
785 initialized_ = initialized_ && (req == saved_reqs_);
786 saved_reqs_ = req;
787
788 if (!initialized_) {
789 const auto& code = GenerateCode(req, in_dtypes, out_dtypes, in_ndims, out_ndims,
790 node_shapes, node_dtypes, nvec, attrs.name, &check_shape_args_);
791 kernel_functions_[fusion::kGeneral] = CompileCode(code, attrs.name, dev_id);
792 if (check_shape_args_.size() > 0) {
793 const auto& code = GenerateCode(req, in_dtypes, out_dtypes, in_ndims, out_ndims,
794 node_shapes, node_dtypes, nvec, attrs.name, nullptr);
795 kernel_functions_[fusion::kShapeOptimized] = CompileCode(code, attrs.name, dev_id);
796 }
797 initialized_ = true;
798 kernel_function_dev_id_ = dev_id;
799 }
800
801 // A change in device would force recompiling, but this is unexpected so signal as an error
802 if (dev_id != kernel_function_dev_id_)
803 LOG(FATAL) << "Fused op compiled for device " << kernel_function_dev_id_
804 << ", not expecting switch to device " << dev_id;
805
806 Stream<gpu>* s = ctx.get_stream<gpu>();
807 auto stream = Stream<gpu>::GetStream(s);
808 std::vector<void*> args;
809 size_t N = 0;
810 for (const auto& output : outputs) {
811 N = std::max(N, output.shape_.Size());
812 }
813 N = (N + nvec - 1)/nvec;
814 args.push_back(&N);
815
816 unsigned int num_blocks = (N + FusedOp::NTHREADS - 1) / FusedOp::NTHREADS;
817
818 std::vector<void*> ptrs;
819 std::vector<std::vector<int>> shapes;
820
821 for (const auto &shape_id : extra_shape_args_) {
822 AddShape(node_shapes[shape_id], &shapes);
823 }
824 for (const auto &data : inputs) {
825 AddPointerAndShape(data, &ptrs, &shapes, s);
826 }
827 for (const auto &data : outputs) {
828 AddPointerAndShape(data, &ptrs, &shapes, s);
829 }
830
831 for (auto &tensor_shapes : shapes) {
832 args.push_back(tensor_shapes.data());
833 }
834 for (auto &ptr : ptrs) {
835 args.push_back(reinterpret_cast<void *>(&ptr));
836 }
837 int kernel_variant = fusion::kGeneral;
838 if (check_shape_args_.size() > 0) {
839 kernel_variant = fusion::kShapeOptimized;
840 for (const auto &shape_id : check_shape_args_) {
841 const auto& shape = node_shapes[shape_id];
842 if (shape[shape.ndim()-1] % nvec != 0) {
843 kernel_variant = fusion::kGeneral;
844 }
845 }
846 }
847 CUDA_DRIVER_CALL(
848 cuLaunchKernel(kernel_functions_[kernel_variant],
849 num_blocks, 1, 1, // grid dim
850 FusedOp::NTHREADS, 1, 1, // block dim
851 0, stream, // shared mem and stream
852 &(args[0]), 0)); // arguments
853 }
854
FusedOpForwardGPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)855 void FusedOpForwardGPU(const nnvm::NodeAttrs& attrs,
856 const OpContext &ctx,
857 const std::vector<TBlob> &inputs,
858 const std::vector<OpReqType> &req,
859 const std::vector<TBlob> &outputs) {
860 const FusedOpPtr& op = nnvm::get<FusedOpPtr>(attrs.parsed);
861 op->Forward<gpu>(attrs, ctx, inputs, req, outputs);
862 }
863
864 NNVM_REGISTER_OP(_FusedOp)
865 .set_attr<FCompute>("FCompute<gpu>", FusedOpForwardGPU);
866
867 } // namespace mxnet
868
869 #endif // MXNET_ENABLE_CUDA_RTC
870