1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tensorrt.cc
22  * \brief TensorRT operation registration
23  * \author Marek Kolodziej, Clement Fuji Tsang, Serge Panev
24 */
25 
26 #if MXNET_USE_TENSORRT
27 
28 #include "./tensorrt-inl.h"
29 
30 #include <NvInfer.h>
31 
32 namespace mxnet {
33 namespace op {
34 
TRTNumInputs(const nnvm::NodeAttrs & attrs)35 inline uint32_t TRTNumInputs(const nnvm::NodeAttrs& attrs) {
36   const TRTParam& param = nnvm::get<TRTParam>(attrs.parsed);
37   const auto inputs_to_idx = param.inputs_to_idx;
38   return inputs_to_idx.size();
39 }
40 
TRTListInputNames(const nnvm::NodeAttrs & attrs)41 inline std::vector<std::string> TRTListInputNames(const nnvm::NodeAttrs& attrs) {
42   std::vector<std::string> outputs;
43   const TRTParam& param = nnvm::get<TRTParam>(attrs.parsed);
44   const auto inputs_to_idx = param.inputs_to_idx;
45   for (auto& p : inputs_to_idx) {
46     outputs[p.second] = p.first;
47   }
48   return outputs;
49 }
50 
TRTInferShape(const nnvm::NodeAttrs & attrs,std::vector<TShape> * in_shapes,std::vector<TShape> * out_shapes)51 inline bool TRTInferShape(const nnvm::NodeAttrs& attrs,
52                           std::vector<TShape> *in_shapes,
53                           std::vector<TShape> *out_shapes) {
54   using namespace exec;
55   const nnvm::Symbol subgraph_sym = *(attrs.subgraphs[0]);
56   const TRTParam& param = nnvm::get<TRTParam>(attrs.parsed);
57   auto params_map = param.params_map;
58   auto inputs_to_idx = param.inputs_to_idx;
59   nnvm::Graph g;
60   g.outputs = subgraph_sym.outputs;
61   const auto& idx_g = g.indexed_graph();
62   CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size() + params_map.size());
63   CHECK_EQ(idx_g.outputs().size(), out_shapes->size());
64 
65   // Put the input and output shapes to the shape vector.
66   mxnet::ShapeVector shapes(idx_g.num_node_entries());
67   const auto &input_nids = idx_g.input_nodes();
68   CHECK_EQ(input_nids.size(), in_shapes->size() + params_map.size());
69   for (size_t i = 0; i < input_nids.size(); i++) {
70     auto node = idx_g[input_nids[i]].source;
71     auto eid = idx_g.entry_id(input_nids[i], 0);
72     auto it_params = params_map.find(node->attrs.name);
73     auto it_inputs = inputs_to_idx.find(node->attrs.name);
74     if (it_params != params_map.end()) {
75       shapes[eid] = it_params->second.shape();
76     } else if (it_inputs != inputs_to_idx.end()) {
77       shapes[eid] = in_shapes->at(it_inputs->second);
78     } else {
79       LOG(FATAL) << node->attrs.name << " shape information is missing for attributes inference";
80     }
81   }
82   CHECK_EQ(g.outputs.size(), out_shapes->size());
83   for (size_t i = 0; i < out_shapes->size(); i++) {
84     auto eid = idx_g.entry_id(g.outputs[i]);
85     shapes[eid] = out_shapes->at(i);
86   }
87 
88   // Infer shape of the graph.
89   g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
90   g = exec::InferShape(std::move(g));
91   // Copy the inferred shape back to the input shapes and the output shapes.
92   shapes = g.GetAttr<mxnet::ShapeVector>("shape");
93   // assign to in_shapes
94   for (size_t i = 0; i < input_nids.size(); ++i) {
95     const auto node = idx_g[input_nids[i]].source;
96     const auto eid = idx_g.entry_id(input_nids[i], 0);
97     auto it = inputs_to_idx.find(node->attrs.name);
98     if (it != inputs_to_idx.end()) {
99       SHAPE_ASSIGN_CHECK(*in_shapes, it->second, shapes[eid]);
100     }
101   }
102   // assign to out_shapes
103   for (size_t i = 0; i < g.outputs.size(); ++i) {
104     const auto eid = idx_g.entry_id(g.outputs[i]);
105     SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]);
106   }
107   // Check if we have inferred the shapes correctly.
108   return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
109 }
110 
TRTInferType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_types,std::vector<int> * out_types)111 inline bool TRTInferType(const nnvm::NodeAttrs& attrs,
112                     std::vector<int> *in_types,
113                     std::vector<int> *out_types) {
114   const nnvm::Symbol subgraph_sym = *(attrs.subgraphs[0]);
115   const TRTParam& param = nnvm::get<TRTParam>(attrs.parsed);
116   auto params_map = param.params_map;
117   auto inputs_to_idx = param.inputs_to_idx;
118 
119   nnvm::Graph g;
120   g.outputs = subgraph_sym.outputs;
121   const auto& idx_g = g.indexed_graph();
122   CHECK_EQ(idx_g.input_nodes().size(), in_types->size() + params_map.size());
123   CHECK_EQ(idx_g.outputs().size(), out_types->size());
124 
125   // Put the input and output data types to the dtype vector.
126   nnvm::DTypeVector types(idx_g.num_node_entries(), -1);
127   const auto &input_nids = idx_g.input_nodes();
128   CHECK_EQ(input_nids.size(), in_types->size() + params_map.size());
129   for (size_t i = 0; i < input_nids.size(); i++) {
130     auto node = idx_g[input_nids[i]].source;
131     auto eid = idx_g.entry_id(input_nids[i], 0);
132     auto it_params = params_map.find(node->attrs.name);
133     auto it_inputs = inputs_to_idx.find(node->attrs.name);
134     if (it_params != params_map.end()) {
135       types[eid] = -1;
136     } else if (it_inputs != inputs_to_idx.end()) {
137       types[eid] = in_types->at(it_inputs->second);
138     } else {
139       LOG(FATAL) << node->attrs.name
140                  << " dtype information is missing for attributes inference";
141     }
142   }
143   CHECK_EQ(g.outputs.size(), out_types->size());
144   for (size_t i = 0; i < out_types->size(); i++) {
145     auto eid = idx_g.entry_id(g.outputs[i]);
146     types[eid] = out_types->at(i);
147   }
148 
149   // Infer data type of the graph.
150   g.attrs["dtype"] = std::make_shared<dmlc::any>(std::move(types));
151   g = exec::InferType(std::move(g));
152 
153   types = g.GetAttr<nnvm::DTypeVector>("dtype");
154   // assign to in_types
155   for (size_t i = 0; i < input_nids.size(); ++i) {
156     const auto node = idx_g[input_nids[i]].source;
157     const auto eid = idx_g.entry_id(input_nids[i], 0);
158     auto it = inputs_to_idx.find(node->attrs.name);
159     if (it != inputs_to_idx.end()) {
160       TYPE_ASSIGN_CHECK(*in_types, it->second, types[eid]);
161     }
162   }
163   // assign to out_types
164   for (size_t i = 0; i < g.outputs.size(); ++i) {
165     const auto eid = idx_g.entry_id(g.outputs[i]);
166     TYPE_ASSIGN_CHECK(*out_types, i, types[eid]);
167   }
168 
169   // Check if we have inferred the dtypes correctly.
170   return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
171 }
172 
TRTInferStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_stypes,std::vector<int> * out_stypes)173 inline bool TRTInferStorageType(const nnvm::NodeAttrs& attrs,
174                            const int dev_mask,
175                            DispatchMode* dispatch_mode,
176                            std::vector<int>* in_stypes,
177                            std::vector<int>* out_stypes) {
178   const nnvm::Symbol subgraph_sym = *(attrs.subgraphs[0]);
179   const TRTParam& param = nnvm::get<TRTParam>(attrs.parsed);
180   auto params_map = param.params_map;
181   auto inputs_to_idx = param.inputs_to_idx;
182   nnvm::Graph g;
183   g.outputs = subgraph_sym.outputs;
184   const auto& idx_g = g.indexed_graph();
185   CHECK_EQ(idx_g.input_nodes().size(), in_stypes->size() + params_map.size());
186   CHECK_EQ(idx_g.outputs().size(), out_stypes->size());
187   exec::DevMaskVector dev_masks(idx_g.num_node_entries(), dev_mask);
188 
189   // Put the input and output storages to the storage vector.
190   StorageTypeVector stypes(idx_g.num_node_entries(), kUndefinedStorage);
191   const auto &input_nids = idx_g.input_nodes();
192   CHECK_EQ(input_nids.size(), in_stypes->size() + params_map.size());
193   for (size_t i = 0; i < input_nids.size(); i++) {
194     auto node = idx_g[input_nids[i]].source;
195     auto eid = idx_g.entry_id(input_nids[i], 0);
196     auto it_params = params_map.find(node->attrs.name);
197     auto it_inputs = inputs_to_idx.find(node->attrs.name);
198     if (it_params != params_map.end()) {
199       stypes[eid] = it_params->second.storage_type();
200     } else if (it_inputs != inputs_to_idx.end()) {
201       stypes[eid] = in_stypes->at(it_inputs->second);
202     } else {
203       LOG(FATAL) << node->attrs.name
204                  << " storage type information is missing for attributes inference";
205     }
206   }
207   CHECK_EQ(g.outputs.size(), out_stypes->size());
208   for (size_t i = 0; i < out_stypes->size(); i++) {
209     auto eid = idx_g.entry_id(g.outputs[i]);
210     stypes[eid] = out_stypes->at(i);
211   }
212 
213   // Infer storage type of the graph.
214   bool dev_match = g.attrs.count("dev_mask") &&
215                    g.GetAttr<exec::DevMaskVector>("dev_mask") == dev_masks;
216   if (!dev_match) {
217     g.attrs["dev_mask"] = std::make_shared<dmlc::any>(std::move(dev_masks));
218   }
219   g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(stypes));
220   g = exec::InferStorageType(std::move(g));
221 
222   stypes = g.GetAttr<StorageTypeVector>("storage_type");
223   // assign to in_types
224   for (size_t i = 0; i < input_nids.size(); ++i) {
225     const auto node = idx_g[input_nids[i]].source;
226     const auto eid = idx_g.entry_id(input_nids[i], 0);
227     auto it = inputs_to_idx.find(node->attrs.name);
228     if (it != inputs_to_idx.end()) {
229       STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, it->second, stypes[eid]);
230     }
231   }
232 
233   DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
234   // assign to out_types
235   for (size_t i = 0; i < g.outputs.size(); ++i) {
236     const auto eid = idx_g.entry_id(g.outputs[i]);
237     STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, stypes[eid]);
238   }
239   // Check if we have inferred the storages correctly.
240   return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
241 }
242 
TRTParamParser(nnvm::NodeAttrs * attrs)243 void TRTParamParser(nnvm::NodeAttrs* attrs) {
244   TRTParam& _param = nnvm::get<TRTParam>(attrs->parsed);
245   std::string prefix = "subgraph_param_";
246   std::string str_dtype, str_shape, str_pointer, _tmp;
247   for (auto it = attrs->dict.begin(); it != attrs->dict.end();) {
248     std::string attrs_name = it->first;
249     if (std::equal(prefix.begin(), prefix.end(), attrs_name.begin())) {
250       std::string param_name = attrs_name.substr(prefix.size(),
251                                                  attrs_name.size() - prefix.size());
252       // TODO(cfujitsang): find a less dirty way to give weights
253       NDArray *cache = reinterpret_cast<NDArray*>(stol(it->second));
254       _param.params_map.emplace(param_name, cache->Copy(Context()));
255       _param.params_map[param_name].WaitToRead();
256       it = attrs->dict.erase(it);
257     } else {
258       ++it;
259     }
260   }
261   attrs->parsed = std::move(_param);
262 }
263 
TRTCreateState(const nnvm::NodeAttrs & attrs,Context ctx,const std::vector<TShape> & in_shape,const std::vector<int> & in_type)264 OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx,
265                           const std::vector<TShape>& in_shape,
266                           const std::vector<int>& in_type) {
267   const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
268   const bool tensorrt_int8 = node_param.int8_mode;
269   nnvm::Graph graph;
270   graph.outputs = attrs.subgraphs[0]->outputs;
271   uint32_t max_batch_size = dmlc::GetEnv("MXNET_TENSORRT_MAX_BATCH_SIZE", in_shape[0][0]);
272   if (max_batch_size < in_shape[0][0]) {
273     LOG(INFO) << "Warning: max batch size changed to be is: " << in_shape[0][0]
274               << " instead of: " << max_batch_size;
275     max_batch_size = in_shape[0][0];
276   }
277   std::unordered_map<std::string, NDArray> params_map = node_param.params_map;
278   const auto& inputs_to_idx = node_param.inputs_to_idx;
279   const auto& outputs_to_idx = node_param.outputs_to_idx;
280   const auto& idx_g = graph.indexed_graph();
281   const auto& input_nids = idx_g.input_nodes();
282 
283   // needed by the int8 calibrator
284   std::unordered_map<std::string, std::pair<void*, size_t>> input_buffers;
285   mxnet::ShapeVector shape_inputs(input_nids.size());
286   nnvm::DTypeVector dtype_inputs(input_nids.size());
287   for (size_t i = 0; i < input_nids.size(); ++i) {
288     auto node = idx_g[input_nids[i]].source;
289     auto it_params = params_map.find(node->attrs.name);
290     auto it_inputs = inputs_to_idx.find(node->attrs.name);
291     if (it_params != params_map.end()) {
292       shape_inputs[i] = it_params->second.shape();
293       dtype_inputs[i] = it_params->second.dtype();
294     } else if (it_inputs != inputs_to_idx.end()) {
295       shape_inputs[i] = in_shape[it_inputs->second];
296       dtype_inputs[i] = in_type[it_inputs->second];
297       if (tensorrt_int8) {
298         int dtype_size;
299         if (dtype_inputs[i] == mshadow::kFloat32) {
300           dtype_size = 4;
301         } else if (dtype_inputs[i] == mshadow::kFloat16) {
302           dtype_size = 2;
303         } else {
304           LOG(FATAL) << "TensorRT op supports only float32 and float16 inputs.";
305         }
306         size_t buffer_size = shape_inputs[i].Size() * dtype_size;
307         void *ptr;
308         MSHADOW_CUDA_CALL(cudaMalloc(&ptr, buffer_size));
309         input_buffers.emplace(node->attrs.name,
310                               std::make_pair(ptr, buffer_size));
311       }
312     } else {
313       LOG(FATAL) << node->attrs.name << " attribute is missing for attributes inference";
314     }
315   }
316   mxnet::ShapeVector out_shape(graph.outputs.size());
317   nnvm::DTypeVector out_type(graph.outputs.size(), -1);
318   mxnet::ShapeVector _in_shape(in_shape.begin(), in_shape.end());
319   nnvm::DTypeVector _in_type(in_type.begin(), in_type.end());
320   TRTInferShape(attrs, &_in_shape, &out_shape);
321   TRTInferType(attrs, &_in_type, &out_type);
322   nnvm::DTypeVector dtypes(idx_g.num_node_entries());
323   mxnet::ShapeVector shapes(idx_g.num_node_entries());
324   for (size_t i = 0; i < graph.outputs.size(); ++i) {
325     auto eid = idx_g.entry_id(graph.outputs[i]);
326     dtypes[eid] = out_type[i];
327     shapes[eid] = out_shape[i];
328   }
329   graph.attrs["dtype_inputs"] = std::make_shared<nnvm::any>(std::move(dtype_inputs));
330   graph.attrs["shape_inputs"] = std::make_shared<nnvm::any>(std::move(shape_inputs));
331   graph.attrs["dtype"]        = std::make_shared<nnvm::any>(std::move(dtypes));
332   graph.attrs["shape"]        = std::make_shared<nnvm::any>(std::move(shapes));
333 
334   std::unique_ptr<::onnx_to_tensorrt::TRTInt8Calibrator> calibrator;
335   if (tensorrt_int8) {
336     calibrator.reset(
337       new ::onnx_to_tensorrt::TRTInt8Calibrator(params_map, std::move(input_buffers),
338                                                 max_batch_size, node_param.calibration_iters));
339   }
340   auto onnx_graph = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(graph, &params_map);
341   uint32_t verbose = dmlc::GetEnv("MXNET_TENSORRT_VERBOSE", 0);
342   auto log_lvl = nvinfer1::ILogger::Severity::kWARNING;
343   if (verbose != 0) {
344     log_lvl = nvinfer1::ILogger::Severity::kVERBOSE;
345   }
346 
347   auto trt_tuple = ::onnx_to_tensorrt::onnxToTrtCtx(onnx_graph, node_param.fp16_mode,
348                                                     max_batch_size, 1 << 30,
349                                                     calibrator.get(),
350                                                     log_lvl);
351 
352   return OpStatePtr::Create<TRTEngineParam>(std::move(std::get<0>(trt_tuple)),
353                                             std::move(std::get<1>(trt_tuple)),
354                                             std::move(std::get<2>(trt_tuple)),
355                                             inputs_to_idx, outputs_to_idx,
356                                             max_batch_size,
357                                             std::move(calibrator),
358                                             std::move(std::get<3>(trt_tuple)));
359 }
360 
361 NNVM_REGISTER_OP(_TensorRT)
362     .describe(R"code(TRT operation (one engine)
363 )code" ADD_FILELINE)
364     .set_num_inputs(TRTNumInputs)
365     .set_num_outputs(DefaultSubgraphOpNumOutputs)
366     .set_attr_parser(TRTParamParser)
367     .set_attr<mxnet::FInferShape>("FInferShape", TRTInferShape)
368     .set_attr<nnvm::FInferType>("FInferType", TRTInferType)
369     .set_attr<nnvm::FListInputNames>("FListInputNames", TRTListInputNames)
370     .set_attr<nnvm::FListOutputNames>("FListOutputNames", DefaultSubgraphOpListOutputs)
371     .set_attr<FCreateOpState>("FCreateOpState", TRTCreateState)
372     .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
__anonebf922ea0102(const NodeAttrs& attrs, const bool) 373         [](const NodeAttrs& attrs, const bool) {
374           const TRTParam& param = nnvm::get<TRTParam>(attrs.parsed);
375           return !param.int8_mode;
376         })
377     .set_attr<FInferStorageType>("FInferStorageType", TRTInferStorageType);
378 
379 MXNET_REGISTER_SUBGRAPH_BACKEND(TensorRT);
380 
381 MXNET_REGISTER_SUBGRAPH_PROPERTY(TensorRT, TensorrtProperty);
382 }  // namespace op
383 }  // namespace mxnet
384 
385 #endif  // MXNET_USE_TENSORRT
386