1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file tensorrt.cc
22 * \brief TensorRT operation registration
23 * \author Marek Kolodziej, Clement Fuji Tsang, Serge Panev
24 */
25
26 #if MXNET_USE_TENSORRT
27
28 #include "./tensorrt-inl.h"
29
30 #include <NvInfer.h>
31
32 namespace mxnet {
33 namespace op {
34
TRTNumInputs(const nnvm::NodeAttrs & attrs)35 inline uint32_t TRTNumInputs(const nnvm::NodeAttrs& attrs) {
36 const TRTParam& param = nnvm::get<TRTParam>(attrs.parsed);
37 const auto inputs_to_idx = param.inputs_to_idx;
38 return inputs_to_idx.size();
39 }
40
TRTListInputNames(const nnvm::NodeAttrs & attrs)41 inline std::vector<std::string> TRTListInputNames(const nnvm::NodeAttrs& attrs) {
42 std::vector<std::string> outputs;
43 const TRTParam& param = nnvm::get<TRTParam>(attrs.parsed);
44 const auto inputs_to_idx = param.inputs_to_idx;
45 for (auto& p : inputs_to_idx) {
46 outputs[p.second] = p.first;
47 }
48 return outputs;
49 }
50
TRTInferShape(const nnvm::NodeAttrs & attrs,std::vector<TShape> * in_shapes,std::vector<TShape> * out_shapes)51 inline bool TRTInferShape(const nnvm::NodeAttrs& attrs,
52 std::vector<TShape> *in_shapes,
53 std::vector<TShape> *out_shapes) {
54 using namespace exec;
55 const nnvm::Symbol subgraph_sym = *(attrs.subgraphs[0]);
56 const TRTParam& param = nnvm::get<TRTParam>(attrs.parsed);
57 auto params_map = param.params_map;
58 auto inputs_to_idx = param.inputs_to_idx;
59 nnvm::Graph g;
60 g.outputs = subgraph_sym.outputs;
61 const auto& idx_g = g.indexed_graph();
62 CHECK_EQ(idx_g.input_nodes().size(), in_shapes->size() + params_map.size());
63 CHECK_EQ(idx_g.outputs().size(), out_shapes->size());
64
65 // Put the input and output shapes to the shape vector.
66 mxnet::ShapeVector shapes(idx_g.num_node_entries());
67 const auto &input_nids = idx_g.input_nodes();
68 CHECK_EQ(input_nids.size(), in_shapes->size() + params_map.size());
69 for (size_t i = 0; i < input_nids.size(); i++) {
70 auto node = idx_g[input_nids[i]].source;
71 auto eid = idx_g.entry_id(input_nids[i], 0);
72 auto it_params = params_map.find(node->attrs.name);
73 auto it_inputs = inputs_to_idx.find(node->attrs.name);
74 if (it_params != params_map.end()) {
75 shapes[eid] = it_params->second.shape();
76 } else if (it_inputs != inputs_to_idx.end()) {
77 shapes[eid] = in_shapes->at(it_inputs->second);
78 } else {
79 LOG(FATAL) << node->attrs.name << " shape information is missing for attributes inference";
80 }
81 }
82 CHECK_EQ(g.outputs.size(), out_shapes->size());
83 for (size_t i = 0; i < out_shapes->size(); i++) {
84 auto eid = idx_g.entry_id(g.outputs[i]);
85 shapes[eid] = out_shapes->at(i);
86 }
87
88 // Infer shape of the graph.
89 g.attrs["shape"] = std::make_shared<dmlc::any>(std::move(shapes));
90 g = exec::InferShape(std::move(g));
91 // Copy the inferred shape back to the input shapes and the output shapes.
92 shapes = g.GetAttr<mxnet::ShapeVector>("shape");
93 // assign to in_shapes
94 for (size_t i = 0; i < input_nids.size(); ++i) {
95 const auto node = idx_g[input_nids[i]].source;
96 const auto eid = idx_g.entry_id(input_nids[i], 0);
97 auto it = inputs_to_idx.find(node->attrs.name);
98 if (it != inputs_to_idx.end()) {
99 SHAPE_ASSIGN_CHECK(*in_shapes, it->second, shapes[eid]);
100 }
101 }
102 // assign to out_shapes
103 for (size_t i = 0; i < g.outputs.size(); ++i) {
104 const auto eid = idx_g.entry_id(g.outputs[i]);
105 SHAPE_ASSIGN_CHECK(*out_shapes, i, shapes[eid]);
106 }
107 // Check if we have inferred the shapes correctly.
108 return g.GetAttr<size_t>("shape_num_unknown_nodes") == 0;
109 }
110
TRTInferType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_types,std::vector<int> * out_types)111 inline bool TRTInferType(const nnvm::NodeAttrs& attrs,
112 std::vector<int> *in_types,
113 std::vector<int> *out_types) {
114 const nnvm::Symbol subgraph_sym = *(attrs.subgraphs[0]);
115 const TRTParam& param = nnvm::get<TRTParam>(attrs.parsed);
116 auto params_map = param.params_map;
117 auto inputs_to_idx = param.inputs_to_idx;
118
119 nnvm::Graph g;
120 g.outputs = subgraph_sym.outputs;
121 const auto& idx_g = g.indexed_graph();
122 CHECK_EQ(idx_g.input_nodes().size(), in_types->size() + params_map.size());
123 CHECK_EQ(idx_g.outputs().size(), out_types->size());
124
125 // Put the input and output data types to the dtype vector.
126 nnvm::DTypeVector types(idx_g.num_node_entries(), -1);
127 const auto &input_nids = idx_g.input_nodes();
128 CHECK_EQ(input_nids.size(), in_types->size() + params_map.size());
129 for (size_t i = 0; i < input_nids.size(); i++) {
130 auto node = idx_g[input_nids[i]].source;
131 auto eid = idx_g.entry_id(input_nids[i], 0);
132 auto it_params = params_map.find(node->attrs.name);
133 auto it_inputs = inputs_to_idx.find(node->attrs.name);
134 if (it_params != params_map.end()) {
135 types[eid] = -1;
136 } else if (it_inputs != inputs_to_idx.end()) {
137 types[eid] = in_types->at(it_inputs->second);
138 } else {
139 LOG(FATAL) << node->attrs.name
140 << " dtype information is missing for attributes inference";
141 }
142 }
143 CHECK_EQ(g.outputs.size(), out_types->size());
144 for (size_t i = 0; i < out_types->size(); i++) {
145 auto eid = idx_g.entry_id(g.outputs[i]);
146 types[eid] = out_types->at(i);
147 }
148
149 // Infer data type of the graph.
150 g.attrs["dtype"] = std::make_shared<dmlc::any>(std::move(types));
151 g = exec::InferType(std::move(g));
152
153 types = g.GetAttr<nnvm::DTypeVector>("dtype");
154 // assign to in_types
155 for (size_t i = 0; i < input_nids.size(); ++i) {
156 const auto node = idx_g[input_nids[i]].source;
157 const auto eid = idx_g.entry_id(input_nids[i], 0);
158 auto it = inputs_to_idx.find(node->attrs.name);
159 if (it != inputs_to_idx.end()) {
160 TYPE_ASSIGN_CHECK(*in_types, it->second, types[eid]);
161 }
162 }
163 // assign to out_types
164 for (size_t i = 0; i < g.outputs.size(); ++i) {
165 const auto eid = idx_g.entry_id(g.outputs[i]);
166 TYPE_ASSIGN_CHECK(*out_types, i, types[eid]);
167 }
168
169 // Check if we have inferred the dtypes correctly.
170 return g.GetAttr<size_t>("dtype_num_unknown_nodes") == 0;
171 }
172
TRTInferStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_stypes,std::vector<int> * out_stypes)173 inline bool TRTInferStorageType(const nnvm::NodeAttrs& attrs,
174 const int dev_mask,
175 DispatchMode* dispatch_mode,
176 std::vector<int>* in_stypes,
177 std::vector<int>* out_stypes) {
178 const nnvm::Symbol subgraph_sym = *(attrs.subgraphs[0]);
179 const TRTParam& param = nnvm::get<TRTParam>(attrs.parsed);
180 auto params_map = param.params_map;
181 auto inputs_to_idx = param.inputs_to_idx;
182 nnvm::Graph g;
183 g.outputs = subgraph_sym.outputs;
184 const auto& idx_g = g.indexed_graph();
185 CHECK_EQ(idx_g.input_nodes().size(), in_stypes->size() + params_map.size());
186 CHECK_EQ(idx_g.outputs().size(), out_stypes->size());
187 exec::DevMaskVector dev_masks(idx_g.num_node_entries(), dev_mask);
188
189 // Put the input and output storages to the storage vector.
190 StorageTypeVector stypes(idx_g.num_node_entries(), kUndefinedStorage);
191 const auto &input_nids = idx_g.input_nodes();
192 CHECK_EQ(input_nids.size(), in_stypes->size() + params_map.size());
193 for (size_t i = 0; i < input_nids.size(); i++) {
194 auto node = idx_g[input_nids[i]].source;
195 auto eid = idx_g.entry_id(input_nids[i], 0);
196 auto it_params = params_map.find(node->attrs.name);
197 auto it_inputs = inputs_to_idx.find(node->attrs.name);
198 if (it_params != params_map.end()) {
199 stypes[eid] = it_params->second.storage_type();
200 } else if (it_inputs != inputs_to_idx.end()) {
201 stypes[eid] = in_stypes->at(it_inputs->second);
202 } else {
203 LOG(FATAL) << node->attrs.name
204 << " storage type information is missing for attributes inference";
205 }
206 }
207 CHECK_EQ(g.outputs.size(), out_stypes->size());
208 for (size_t i = 0; i < out_stypes->size(); i++) {
209 auto eid = idx_g.entry_id(g.outputs[i]);
210 stypes[eid] = out_stypes->at(i);
211 }
212
213 // Infer storage type of the graph.
214 bool dev_match = g.attrs.count("dev_mask") &&
215 g.GetAttr<exec::DevMaskVector>("dev_mask") == dev_masks;
216 if (!dev_match) {
217 g.attrs["dev_mask"] = std::make_shared<dmlc::any>(std::move(dev_masks));
218 }
219 g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(stypes));
220 g = exec::InferStorageType(std::move(g));
221
222 stypes = g.GetAttr<StorageTypeVector>("storage_type");
223 // assign to in_types
224 for (size_t i = 0; i < input_nids.size(); ++i) {
225 const auto node = idx_g[input_nids[i]].source;
226 const auto eid = idx_g.entry_id(input_nids[i], 0);
227 auto it = inputs_to_idx.find(node->attrs.name);
228 if (it != inputs_to_idx.end()) {
229 STORAGE_TYPE_ASSIGN_CHECK(*in_stypes, it->second, stypes[eid]);
230 }
231 }
232
233 DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
234 // assign to out_types
235 for (size_t i = 0; i < g.outputs.size(); ++i) {
236 const auto eid = idx_g.entry_id(g.outputs[i]);
237 STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, stypes[eid]);
238 }
239 // Check if we have inferred the storages correctly.
240 return g.GetAttr<size_t>("storage_type_num_unknown_nodes") == 0;
241 }
242
TRTParamParser(nnvm::NodeAttrs * attrs)243 void TRTParamParser(nnvm::NodeAttrs* attrs) {
244 TRTParam& _param = nnvm::get<TRTParam>(attrs->parsed);
245 std::string prefix = "subgraph_param_";
246 std::string str_dtype, str_shape, str_pointer, _tmp;
247 for (auto it = attrs->dict.begin(); it != attrs->dict.end();) {
248 std::string attrs_name = it->first;
249 if (std::equal(prefix.begin(), prefix.end(), attrs_name.begin())) {
250 std::string param_name = attrs_name.substr(prefix.size(),
251 attrs_name.size() - prefix.size());
252 // TODO(cfujitsang): find a less dirty way to give weights
253 NDArray *cache = reinterpret_cast<NDArray*>(stol(it->second));
254 _param.params_map.emplace(param_name, cache->Copy(Context()));
255 _param.params_map[param_name].WaitToRead();
256 it = attrs->dict.erase(it);
257 } else {
258 ++it;
259 }
260 }
261 attrs->parsed = std::move(_param);
262 }
263
TRTCreateState(const nnvm::NodeAttrs & attrs,Context ctx,const std::vector<TShape> & in_shape,const std::vector<int> & in_type)264 OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx,
265 const std::vector<TShape>& in_shape,
266 const std::vector<int>& in_type) {
267 const auto& node_param = nnvm::get<TRTParam>(attrs.parsed);
268 const bool tensorrt_int8 = node_param.int8_mode;
269 nnvm::Graph graph;
270 graph.outputs = attrs.subgraphs[0]->outputs;
271 uint32_t max_batch_size = dmlc::GetEnv("MXNET_TENSORRT_MAX_BATCH_SIZE", in_shape[0][0]);
272 if (max_batch_size < in_shape[0][0]) {
273 LOG(INFO) << "Warning: max batch size changed to be is: " << in_shape[0][0]
274 << " instead of: " << max_batch_size;
275 max_batch_size = in_shape[0][0];
276 }
277 std::unordered_map<std::string, NDArray> params_map = node_param.params_map;
278 const auto& inputs_to_idx = node_param.inputs_to_idx;
279 const auto& outputs_to_idx = node_param.outputs_to_idx;
280 const auto& idx_g = graph.indexed_graph();
281 const auto& input_nids = idx_g.input_nodes();
282
283 // needed by the int8 calibrator
284 std::unordered_map<std::string, std::pair<void*, size_t>> input_buffers;
285 mxnet::ShapeVector shape_inputs(input_nids.size());
286 nnvm::DTypeVector dtype_inputs(input_nids.size());
287 for (size_t i = 0; i < input_nids.size(); ++i) {
288 auto node = idx_g[input_nids[i]].source;
289 auto it_params = params_map.find(node->attrs.name);
290 auto it_inputs = inputs_to_idx.find(node->attrs.name);
291 if (it_params != params_map.end()) {
292 shape_inputs[i] = it_params->second.shape();
293 dtype_inputs[i] = it_params->second.dtype();
294 } else if (it_inputs != inputs_to_idx.end()) {
295 shape_inputs[i] = in_shape[it_inputs->second];
296 dtype_inputs[i] = in_type[it_inputs->second];
297 if (tensorrt_int8) {
298 int dtype_size;
299 if (dtype_inputs[i] == mshadow::kFloat32) {
300 dtype_size = 4;
301 } else if (dtype_inputs[i] == mshadow::kFloat16) {
302 dtype_size = 2;
303 } else {
304 LOG(FATAL) << "TensorRT op supports only float32 and float16 inputs.";
305 }
306 size_t buffer_size = shape_inputs[i].Size() * dtype_size;
307 void *ptr;
308 MSHADOW_CUDA_CALL(cudaMalloc(&ptr, buffer_size));
309 input_buffers.emplace(node->attrs.name,
310 std::make_pair(ptr, buffer_size));
311 }
312 } else {
313 LOG(FATAL) << node->attrs.name << " attribute is missing for attributes inference";
314 }
315 }
316 mxnet::ShapeVector out_shape(graph.outputs.size());
317 nnvm::DTypeVector out_type(graph.outputs.size(), -1);
318 mxnet::ShapeVector _in_shape(in_shape.begin(), in_shape.end());
319 nnvm::DTypeVector _in_type(in_type.begin(), in_type.end());
320 TRTInferShape(attrs, &_in_shape, &out_shape);
321 TRTInferType(attrs, &_in_type, &out_type);
322 nnvm::DTypeVector dtypes(idx_g.num_node_entries());
323 mxnet::ShapeVector shapes(idx_g.num_node_entries());
324 for (size_t i = 0; i < graph.outputs.size(); ++i) {
325 auto eid = idx_g.entry_id(graph.outputs[i]);
326 dtypes[eid] = out_type[i];
327 shapes[eid] = out_shape[i];
328 }
329 graph.attrs["dtype_inputs"] = std::make_shared<nnvm::any>(std::move(dtype_inputs));
330 graph.attrs["shape_inputs"] = std::make_shared<nnvm::any>(std::move(shape_inputs));
331 graph.attrs["dtype"] = std::make_shared<nnvm::any>(std::move(dtypes));
332 graph.attrs["shape"] = std::make_shared<nnvm::any>(std::move(shapes));
333
334 std::unique_ptr<::onnx_to_tensorrt::TRTInt8Calibrator> calibrator;
335 if (tensorrt_int8) {
336 calibrator.reset(
337 new ::onnx_to_tensorrt::TRTInt8Calibrator(params_map, std::move(input_buffers),
338 max_batch_size, node_param.calibration_iters));
339 }
340 auto onnx_graph = op::nnvm_to_onnx::ConvertNnvmGraphToOnnx(graph, ¶ms_map);
341 uint32_t verbose = dmlc::GetEnv("MXNET_TENSORRT_VERBOSE", 0);
342 auto log_lvl = nvinfer1::ILogger::Severity::kWARNING;
343 if (verbose != 0) {
344 log_lvl = nvinfer1::ILogger::Severity::kVERBOSE;
345 }
346
347 auto trt_tuple = ::onnx_to_tensorrt::onnxToTrtCtx(onnx_graph, node_param.fp16_mode,
348 max_batch_size, 1 << 30,
349 calibrator.get(),
350 log_lvl);
351
352 return OpStatePtr::Create<TRTEngineParam>(std::move(std::get<0>(trt_tuple)),
353 std::move(std::get<1>(trt_tuple)),
354 std::move(std::get<2>(trt_tuple)),
355 inputs_to_idx, outputs_to_idx,
356 max_batch_size,
357 std::move(calibrator),
358 std::move(std::get<3>(trt_tuple)));
359 }
360
361 NNVM_REGISTER_OP(_TensorRT)
362 .describe(R"code(TRT operation (one engine)
363 )code" ADD_FILELINE)
364 .set_num_inputs(TRTNumInputs)
365 .set_num_outputs(DefaultSubgraphOpNumOutputs)
366 .set_attr_parser(TRTParamParser)
367 .set_attr<mxnet::FInferShape>("FInferShape", TRTInferShape)
368 .set_attr<nnvm::FInferType>("FInferType", TRTInferType)
369 .set_attr<nnvm::FListInputNames>("FListInputNames", TRTListInputNames)
370 .set_attr<nnvm::FListOutputNames>("FListOutputNames", DefaultSubgraphOpListOutputs)
371 .set_attr<FCreateOpState>("FCreateOpState", TRTCreateState)
372 .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
__anonebf922ea0102(const NodeAttrs& attrs, const bool) 373 [](const NodeAttrs& attrs, const bool) {
374 const TRTParam& param = nnvm::get<TRTParam>(attrs.parsed);
375 return !param.int8_mode;
376 })
377 .set_attr<FInferStorageType>("FInferStorageType", TRTInferStorageType);
378
379 MXNET_REGISTER_SUBGRAPH_BACKEND(TensorRT);
380
381 MXNET_REGISTER_SUBGRAPH_PROPERTY(TensorRT, TensorrtProperty);
382 } // namespace op
383 } // namespace mxnet
384
385 #endif // MXNET_USE_TENSORRT
386