1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 // Additional use of MXNET_USE_CUDA is not needed to guard a '.cu' file.
21 #if MXNET_ENABLE_CUDA_RTC
22 
23 #include <sys/stat.h>
24 #include <nvrtc.h>
25 #include <cuda.h>
26 #include <nnvm/pass_functions.h>
27 #include <algorithm>
28 #include <mutex>
29 #include "./fused_op.h"
30 #include "./fused_op-inl.h"
31 #include "../operator_common.h"
32 #include "../elemwise_op_common.h"
33 #include "../../executor/exec_pass.h"
34 #include "../../common/cuda_utils.h"
35 
36 namespace mxnet {
37 
38 namespace {
39 
mshadowTypeToString(int type)40 inline std::string mshadowTypeToString(int type) {
41   switch (type) {
42     case mshadow::kFloat32:
43       return "float";
44     case mshadow::kFloat64:
45       return "double";
46     case mshadow::kFloat16:
47       return "half";
48     case mshadow::kUint8:
49       return "unsigned char";
50     case mshadow::kInt8:
51       return "char";
52     case mshadow::kInt32:
53       return "int";
54     case mshadow::kInt64:
55       return "long long";
56     case mshadow::kBool:
57       return "bool";
58     default:
59       LOG(FATAL) << "Unknown type enum " << type;
60   }
61   return "";
62 }
63 
mshadowTypeToVectorLength(int type)64 inline int mshadowTypeToVectorLength(int type) {
65   switch (type) {
66     case mshadow::kFloat32:
67       return 1;
68     case mshadow::kFloat64:
69       return 1;
70     case mshadow::kFloat16:
71       return 2;
72     case mshadow::kUint8:
73       return 4;
74     case mshadow::kInt8:
75       return 4;
76     case mshadow::kInt32:
77       return 1;
78     case mshadow::kInt64:
79       return 1;
80     case mshadow::kBool:
81       return 4 / sizeof(bool);
82     default:
83       LOG(FATAL) << "Unknown type enum " << type;
84   }
85   return 0;
86 }
87 
replaceString(std::string * input,const std::string old,const std::string repl)88 inline void replaceString(std::string *input, const std::string old, const std::string repl) {
89     size_t pos = 0;
90     while ((pos = input->find(old, pos)) != std::string::npos) {
91         input->replace(pos, old.size(), repl);
92         pos += repl.size();
93     }
94 }
95 
splitStringToVector(const std::string & input,const std::string def)96 inline std::vector<int> splitStringToVector(const std::string& input, const std::string def) {
97     size_t pos_start = 0, pos_end;
98     const std::string& s = input.substr(1, input.length()-2);
99     std::vector<int> res;
100 
101     auto convert_token = [def](std::string token){
102         if (token == def) {
103             return 0;
104         }
105         return std::stoi(token);
106     };
107 
108     while ((pos_end = s.find(",", pos_start)) != std::string::npos) {
109         std::string token = s.substr(pos_start, pos_end - pos_start);
110         pos_start = pos_end + 1;
111         if (token.length() > 0) {
112             res.push_back(convert_token(token));
113         }
114     }
115 
116     if (pos_start < s.length()) {
117         res.push_back(convert_token(s.substr(pos_start)));
118     }
119     return res;
120 }
121 
ParseOpDescription(const std::vector<std::string> & op_desc,const std::map<std::pair<int,int>,std::string> & variables,const nnvm::IndexedGraph::Node & node)122 std::string ParseOpDescription(const std::vector<std::string>& op_desc,
123                                const std::map<std::pair<int, int>, std::string>& variables,
124                                const nnvm::IndexedGraph::Node& node) {
125   const auto* source = node.source;
126   std::string fmt = op_desc[0];
127   for (size_t j = 1; j < op_desc.size(); ++j) {
128     const std::string& desc = op_desc[j];
129     std::string sub;
130     if (desc[0] == '_') {
131       // Argument
132       const int arg_id = std::stoi(desc.substr(1));
133       sub = variables.at({node.inputs[arg_id].node_id, node.inputs[arg_id].index});
134     } else {
135       sub = source->attrs.dict.at(desc);
136     }
137     size_t pos = fmt.find("%");
138     CHECK_NE(pos, std::string::npos);
139     fmt.replace(pos, 1, sub);
140   }
141   return fmt;
142 }
143 
AddShape(const mxnet::TShape & shape,std::vector<std::vector<int>> * shapes)144 void AddShape(const mxnet::TShape& shape,
145               std::vector<std::vector<int>>* shapes) {
146   // We need alignment to 8 bytes for size_t in the Shape struct
147   // so if ndim is odd, there will be 4B of padding
148   int ndim = shape.ndim();
149   const int offset = ndim % 2 == 0 ? 2 : 3;
150   shapes->push_back(std::vector<int>(ndim + offset));
151   std::vector<int>& tensor_shapes = shapes->back();
152   size_t total_size = 1;
153   for (int i = ndim-1; i >= 0; i--) {
154     tensor_shapes[i] = shape[i];
155     total_size *= shape[i];
156   }
157   size_t * shape_size_ptr = reinterpret_cast<size_t*>(&tensor_shapes[ndim + offset - 2]);
158   *shape_size_ptr = total_size;
159 }
160 
AddPointerAndShape(const TBlob & data,std::vector<void * > * ptrs,std::vector<std::vector<int>> * shapes,mshadow::Stream<gpu> * s)161 void AddPointerAndShape(const TBlob& data,
162                         std::vector<void*> *ptrs,
163                         std::vector<std::vector<int>>* shapes,
164                         mshadow::Stream<gpu> * s) {
165   using namespace mshadow;
166   MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
167     Tensor<gpu, 1, DType> tensor = data.FlatTo1D<gpu, DType>(s);
168     ptrs->push_back(tensor.dptr_);
169     AddShape(data.shape_, shapes);
170   });
171 }
172 
173 // Obtain compilation log from the program.
GetCompileLog(nvrtcProgram program)174 std::string GetCompileLog(nvrtcProgram program) {
175   size_t log_size_including_null;
176   NVRTC_CALL(nvrtcGetProgramLogSize(program, &log_size_including_null));
177   // For most std::string implementations, this is probably 1 char bigger than needed.  OK though.
178   std::string log(log_size_including_null, '\0');
179   NVRTC_CALL(nvrtcGetProgramLog(program, &log[0]));
180   // Make sure the string reflects the true size (so minus the null terminator).
181   log.resize(log_size_including_null - 1);
182   return log;
183 }
184 
185 // Obtain compilation result (ptx assembly) from the program.
GetPtx(nvrtcProgram program)186 std::string GetPtx(nvrtcProgram program) {
187   size_t ptx_size_including_null;
188   NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size_including_null));
189   // For most std::string implementations, this is probably 1 char bigger than needed.  OK though.
190   std::string ptx(ptx_size_including_null, '\0');
191   NVRTC_CALL(nvrtcGetPTX(program, &ptx[0]));
192   // Make sure the string reflects the true size (so minus the null terminator).
193   ptx.resize(ptx_size_including_null - 1);
194   return ptx;
195 }
196 
197 }  // namespace
198 
GenerateCode(const std::vector<OpReqType> & req,const std::vector<int> & in_dtypes,const std::vector<int> & out_dtypes,const std::vector<int> & in_ndims,const std::vector<int> & out_ndims,const mxnet::ShapeVector & node_shapes,const std::vector<int> & node_dtypes,const int nvec,const std::string & kernel_name,std::vector<uint32_t> * check_shapes)199 std::string FusedOp::GenerateCode(const std::vector<OpReqType> &req,
200                            const std::vector<int> &in_dtypes,
201                            const std::vector<int> &out_dtypes,
202                            const std::vector<int> &in_ndims,
203                            const std::vector<int> &out_ndims,
204                            const mxnet::ShapeVector &node_shapes,
205                            const std::vector<int> &node_dtypes,
206                            const int nvec,
207                            const std::string &kernel_name,
208                            std::vector<uint32_t>* check_shapes) {
209   const auto& g = subgraph_.indexed_graph();
210   std::string code = "";
211   int temp_name_counter = 0;
212   using NodeEntry = nnvm::IndexedGraph::NodeEntry;
213   std::map<std::pair<int, int>, std::string> variables;
214   std::map<int, int> load_index;
215   bool check_shapes_compile = true;
216 
217   std::vector<uint32_t> outputs(g.num_nodes());
218 
219   for (size_t i = 0; i < g.num_nodes(); ++i) {
220     const auto& node = g[i];
221     if (node.source != nullptr) {
222       outputs[i] = node.source->num_outputs();
223     } else {
224       outputs[i] = 0;
225     }
226   }
227 
228   for (size_t i = 0; i < g.num_nodes(); ++i) {
229     const auto& node = g[i];
230     const auto* source = node.source;
231     if (source != nullptr) {
232         if (source->is_variable()) {
233             load_index[i] = 1;
234         } else {
235             std::string op_name = source->op()->name;
236             if (fusion::slice_ops.find(op_name) != fusion::slice_ops.end()) {
237                 load_index[node.inputs[0].node_id] = 0;
238             }
239         }
240     }
241   }
242   for (size_t i = 0; i < g.num_nodes(); ++i) {
243     const auto& node = g[i];
244     const auto* source = node.source;
245     if (source != nullptr) {
246       if (source->is_variable()) {
247         if (load_index[i]) {
248           const auto& var_name = source->attrs.name;
249           code += "const auto vec_" + var_name + " = op::load_index<nvec>(" +
250                    var_name + ", offset, " + var_name + "_shape);\n";
251           variables[{i, 0}] = var_name;
252         }
253         CHECK_EQ(outputs[i], 1);
254       } else {
255         std::string op_name = source->op()->name;
256         if (fusion::slice_ops.find(op_name) != fusion::slice_ops.end()) {
257           int node_id = node.inputs[0].node_id;
258           const uint32_t input_entry_id = g.entry_id(node.inputs[0]);
259           const auto& shape = node_shapes[input_entry_id];
260           const int ndim = shape.ndim();
261           const auto& var_name = g[node_id].source->attrs.name;
262           const auto vec_name = "vec_" + var_name + "_" + std::to_string(i);
263           load_index[node_id] = 0;
264           auto parse_tuple = [ndim](const std::string& input, const std::string& def) {
265             std::string out = input;
266             replaceString(&out, " ", "");
267             if (out[0] == '(') {
268               replaceString(&out, "(", "{");
269               replaceString(&out, ")", "}");
270               // First check if out is ()
271               int n_entries = out.size() != 2;
272               for (size_t i = 1; i < out.size() - 1; ++i) {
273                 if (out[i] == ',') {
274                   ++n_entries;
275                 }
276               }
277               if (n_entries != ndim) {
278                 out.pop_back();
279                 for (int i = n_entries; i < ndim; ++i) {
280                   out += "," + def;
281                 }
282                 out += "}";
283               }
284             } else {
285               out = "{" + std::move(out);
286               for (int i = 1; i < ndim; ++i) {
287                 out += "," + def;
288               }
289               out += "}";
290             }
291             replaceString(&out, "None", def);
292             return out;
293           };
294           auto parse_int = [](const std::string& input, const std::string& def) {
295             std::string out = input;
296             replaceString(&out, " ", "");
297             replaceString(&out, "None", def);
298             return out;
299           };
300           auto build_tuple = [ndim](int axis, const std::string str, const std::string def) {
301             if (axis < 0 &&
302                 axis >= -ndim) {
303               axis += ndim;
304             }
305             if (axis < 0 || axis >= ndim) {
306               LOG(FATAL) << "Axis " << axis << " is out of bounds for array of dimension " << ndim;
307             }
308             std::string tuple = "{";
309             for (int i = 0; i < axis; i++) {
310                 tuple += def + ",";
311             }
312             tuple += str;
313             for (int i = axis + 1; i < ndim; i++) {
314                 tuple += "," + def;
315             }
316             tuple += "}";
317             return tuple;
318           };
319           auto check_tuple = [ndim, nvec](const std::string str) {
320             std::vector<int> tuple = splitStringToVector(str, "INT_MAX");
321             if (tuple[ndim-1] % nvec == 0) {
322               return true;
323             }
324             return false;
325           };
326           auto build_string_end = [i, ndim, var_name](std::string* code) {
327             std::string end_var_name = var_name + "_" + std::to_string(i) + "_end";
328             *code += "op::Shape<" + std::to_string(ndim) + "> "+ end_var_name + ";\n";
329             *code += end_var_name + ".set(INT_MAX);\n";
330             return end_var_name;
331           };
332           std::string begin;
333           std::string end;
334           if (op_name == "broadcast_like" || op_name == "slice_like") {
335             uint32_t like_id = g.entry_id(i, 0);
336             begin = build_tuple(0, "0", "0");
337             std::string extra_var_name = "extra_" + std::to_string(like_id) + "_shape";
338             if (std::find(extra_shape_args_.begin(), extra_shape_args_.end(), like_id) ==
339                 extra_shape_args_.end()) {
340                 extra_shape_args_.push_back(like_id);
341             }
342             if (check_shapes) {
343               check_shapes->push_back(like_id);
344               check_shapes->push_back(input_entry_id);
345             }
346             end = extra_var_name;
347           } else {
348             if (op_name == "slice_axis") {
349               begin = parse_int(source->attrs.dict.at("begin"), "0");
350               end = parse_int(source->attrs.dict.at("end"), "INT_MAX");
351               int axis = std::stoi(source->attrs.dict.at("axis"));
352               begin = build_tuple(axis, begin, "0");
353               end = build_tuple(axis, end, "INT_MAX");
354             } else {
355               begin = parse_tuple(source->attrs.dict.at("begin"), "0");
356               end = parse_tuple(source->attrs.dict.at("end"), "INT_MAX");
357             }
358             if (check_shapes) {
359               if (check_tuple(begin) && check_tuple(end)) {
360                 check_shapes->push_back(input_entry_id);
361               } else {
362                 check_shapes_compile = false;
363               }
364             }
365           }
366           std::string slice_func = "load_slice";
367           if (!check_shapes) {
368             slice_func = "fast_" + slice_func;
369           }
370           code += "const auto " + vec_name + " = op::" + slice_func + "<nvec>(" +
371                   var_name + ", " + var_name + "_shape," + begin +
372                   "," + end + ", offset);\n";
373           CHECK_EQ(outputs[i], 1);
374           variables[{i, 0}] = vec_name;
375           continue;
376         }
377       }
378     }
379   }
380 
381   if (!check_shapes_compile) {
382       check_shapes->clear();
383   }
384 
385   size_t counter = 0;
386   for (const auto& entry : g.outputs()) {
387     std::string var_name = "output" + std::to_string(counter);
388     code += "op::VectorType<DType_" + var_name + \
389             ", nvec> vec_" + var_name + ";\n";
390     ++counter;
391   }
392 
393   code += "for (int j = 0; j < nvec; j++ ) {\n";
394 
395 
396   for (size_t i = 0; i < g.num_nodes(); ++i) {
397     const auto& node = g[i];
398     const auto* source = node.source;
399     if (source != nullptr) {
400       std::string var_name = "temp" + std::to_string(temp_name_counter++);
401       if (source->is_variable()) {
402         if (load_index[i]) {
403             code += "const auto " + var_name + " = op::load(vec_" +
404                     variables[{i, 0}] + ".x[j]);\n";
405             CHECK_EQ(outputs[i], 1);
406             variables[{i, 0}] = var_name;
407         }
408       } else {
409         std::string op_name = source->op()->name;
410         if (fusion::ops_desc.find(op_name) != fusion::ops_desc.end()) {
411           const std::vector<std::vector<std::string>>& op_descs =
412             fusion::ops_desc.at(op_name);
413           CHECK_EQ(outputs[i], op_descs.size());
414           size_t count = 0;
415           for (const auto& op_desc : op_descs) {
416             var_name = "temp" + std::to_string(temp_name_counter++);
417             const std::string& fmt = ParseOpDescription(op_desc, variables, node);
418             code += "const auto " + var_name + " = " + fmt + ";\n";
419             variables[{i, count}] = var_name;
420             ++count;
421           }
422           continue;
423         }
424 
425         if (fusion::slice_ops.find(op_name) != fusion::slice_ops.end()) {
426           code += "const auto " + var_name + " = op::load(" + variables[{i, 0}] + ".x[j]);\n";
427           variables[{i, 0}] = var_name;
428           continue;
429         }
430 
431 
432         // Special cases with variable number
433         // of inputs/outputs, listed in
434         // fusion::variable_io_ops
435         if (op_name == "add_n") {
436           CHECK_EQ(outputs[i], 1);
437           const auto& arg = variables[{node.inputs[0].node_id, node.inputs[0].index}];
438           code += "auto " + var_name + " = " + arg + ";\n";
439           for (size_t inp = 1; inp < node.inputs.size(); ++inp) {
440             const auto& temp_arg = variables[{node.inputs[inp].node_id, node.inputs[inp].index}];
441             code += var_name + " = op::add(" + var_name + ", " + temp_arg + ");\n";
442           }
443           variables[{i, 0}] = var_name;
444           continue;
445         }
446 
447         if (op_name == "_backward_Activation") {
448           CHECK_EQ(outputs[i], 1);
449           std::string act_type = node.source->attrs.dict.at("act_type");
450           std::string rhs, lhs;
451           rhs = variables[{node.inputs[0].node_id, node.inputs[0].index}];
452           if (act_type == "relu" ||
453               act_type == "sigmoid" ||
454               act_type == "tanh") {
455             lhs = variables[{node.inputs[1].node_id, node.inputs[1].index}];
456           } else {
457             lhs = variables[{node.inputs[2].node_id, node.inputs[2].index}];
458           }
459           code += "const auto " + var_name + " = op::backward_" + act_type +
460                   "(" + lhs + ", " + rhs + ");\n";
461 
462           variables[{i, 0}] = var_name;
463           continue;
464         }
465 
466         if (op_name == "amp_multicast" || op_name == "_backward_amp_multicast") {
467           CHECK_EQ(outputs[i], node.inputs.size());
468           for (size_t counter = 0; counter < outputs[i]; ++counter) {
469             const auto& input = node.inputs[counter];
470             var_name = "temp" + std::to_string(temp_name_counter++);
471             const auto& arg = variables[{input.node_id, input.index}];
472             code += "const auto " + var_name + " = " + arg + ";\n";
473             variables[{i, counter}] = var_name;
474           }
475           continue;
476         }
477 
478         if (op_name == "_backward_cast") {
479           CHECK_EQ(outputs[i], 1);
480           const int output_type = node_dtypes[g.entry_id(i, 0)];
481           const auto& arg = variables[{node.inputs[0].node_id, node.inputs[0].index}];
482           code += "const auto " + var_name + " = op::cast<" + mshadowTypeToString(output_type) +
483                   ">(" + arg + ");\n";
484           variables[{i, 0}] = var_name;
485           continue;
486         }
487 
488         // LeakyReLU, look for act_type
489         if (op_name == "LeakyReLU") {
490             std::string act_type = node.source->attrs.dict.at("act_type");
491             const std::vector<std::vector<std::string>>& op_descs =
492                 fusion::LeakyReLU_ops.at(act_type);
493             if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_ops.end()) {
494               CHECK_EQ(outputs[i], op_descs.size());
495               size_t count = 0;
496               for (const auto& op_desc : op_descs) {
497                 var_name = "temp" + std::to_string(temp_name_counter++);
498                 const std::string& fmt = ParseOpDescription(op_desc, variables, node);
499                 code += "const auto " + var_name + " = " + fmt + ";\n";
500                 variables[{i, count}] = var_name;
501                 ++count;
502               }
503               continue;
504             }
505         }
506         if (op_name == "_backward_LeakyReLU") {
507             std::string act_type = node.source->attrs.dict.at("act_type");
508             const std::vector<std::vector<std::string>>& op_descs =
509                 fusion::LeakyReLU_bwd_ops.at(act_type);
510             if (fusion::LeakyReLU_ops.find(act_type) != fusion::LeakyReLU_bwd_ops.end()) {
511               CHECK_EQ(outputs[i], op_descs.size());
512               size_t count = 0;
513               for (const auto& op_desc : op_descs) {
514                 var_name = "temp" + std::to_string(temp_name_counter++);
515                 const std::string& fmt = ParseOpDescription(op_desc, variables, node);
516                 code += "const auto " + var_name + " = " + fmt + ";\n";
517                 variables[{i, count}] = var_name;
518                 ++count;
519               }
520               continue;
521             }
522         }
523 
524         LOG(FATAL) << "Unrecognized op " + op_name;
525       }
526     } else {
527       LOG(FATAL) << "Encountered node with NULL source.";
528     }
529   }
530 
531   counter = 0;
532   for (const auto& entry : g.outputs()) {
533     const std::string& var = variables[{entry.node_id, entry.index}];
534     const auto var_name = "output" + std::to_string(counter);
535     code += "vec_" + var_name + ".x[j] = op::store("+ var +", " + var_name + ");\n";
536     ++counter;
537   }
538 
539   code += "}\n";
540 
541   counter = 0;
542 
543   for (const auto& entry : g.outputs()) {
544     const std::string& var = variables[{entry.node_id, entry.index}];
545     if (req[counter] == kWriteTo || req[counter] == kWriteInplace) {
546       const auto var_name = "output" + std::to_string(counter);
547       code += "op::store_index(vec_" + var_name + ", i, " + var_name + ", " +
548               var_name + "_shape);\n";
549     } else if (req[counter] == kAddTo) {
550       const auto var_name = "output" + std::to_string(counter);
551       code += "op::store_add_index(vec_" + var_name + ", i, " + var_name + ", " +
552               var_name + "_shape);\n";
553     } else if (req[counter] == kNullOp) {
554       // nullptr req, do not do anything
555     } else {
556       LOG(FATAL) << "Encountered unexpected req.";
557     }
558     ++counter;
559   }
560 
561   // Add boilerplate and type information
562   std::string kernel_params = "";
563   std::string tensor_params = "";
564   nnvm::Symbol sym;
565   sym.outputs = subgraph_.outputs;
566   const std::vector<std::string> input_names = sym.ListInputNames(nnvm::Symbol::kAll);
567   size_t num_params = in_dtypes.size() + out_dtypes.size();
568   size_t i = 0;
569   std::string aux_code = "static const int nvec = " + std::to_string(nvec) + ";\n";
570 
571   for (const auto &shape_id : extra_shape_args_) {
572       std::string shape_name = "extra_" + std::to_string(shape_id) + "_shape";
573       int ndim = node_shapes[shape_id].ndim();
574       kernel_params += " const op::Shape<" + std::to_string(ndim) + "> " + shape_name;
575       kernel_params += ", ";
576   }
577   for (const auto &type : in_dtypes) {
578     std::string type_name = mshadowTypeToString(type);
579     std::string dtype_var = "DType_" + input_names[i];
580     std::string dim_var = "ndim_" + input_names[i];
581     std::string dim_val = std::to_string(in_ndims[i]);
582     aux_code = "using " + dtype_var + " = " + type_name + ";\n" + aux_code;
583     aux_code = "static const int " + dim_var + " = " + dim_val + ";\n" + aux_code;
584     tensor_params += dtype_var + "* " +input_names[i];
585     kernel_params += " const op::Shape<" + dim_val + "> " + input_names[i]+"_shape";
586     ++i;
587     if (i < num_params) {
588       tensor_params += ", ";
589     }
590     kernel_params += ", ";
591   }
592   for (const auto &type : out_dtypes) {
593     std::string type_name = mshadowTypeToString(type);
594     std::string out_name = "output" + std::to_string(i - in_dtypes.size());
595     std::string dtype_var = "DType_" + out_name;
596     std::string dim_var = "ndim_" + out_name;
597     std::string dim_val = std::to_string(out_ndims[i - in_dtypes.size()]);
598     aux_code = "static const int " + dim_var + " = " + dim_val + ";\n" + aux_code;
599     aux_code = "using " + dtype_var + " = " + type_name + ";\n" + aux_code;
600     tensor_params += dtype_var + "* " + out_name;
601     kernel_params += " const op::Shape<" + dim_val + "> " + out_name+"_shape";
602     ++i;
603     if (i < num_params) {
604       tensor_params += ", ";
605     }
606     kernel_params += ", ";
607   }
608   kernel_params += tensor_params;
609 
610   // Create kernel source (minus the common header)
611   return aux_code + "\n" +
612          "__launch_bounds__(" + std::to_string(FusedOp::NTHREADS) + ")\n" +
613          "__global__ void FusedKernel_" + kernel_name +
614          "(size_t N, " + kernel_params + ") {\n" +
615          fusion::kernel_begin + "\n" +
616          code + "\n" +
617          fusion::kernel_end;
618 }
619 
CompileCode(const std::string & code,const std::string & kernel_name,int dev_id)620 CUfunction FusedOp::CompileCode(const std::string &code,
621                                 const std::string &kernel_name,
622                                 int dev_id) {
623   // Guard NVRTC calls
624   std::lock_guard<std::mutex> lock_nvrtc(mutex_);
625   // Local class for value type of compile cache
626   struct KernelInfo {
627     std::string mangled_name;
628     std::string ptx;
629     std::vector<CUfunction> functions;
630   };
631   // Maps from the cuda source code (minus header) to the ptx and jit-compiled CUfunctions.
632   using KernelCache = std::map<std::string, KernelInfo>;
633   // Per-gpu-architecture compiled kernel cache with jit-compiled function for each device context
634   static std::map<int32_t, KernelCache> compiled_kernels;
635   int sm_arch = SMArch(dev_id);
636   KernelCache& compiled_kernels_this_arch = compiled_kernels[sm_arch];  // make null map as needed
637   KernelInfo& kinfo = compiled_kernels_this_arch[code];                 // make KernelInfo as needed
638   if (kinfo.ptx.size() == 0) {
639     // It's the first time we've seen this kernel, so we need to generate the ptx and mangled_name.
640     static std::string common_header =
641         std::string(fusion::fp16_support_string) + "\n" +
642         fusion::type_support_string + "\n" +
643         fusion::function_definitions + "\n" +
644         fusion::backward_function_definitions + "\n";
645     std::string code_with_header = common_header + code;
646     // If verbose mode, output kernel source, though not including the common header
647     if (dmlc::GetEnv("MXNET_FUSION_VERBOSE", false)) {
648       LOG(INFO) << "\n" << std::string(80, '-') << "\n" << code;
649     }
650     if (compiled_kernels_this_arch.size() == CACHESIZE_WARN_THRESHOLD + 1 &&
651         dmlc::GetEnv("MXNET_FUSION_SIZE_WARNING", true)) {
652       LOG(WARNING) << "The number of different fused ops exceeds " << CACHESIZE_WARN_THRESHOLD
653                    << ".  Set MXNET_FUSION_SIZE_WARNING=0 to quiet this warning.";
654     }
655     nvrtcProgram program;
656     NVRTC_CALL(nvrtcCreateProgram(&program,                                  // prog
657                                   &code_with_header[0],                      // buffer
658                                   (kernel_name + "_kernel.cu").c_str(),      // name
659                                   0,                                         // num headers
660                                   nullptr,                                      // headers
661                                   nullptr));                                    // include names
662 
663     std::string gpu_arch_arg = "--gpu-architecture=compute_" + std::to_string(sm_arch);
664     const char *opts[] = {gpu_arch_arg.c_str(),
665                           "--std=c++11"};
666     const std::string kernel_name_demangled = "FusedKernel_" + kernel_name;
667     NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str()));
668 
669     nvrtcResult compileResult = nvrtcCompileProgram(program,  // prog
670                                                     2,        // num options
671                                                     opts);    // options
672     CHECK_EQ(compileResult, NVRTC_SUCCESS)
673         << "NVRTC Compilation failed. Please set environment variable MXNET_USE_FUSION to 0.\n"
674         << GetCompileLog(program);
675 
676     kinfo.ptx = GetPtx(program);
677     const char *mangled_name;
678     NVRTC_CALL(nvrtcGetLoweredName(program,
679                                    kernel_name_demangled.c_str(),
680                                    &mangled_name));
681     kinfo.mangled_name = mangled_name;
682     // Destroy the program.
683     NVRTC_CALL(nvrtcDestroyProgram(&program));
684   }
685   // Ensure function array is deep enough to index by dev_id
686   while (kinfo.functions.size() <= static_cast<size_t>(dev_id))
687     kinfo.functions.push_back(static_cast<CUfunction>(nullptr));
688   // Jit-compile ptx for the device as needed
689   if (kinfo.functions[dev_id] == static_cast<CUfunction>(nullptr)) {
690     // Make sure driver context is set to the proper device
691     CUdevice cu_device;
692     CUcontext context;
693     CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, dev_id));
694     CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device));
695     // Jit-compile ptx for the driver's current context
696     CUmodule module;
697     CUDA_DRIVER_CALL(cuModuleLoadData(&module, kinfo.ptx.c_str()));
698     CUDA_DRIVER_CALL(cuModuleGetFunction(&kinfo.functions[dev_id],
699                                          module,
700                                          kinfo.mangled_name.c_str()));
701   }
702   return kinfo.functions[dev_id];
703 }
704 
705 
CheckShapesAndTypes(const std::vector<TBlob> & inputs,const std::vector<TBlob> & outputs,std::vector<int> * in_dtypes,std::vector<int> * in_ndims,std::vector<int> * out_dtypes,std::vector<int> * out_ndims,int * nvec)706 void FusedOp::CheckShapesAndTypes(const std::vector<TBlob> &inputs,
707                                   const std::vector<TBlob> &outputs,
708                                   std::vector<int> *in_dtypes,
709                                   std::vector<int> *in_ndims,
710                                   std::vector<int> *out_dtypes,
711                                   std::vector<int> *out_ndims,
712                                   int *nvec) {
713   std::vector<mxnet::TShape> in_shapes;
714   std::vector<mxnet::TShape> out_shapes;
715   CHECK_EQ(inputs.size(), inputs_.size());
716   CHECK_EQ(outputs.size(), outputs_.size());
717 
718   for (size_t counter = 0; counter < inputs.size(); ++counter) {
719     const auto& blob = inputs[counter];
720     in_dtypes->push_back(blob.type_flag_);
721     in_ndims->push_back(blob.ndim());
722     in_shapes.push_back(blob.shape_);
723     initialized_ = initialized_ && blob.type_flag_ == inputs_[counter].dtype;
724     initialized_ = initialized_ && blob.ndim() == inputs_[counter].ndim;
725     inputs_[counter].dtype = blob.type_flag_;
726     inputs_[counter].ndim = blob.ndim();
727     *nvec = max(*nvec, mshadowTypeToVectorLength(blob.type_flag_));
728   }
729 
730   for (size_t counter = 0; counter < outputs.size(); ++counter) {
731     const auto& blob = outputs[counter];
732     out_dtypes->push_back(blob.type_flag_);
733     out_ndims->push_back(blob.ndim());
734     out_shapes.push_back(blob.shape_);
735     initialized_ = initialized_ && blob.type_flag_ == outputs_[counter].dtype;
736     initialized_ = initialized_ && blob.ndim() == outputs_[counter].ndim;
737     outputs_[counter].dtype = blob.type_flag_;
738     outputs_[counter].ndim = blob.ndim();
739     *nvec = max(*nvec, mshadowTypeToVectorLength(blob.type_flag_));
740   }
741 
742   for (auto it = intermediate_shapes_.begin();
743        it != intermediate_shapes_.end();
744        ++it) {
745     if (it->input_attr == in_shapes && it->output_attr == out_shapes) {
746       intermediate_shapes_.erase(intermediate_shapes_.begin(), it);
747       break;
748     }
749   }
750   for (auto it = intermediate_dtypes_.begin();
751        it != intermediate_dtypes_.end();
752        ++it) {
753     if (it->input_attr == *in_dtypes && it->output_attr == *out_dtypes) {
754       intermediate_dtypes_.erase(intermediate_dtypes_.begin(), it);
755       break;
756     }
757   }
758 }
759 
760 template <>
Forward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)761 void FusedOp::Forward<gpu>(const nnvm::NodeAttrs& attrs,
762                            const OpContext &ctx,
763                            const std::vector<TBlob> &inputs,
764                            const std::vector<OpReqType> &req,
765                            const std::vector<TBlob> &outputs) {
766   using namespace mshadow;
767   std::lock_guard<std::mutex> lock(my_mutex_);
768   CHECK_GE(outputs.size(), 1) << "There needs to be at least 1 output.";
769 
770   std::vector<int> in_dtypes;
771   std::vector<int> in_ndims;
772   std::vector<int> out_dtypes;
773   std::vector<int> out_ndims;
774   int nvec = 1;
775 
776   CheckShapesAndTypes(inputs, outputs, &in_dtypes, &in_ndims,
777                       &out_dtypes, &out_ndims, &nvec);
778 
779   const auto& node_shapes = intermediate_shapes_[0].internal_attr;
780   const auto& node_dtypes = intermediate_dtypes_[0].internal_attr;
781 
782   int dev_id = ctx.run_ctx.ctx.dev_id;
783 
784   // A change between training and inference modes may require different kernel functions
785   initialized_ = initialized_ && (req == saved_reqs_);
786   saved_reqs_ = req;
787 
788   if (!initialized_) {
789     const auto& code = GenerateCode(req, in_dtypes, out_dtypes, in_ndims, out_ndims,
790                        node_shapes, node_dtypes, nvec, attrs.name, &check_shape_args_);
791     kernel_functions_[fusion::kGeneral] = CompileCode(code, attrs.name, dev_id);
792     if (check_shape_args_.size() > 0) {
793       const auto& code = GenerateCode(req, in_dtypes, out_dtypes, in_ndims, out_ndims,
794                            node_shapes, node_dtypes, nvec, attrs.name, nullptr);
795       kernel_functions_[fusion::kShapeOptimized] = CompileCode(code, attrs.name, dev_id);
796     }
797     initialized_ = true;
798     kernel_function_dev_id_ = dev_id;
799   }
800 
801   // A change in device would force recompiling, but this is unexpected so signal as an error
802   if (dev_id != kernel_function_dev_id_)
803     LOG(FATAL) << "Fused op compiled for device " << kernel_function_dev_id_
804                <<  ", not expecting switch to device " << dev_id;
805 
806   Stream<gpu>* s = ctx.get_stream<gpu>();
807   auto stream = Stream<gpu>::GetStream(s);
808   std::vector<void*> args;
809   size_t N = 0;
810   for (const auto& output : outputs) {
811     N = std::max(N, output.shape_.Size());
812   }
813   N = (N + nvec - 1)/nvec;
814   args.push_back(&N);
815 
816   unsigned int num_blocks = (N + FusedOp::NTHREADS - 1) / FusedOp::NTHREADS;
817 
818   std::vector<void*> ptrs;
819   std::vector<std::vector<int>> shapes;
820 
821   for (const auto &shape_id : extra_shape_args_) {
822     AddShape(node_shapes[shape_id], &shapes);
823   }
824   for (const auto &data : inputs) {
825     AddPointerAndShape(data, &ptrs, &shapes, s);
826   }
827   for (const auto &data : outputs) {
828     AddPointerAndShape(data, &ptrs, &shapes, s);
829   }
830 
831   for (auto &tensor_shapes : shapes) {
832     args.push_back(tensor_shapes.data());
833   }
834   for (auto &ptr : ptrs) {
835     args.push_back(reinterpret_cast<void *>(&ptr));
836   }
837   int kernel_variant = fusion::kGeneral;
838   if (check_shape_args_.size() > 0) {
839     kernel_variant = fusion::kShapeOptimized;
840       for (const auto &shape_id : check_shape_args_) {
841           const auto& shape = node_shapes[shape_id];
842           if (shape[shape.ndim()-1] % nvec != 0) {
843             kernel_variant = fusion::kGeneral;
844           }
845       }
846   }
847   CUDA_DRIVER_CALL(
848       cuLaunchKernel(kernel_functions_[kernel_variant],
849         num_blocks, 1, 1,          // grid dim
850         FusedOp::NTHREADS, 1, 1,   // block dim
851         0, stream,                 // shared mem and stream
852         &(args[0]), 0));           // arguments
853 }
854 
FusedOpForwardGPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)855 void FusedOpForwardGPU(const nnvm::NodeAttrs& attrs,
856                     const OpContext &ctx,
857                     const std::vector<TBlob> &inputs,
858                     const std::vector<OpReqType> &req,
859                     const std::vector<TBlob> &outputs) {
860   const FusedOpPtr& op = nnvm::get<FusedOpPtr>(attrs.parsed);
861   op->Forward<gpu>(attrs, ctx, inputs, req, outputs);
862 }
863 
864 NNVM_REGISTER_OP(_FusedOp)
865 .set_attr<FCompute>("FCompute<gpu>", FusedOpForwardGPU);
866 
867 }  // namespace mxnet
868 
869 #endif  // MXNET_ENABLE_CUDA_RTC
870