1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file nnvm_to_onnx.cc
22 * \brief Conversion from NNVM to ONNX for TensorRT
23 * \author Marek Kolodziej, Clement Fuji Tsang
24 */
25
26 #if MXNET_USE_TENSORRT
27
28 #include "./nnvm_to_onnx-inl.h"
29
30 #include <mxnet/base.h>
31 #include <nnvm/graph.h>
32 #include <nnvm/pass_functions.h>
33
34 #include "../../../common/utils.h"
35 #include "../../../ndarray/ndarray_function.h"
36 #include "../../pad-inl.h"
37 #include "../../nn/activation-inl.h"
38 #include "../../nn/batch_norm-inl.h"
39 #include "../../nn/convolution-inl.h"
40 #include "../../nn/deconvolution-inl.h"
41 #include "../../nn/fully_connected-inl.h"
42 #include "../../nn/pooling-inl.h"
43 #include "../../nn/concat-inl.h"
44 #include "../../softmax_output-inl.h"
45 #include "../../tensor/matrix_op-inl.h"
46
47 #if MXNET_USE_TENSORRT_ONNX_CHECKER
48 #include <onnx/checker.h>
49 #endif // MXNET_USE_TENSORRT_ONNX_CHECKER
50
51 namespace mxnet {
52 namespace op {
53 namespace nnvm_to_onnx {
54
ConvertNnvmGraphToOnnx(const nnvm::Graph & g,std::unordered_map<std::string,NDArray> * params_map)55 std::string ConvertNnvmGraphToOnnx(
56 const nnvm::Graph& g,
57 std::unordered_map<std::string, NDArray>* params_map) {
58
59 static std::atomic_ulong subgraph_count = { 0 };
60
61 std::string serialized_onnx_graph;
62
63 const nnvm::IndexedGraph& ig = g.indexed_graph();
64 const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
65 const auto& shapes = g.GetAttr<ShapeVector>("shape");
66 const auto& dtype_inputs = g.GetAttr<DTypeVector>("dtype_inputs");
67 const auto& shape_inputs = g.GetAttr<ShapeVector>("shape_inputs");
68
69 ModelProto model_proto;
70
71 // We're currently serializing our models in ONNX 3, opset 8 as it is best supported by the
72 // currently linked version of the onnx-tensorrt library.
73 // More information on ONNX versions and opsets can be found at:
74 // https://github.com/onnx/onnx/blob/master/docs/IR.md
75
76 auto opset_proto = model_proto.add_opset_import();
77 const int64 onnx_opset = 8;
78 const int64 onnx_major_version = 3;
79
80 // Declare our ONNX versions in our protobuf model.
81 opset_proto->set_version(onnx_opset);
82 model_proto.set_ir_version(onnx_major_version);
83
84 GraphProto* graph_proto = model_proto.mutable_graph();
85 auto subgraph_name_id = subgraph_count.fetch_add(1);
86 graph_proto->set_name("MXNetTRTSubgraph" + std::to_string(subgraph_name_id));
87
88 auto placeholder_shapes = GetPlaceholderShapes(shape_inputs, ig);
89 auto placeholder_dtypes = GetPlaceholderDTypes(dtype_inputs, ig);
90 auto output_lookup = GetOutputLookup(ig);
91
92 for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) {
93 const IndexedGraph::Node& node = ig[node_idx];
94 const nnvm::Node* source = node.source;
95 // If this is a op
96 if (!source->is_variable()) {
97 auto mightNeedPreprocessNode = preprocess_map.find(source->op()->name);
98 // if this op is defined in preprocess_map
99 if (mightNeedPreprocessNode != preprocess_map.end()) {
100 mightNeedPreprocessNode->second(source->attrs, source->inputs, params_map);
101 }
102 }
103 }
104
105 uint32_t current_input = 0;
106 // Can't do a foreach over IndexedGraph since it doesn't implement begin(), etc.
107 for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) {
108 const IndexedGraph::Node& node = ig[node_idx];
109 const nnvm::Node* source = node.source;
110 const NodeAttrs& attrs = source->attrs;
111 const Op* op = source->op();
112
113 std::string node_name = attrs.name;
114 // Here, "variable" actually means anything that's not an op i.e. a constant (weights) or a
115 // placeholder
116 if (source->is_variable()) {
117 // Is this a placeholder?
118 if (params_map->count(node_name) == 0) {
119 // This fixes the problem with a SoftmaxOutput node during inference, but it's hacky.
120 // Need to figure out how to properly fix it.
121 if (node_name.find("label") != std::string::npos) {
122 current_input++;
123 continue;
124 }
125 ConvertPlaceholder(node_name, placeholder_shapes, placeholder_dtypes, graph_proto);
126 } else {
127 // If it's not a placeholder, then by exclusion it's a constant.
128 ConvertConstant(graph_proto, node_name, params_map);
129 } // is_placeholder
130 } else {
131 // It's an op, rather than a "variable" (constant or placeholder)
132 if (converter_map.count(op->name) == 0) {
133 LOG(FATAL) << "Conversion for node of type " << op->name << " (node "
134 << node_name << ") "
135 << " is not supported yet.";
136 }
137 // Find function ptr to a converter based on the op name, and invoke the converter. This
138 // looks unsafe because find may not succeed, but it does because we're in the operator
139 // logic after testing that this node name does not represent a variable.
140 converter_map.find(op->name)->second(graph_proto, node_name, attrs, ig, node.inputs);
141 // See if the current node is an output node
142 auto out_iter = output_lookup.find(node_name);
143 // We found an output
144 if (out_iter != output_lookup.end()) {
145 ConvertOutput(graph_proto, out_iter, node_name, shapes, dtypes, ig);
146 } // output found
147 } // conversion function exists
148 } // loop over i from 0 to num_nodes
149
150 model_proto.SerializeToString(&serialized_onnx_graph);
151
152 #if MXNET_USE_TENSORRT_ONNX_CHECKER
153 onnx::checker::check_model(model_proto);
154 #endif // MXNET_USE_TENSORRT_ONNX_CHECKER
155
156 return serialized_onnx_graph;
157 }
158
DefaultConnectInputsOutputs(NodeProto * node_proto,const array_view<IndexedGraph::NodeEntry> & inputs,const nnvm::IndexedGraph & ig,const std::string & node_name)159 void DefaultConnectInputsOutputs(NodeProto *node_proto,
160 const array_view<IndexedGraph::NodeEntry>& inputs,
161 const nnvm::IndexedGraph& ig,
162 const std::string& node_name) {
163 for (const nnvm::IndexedGraph::NodeEntry& entry : inputs) {
164 std::string in_node_name = ig[entry.node_id].source->attrs.name;
165 // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish there was a less
166 // hacky way to do it than name matching.
167 if (in_node_name.find("label") != std::string::npos) {
168 continue;
169 }
170 node_proto->add_input(in_node_name);
171 }
172 // The node's output will have the same name as the node name.
173 node_proto->add_output(node_name);
174 }
175
Make1DTensor(GraphProto * const graph_proto,const int64_t & size,const std::string & name,const TensorProto_DataType & dtype)176 TensorProto* const Make1DTensor(GraphProto* const graph_proto, const int64_t& size,
177 const std::string& name, const TensorProto_DataType& dtype) {
178 TensorProto* const initializer_proto = graph_proto->add_initializer();
179 initializer_proto->set_name(name);
180 initializer_proto->set_data_type(dtype);
181 initializer_proto->add_dims(static_cast<int64>(size));
182
183 ValueInfoProto* const input_proto = graph_proto->add_input();
184 input_proto->set_name(name);
185 auto var = input_proto->mutable_type()->mutable_tensor_type();
186 var->set_elem_type(dtype);
187 var->mutable_shape()->add_dim()->set_dim_value(static_cast<int64>(size));
188 return initializer_proto;
189 }
190
191 // Keep for when ONNX version will be updated
192 /*
193 void ConvertSlice(GraphProto* const graph_proto, const Node* node, const Graph& g) {
194 const auto& params = nnvm::get<SliceParam>(node->attrs.parsed);
195 int64 nb_slices = static_cast<int64>(params.begin.ndim());
196
197 // starts
198 auto init_starts = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_starts",
199 TensorProto_DataType_INT64);
200 for (auto& opt : params.begin) {
201 if (opt.has_value()) {
202 init_starts->add_int64_data(static_cast<int64>(opt.value()));
203 } else {
204 init_starts->add_int64_data(static_cast<int64>(0));
205 }
206 }
207
208 // ends
209 auto init_ends = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_ends",
210 TensorProto_DataType_INT64);
211 for (auto& opt : params.end) {
212 if (opt.has_value()) {
213 init_ends->add_int64_data(static_cast<int64>(opt.value()));
214 } else {
215 init_ends->add_int64_data(static_cast<int64>(INT_MAX));
216 }
217 }
218
219 // axes
220 auto init_axes = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_axes",
221 TensorProto_DataType_INT64);
222 for (int64_t i = 0; i < nb_slices; ++i) {
223 init_axes->add_int64_data(static_cast<int64>(i));
224 }
225
226 // slice node
227 NodeProto* node_proto = graph_proto->add_node();
228 node_proto->set_name(node->attrs.name);
229 node_proto->set_op_type("Slice");
230 node_proto->add_input(node->inputs[0].node->attrs.name);
231 node_proto->add_input(node->attrs.name + "_starts");
232 node_proto->add_input(node->attrs.name + "_ends");
233 node_proto->add_input(node->attrs.name + "_axes");
234
235 // steps
236 if (params.step.ndim() != 0) {
237 auto init_steps = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_steps",
238 TensorProto_DataType_INT64);
239 for (auto& opt : params.step) {
240 if (opt.has_value()) {
241 init_steps->add_int64_data(static_cast<int64>(opt.value()));
242 } else {
243 init_steps->add_int64_data(static_cast<int64>(1));
244 }
245 }
246 node_proto->add_input(node->attrs.name + "_steps");
247 }
248
249 node_proto->add_output(node->attrs.name);
250 }
251 */
252
ConvertIdentity(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)253 void ConvertIdentity(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
254 const nnvm::IndexedGraph& ig,
255 const array_view<IndexedGraph::NodeEntry>& inputs) {
256 NodeProto* node_proto = graph_proto->add_node();
257 node_proto->set_name(node_name);
258 node_proto->set_op_type("Identity");
259 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
260 }
261
262 template <class ConvDeconvParam>
ConvDeconvConvertHelper(NodeProto * node_proto,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs,const ConvDeconvParam & param,ConvDeconvType type)263 void ConvDeconvConvertHelper(NodeProto *node_proto, const NodeAttrs& attrs,
264 const nnvm::IndexedGraph& ig,
265 const array_view<IndexedGraph::NodeEntry>& inputs,
266 const ConvDeconvParam& param,
267 ConvDeconvType type) {
268 if (type == ConvDeconvType::Convolution) {
269 node_proto->set_op_type("Conv");
270 } else {
271 node_proto->set_op_type("ConvTranspose");
272 }
273
274 const mxnet::TShape kernel = param.kernel;
275 const mxnet::TShape stride = param.stride;
276 const mxnet::TShape dilate = param.dilate;
277 const mxnet::TShape pad = param.pad;
278 const uint32_t num_group = param.num_group;
279 // const bool no_bias = conv_param.no_bias;
280 const dmlc::optional<int> layout = param.layout;
281
282 // dilations
283 AttributeProto* const dilations = node_proto->add_attribute();
284 dilations->set_name("dilations");
285 dilations->set_type(AttributeProto::INTS);
286 for (const dim_t kval : dilate) {
287 dilations->add_ints(static_cast<int64>(kval));
288 }
289
290 // group
291 AttributeProto* const group = node_proto->add_attribute();
292 group->set_name("group");
293 group->set_type(AttributeProto::INT);
294 group->set_i(static_cast<int64>(num_group));
295
296 // kernel shape
297 AttributeProto* const kernel_shape = node_proto->add_attribute();
298 kernel_shape->set_name("kernel_shape");
299 kernel_shape->set_type(AttributeProto::INTS);
300
301 for (const dim_t kval : kernel) {
302 kernel_shape->add_ints(static_cast<int64>(kval));
303 }
304
305 // pads
306 AttributeProto* const pads = node_proto->add_attribute();
307 pads->set_name("pads");
308 pads->set_type(AttributeProto::INTS);
309
310 for (int i =0; i < 2; i++) {
311 for (dim_t kval : pad) {
312 pads->add_ints(static_cast<int64>(kval));
313 }
314 }
315
316 // strides
317 AttributeProto* const strides = node_proto->add_attribute();
318 strides->set_name("strides");
319 strides->set_type(AttributeProto::INTS);
320 for (const dim_t kval : stride) {
321 strides->add_ints(static_cast<int64>(kval));
322 }
323 }
324
ConvertConvolution(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)325 void ConvertConvolution(GraphProto *graph_proto, const std::string& node_name,
326 const NodeAttrs& attrs,
327 const nnvm::IndexedGraph& ig,
328 const array_view<IndexedGraph::NodeEntry>& inputs) {
329 NodeProto* node_proto = graph_proto->add_node();
330 node_proto->set_name(node_name);
331 const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
332 ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, conv_param,
333 ConvDeconvType::Convolution);
334 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
335 } // end ConvertConvolution
336
ConvertDeconvolution(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)337 void ConvertDeconvolution(GraphProto *graph_proto, const std::string& node_name,
338 const NodeAttrs& attrs,
339 const nnvm::IndexedGraph& ig,
340 const array_view<IndexedGraph::NodeEntry>& inputs) {
341 NodeProto* node_proto = graph_proto->add_node();
342 node_proto->set_name(node_name);
343 const auto& deconv_param = nnvm::get<op::DeconvolutionParam>(attrs.parsed);
344 ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, deconv_param,
345 ConvDeconvType::Deconvolution);
346 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
347 } // end ConvertDeconvolution
348
ConvertPooling(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)349 void ConvertPooling(GraphProto *graph_proto, const std::string& node_name,
350 const NodeAttrs& attrs,
351 const nnvm::IndexedGraph& ig,
352 const array_view<IndexedGraph::NodeEntry>& inputs) {
353 NodeProto* node_proto = graph_proto->add_node();
354 node_proto->set_name(node_name);
355 const auto& pooling_param = nnvm::get<op::PoolingParam>(attrs.parsed);
356
357 const mxnet::TShape kernel = pooling_param.kernel;
358 const mxnet::TShape stride = pooling_param.stride;
359 const mxnet::TShape pad = pooling_param.pad;
360 const int pool_type = pooling_param.pool_type;
361 const bool global_pool = pooling_param.global_pool;
362
363 if (global_pool) {
364 if (pool_type == pool_enum::kMaxPooling) {
365 node_proto->set_op_type("GlobalMaxPool");
366 } else if (pool_type == pool_enum::kAvgPooling) {
367 node_proto->set_op_type("GlobalAveragePool");
368 } else {
369 LOG(FATAL) << "Pool type of node '" << attrs.name << "' unsupported: " << attrs.name;
370 }
371 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
372 return;
373 }
374
375 // kernel_shape
376 AttributeProto* const kernel_shape = node_proto->add_attribute();
377 kernel_shape->set_name("kernel_shape");
378 kernel_shape->set_type(AttributeProto::INTS);
379 for (dim_t kval : kernel) {
380 kernel_shape->add_ints(static_cast<int64>(kval));
381 }
382
383 // pads
384 AttributeProto* const pads = node_proto->add_attribute();
385 pads->set_name("pads");
386 pads->set_type(AttributeProto::INTS);
387
388 // Convert from MXNet symetric pads to ONNX non-symetric by running through padding twice.
389 for (int i =0; i < 2; i++) {
390 for (dim_t kval : pad) {
391 pads->add_ints(static_cast<int64>(kval));
392 }
393 }
394
395 // strides
396 AttributeProto* const strides = node_proto->add_attribute();
397 strides->set_name("strides");
398 strides->set_type(AttributeProto::INTS);
399 for (dim_t kval : stride) {
400 strides->add_ints(static_cast<int64>(kval));
401 }
402
403 // ceil_mode
404 AttributeProto* const ceil_mode = node_proto->add_attribute();
405 ceil_mode->set_name("ceil_mode");
406 ceil_mode->set_type(AttributeProto::INT);
407 ceil_mode->set_i(static_cast<int64>(pooling_param.pooling_convention == pool_enum::kFull));
408
409 if (pool_type == pool_enum::kMaxPooling) {
410 node_proto->set_op_type("MaxPool");
411 } else if (pool_type == pool_enum::kAvgPooling) {
412 node_proto->set_op_type("AveragePool");
413 } else {
414 LOG(FATAL) << "Pool type of node '" << attrs.name << "' unsupported: " << attrs.name;
415 }
416
417 // count_include_pad
418 AttributeProto* const count_include_pad = node_proto->add_attribute();
419 count_include_pad->set_name("count_include_pad");
420 count_include_pad->set_type(AttributeProto::INT);
421 if (pooling_param.count_include_pad.has_value()) {
422 count_include_pad->set_i(pooling_param.count_include_pad.value());
423 } else {
424 count_include_pad->set_i(1);
425 }
426 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
427 } // end ConvertPooling
428
ConvertRelu(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs &,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)429 void ConvertRelu(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& /*attrs*/,
430 const nnvm::IndexedGraph& ig,
431 const array_view<IndexedGraph::NodeEntry>& inputs) {
432 NodeProto* node_proto = graph_proto->add_node();
433 node_proto->set_name(node_name);
434 node_proto->set_op_type("Relu");
435 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
436 }
437
ConvertActivation(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)438 void ConvertActivation(GraphProto *graph_proto, const std::string& node_name,
439 const NodeAttrs& attrs,
440 const nnvm::IndexedGraph& ig,
441 const array_view<IndexedGraph::NodeEntry>& inputs) {
442 NodeProto* node_proto = graph_proto->add_node();
443 node_proto->set_name(node_name);
444 const auto& act_param = nnvm::get<op::ActivationParam>(attrs.parsed);
445 std::string act_type;
446 switch (act_param.act_type) {
447 case op::activation::kReLU:
448 act_type = "Relu";
449 break;
450 case op::activation::kSigmoid:
451 act_type = "Sigmoid";
452 break;
453 case op::activation::kTanh:
454 act_type = "Tanh";
455 break;
456 case op::activation::kSoftReLU:
457 // act_type = "SoftReLU";
458 throw dmlc::Error("SoftReLU is not supported in ONNX");
459 break;
460 default:
461 throw dmlc::Error("Activation of such type doesn't exist");
462 }
463
464 node_proto->set_op_type(act_type);
465 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
466 }
467
ConvertFullyConnected(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)468 void ConvertFullyConnected(GraphProto *graph_proto, const std::string& node_name,
469 const NodeAttrs& attrs,
470 const nnvm::IndexedGraph& ig,
471 const array_view<IndexedGraph::NodeEntry>& inputs) {
472 const auto& act_param = nnvm::get<op::FullyConnectedParam>(attrs.parsed);
473 // ONNX spec doesn't support GEMMs with input of different dims, so we need to replace it
474 // by Transpose+MatMul+Add
475 if (!act_param.flatten && !act_param.no_bias) {
476 NodeProto* tranpose_node_proto = graph_proto->add_node();
477 NodeProto* matmul_node_proto = graph_proto->add_node();
478 NodeProto* add_node_proto = graph_proto->add_node();
479 tranpose_node_proto->set_name(node_name+"_Transpose");
480 matmul_node_proto->set_name(node_name+"_MatMul");
481 add_node_proto->set_name(node_name+"_Add");
482
483 tranpose_node_proto->set_op_type("Transpose");
484 matmul_node_proto->set_op_type("MatMul");
485 add_node_proto->set_op_type("Add");
486
487 std::string input_node_name = ig[inputs[op::conv::kData].node_id].source->attrs.name;
488 std::string weight_node_name = ig[inputs[op::conv::kWeight].node_id].source->attrs.name;
489 std::string bias_node_name = ig[inputs[op::conv::kBias].node_id].source->attrs.name;
490
491 tranpose_node_proto->add_input(weight_node_name);
492 tranpose_node_proto->add_output(node_name+"_Transpose");
493
494 matmul_node_proto->add_input(input_node_name);
495 matmul_node_proto->add_input(node_name+"_Transpose");
496 matmul_node_proto->add_output(node_name+"_MatMul");
497
498 add_node_proto->add_input(node_name+"_MatMul");
499 add_node_proto->add_input(bias_node_name);
500 // Add's output is the output of the Transpose+MatMul+Add subgraph
501 add_node_proto->add_output(node_name);
502 } else {
503 NodeProto* node_proto = graph_proto->add_node();
504 node_proto->set_name(node_name);
505 if (act_param.no_bias) {
506 node_proto->set_op_type("MatMul");
507 } else {
508 node_proto->set_op_type("Gemm");
509
510 AttributeProto* const alpha = node_proto->add_attribute();
511 alpha->set_name("alpha");
512 alpha->set_type(AttributeProto::FLOAT);
513 alpha->set_f(1.0f);
514
515 AttributeProto* const beta = node_proto->add_attribute();
516 beta->set_name("beta");
517 beta->set_type(AttributeProto::FLOAT);
518 beta->set_f(1.0f);
519
520 AttributeProto* const transA = node_proto->add_attribute();
521 transA->set_name("transA");
522 transA->set_type(AttributeProto::INT);
523 transA->set_i(0);
524
525 AttributeProto* const transB = node_proto->add_attribute();
526 transB->set_name("transB");
527 transB->set_type(AttributeProto::INT);
528 transB->set_i(1);
529 }
530 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
531 }
532 }
533
ConvertSlice(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)534 void ConvertSlice(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
535 const nnvm::IndexedGraph& ig,
536 const array_view<IndexedGraph::NodeEntry>& inputs) {
537 NodeProto* node_proto = graph_proto->add_node();
538 node_proto->set_name(node_name);
539 const auto& params = nnvm::get<SliceParam>(attrs.parsed);
540 node_proto->set_op_type("Slice");
541
542 // starts
543 AttributeProto* const starts = node_proto->add_attribute();
544 starts->set_name("starts");
545 starts->set_type(AttributeProto::INTS);
546
547 // ends
548 AttributeProto* const ends = node_proto->add_attribute();
549 ends->set_name("ends");
550 ends->set_type(AttributeProto::INTS);
551
552 // axes
553 AttributeProto* const axes = node_proto->add_attribute();
554 axes->set_name("axes");
555 axes->set_type(AttributeProto::INTS);
556
557 for (int64_t i = 1; i < params.begin.ndim(); ++i) {
558 if (params.begin[i].has_value()) {
559 starts->add_ints(static_cast<int64>(params.begin[i].value()));
560 } else {
561 starts->add_ints(static_cast<int64>(0));
562 }
563 if (params.end[i].has_value()) {
564 ends->add_ints(static_cast<int64>(params.end[i].value()));
565 } else {
566 ends->add_ints(static_cast<int64>(INT_MAX));
567 }
568 axes->add_ints(static_cast<int64>(i));
569 }
570 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
571 }
572
ConvertSoftmaxOutput(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs &,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)573 void ConvertSoftmaxOutput(GraphProto *graph_proto, const std::string& node_name,
574 const NodeAttrs& /*attrs*/,
575 const nnvm::IndexedGraph& ig,
576 const array_view<IndexedGraph::NodeEntry>& inputs) {
577 NodeProto* node_proto = graph_proto->add_node();
578 node_proto->set_name(node_name);
579 node_proto->set_op_type("Softmax");
580
581 // Setting by default to 1 since MXNet doesn't provide such an attribute for softmax in its
582 // node params. This attribute is only relevant when the input is coerced to 2D, and in that
583 // case dimension 0 is assumed to be the batch dimension.
584 AttributeProto* const axis = node_proto->add_attribute();
585 axis->set_name("axis");
586 axis->set_type(AttributeProto::INT);
587 axis->set_i(1);
588 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
589 }
590
591
ConvertFlatten(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs &,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)592 void ConvertFlatten(GraphProto *graph_proto, const std::string& node_name,
593 const NodeAttrs& /*attrs*/,
594 const nnvm::IndexedGraph& ig,
595 const array_view<IndexedGraph::NodeEntry>& inputs) {
596 NodeProto* node_proto = graph_proto->add_node();
597 node_proto->set_name(node_name);
598 node_proto->set_op_type("Flatten");
599
600 // Setting by default to 1 since MXNet doesn't provide such an attribute for Flatten in its
601 // node params. This attribute is only relevant when the input is coerced to 2D, and in that
602 // case dimension 0 is assumed to be the batch dimension.
603 AttributeProto* const axis = node_proto->add_attribute();
604 axis->set_name("axis");
605 axis->set_type(AttributeProto::INT);
606 axis->set_i(1);
607 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
608 }
609
ConvertBatchNorm(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)610 void ConvertBatchNorm(GraphProto *graph_proto, const std::string& node_name,
611 const NodeAttrs& attrs,
612 const nnvm::IndexedGraph& ig,
613 const array_view<IndexedGraph::NodeEntry>& inputs) {
614 NodeProto* node_proto = graph_proto->add_node();
615 node_proto->set_name(node_name);
616 node_proto->set_op_type("BatchNormalization");
617 const auto& param = nnvm::get<op::BatchNormParam>(attrs.parsed);
618
619 AttributeProto* const epsilon = node_proto->add_attribute();
620 epsilon->set_name("epsilon");
621 epsilon->set_type(AttributeProto::FLOAT);
622 epsilon->set_f(static_cast<float>(param.eps));
623
624 AttributeProto* const momentum = node_proto->add_attribute();
625 momentum->set_name("momentum");
626 momentum->set_type(AttributeProto::FLOAT);
627 momentum->set_f(param.momentum);
628
629 AttributeProto* const spatial = node_proto->add_attribute();
630 spatial->set_name("spatial");
631 spatial->set_type(AttributeProto::INT);
632 // MXNet computes mean and variance per feature for batchnorm. Enabling spatial mode
633 // (default in ONNX3) implies running batchnorm on all spatial features so we need to explicitly
634 // disable this for MXNet's BatchNorm.
635 spatial->set_i(0);
636 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
637 }
638
ConvertElementwiseAdd(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs &,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)639 void ConvertElementwiseAdd(GraphProto *graph_proto, const std::string& node_name,
640 const NodeAttrs& /*attrs*/,
641 const nnvm::IndexedGraph& ig,
642 const array_view<IndexedGraph::NodeEntry>& inputs) {
643 NodeProto* node_proto = graph_proto->add_node();
644 node_proto->set_name(node_name);
645 node_proto->set_op_type("Add");
646 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
647 }
648
ConvertElementwiseSub(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs &,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)649 void ConvertElementwiseSub(GraphProto *graph_proto, const std::string& node_name,
650 const NodeAttrs& /*attrs*/,
651 const nnvm::IndexedGraph& ig,
652 const array_view<IndexedGraph::NodeEntry>& inputs) {
653 NodeProto* node_proto = graph_proto->add_node();
654 node_proto->set_name(node_name);
655 node_proto->set_op_type("Sub");
656 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
657 }
658
ConvertElementwiseMul(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs &,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)659 void ConvertElementwiseMul(GraphProto *graph_proto, const std::string& node_name,
660 const NodeAttrs& /*attrs*/,
661 const nnvm::IndexedGraph& ig,
662 const array_view<IndexedGraph::NodeEntry>& inputs) {
663 NodeProto* node_proto = graph_proto->add_node();
664 node_proto->set_name(node_name);
665 node_proto->set_op_type("Mul");
666 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
667 }
668
ConvertConcatenate(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)669 void ConvertConcatenate(GraphProto *graph_proto, const std::string& node_name,
670 const NodeAttrs& attrs,
671 const nnvm::IndexedGraph& ig,
672 const array_view<IndexedGraph::NodeEntry>& inputs) {
673 NodeProto* node_proto = graph_proto->add_node();
674 node_proto->set_name(node_name);
675 const auto& _param = nnvm::get<ConcatParam>(attrs.parsed);
676 node_proto->set_op_type("Concat");
677 node_proto->set_name(attrs.name);
678 // axis
679 AttributeProto* const axis = node_proto->add_attribute();
680 axis->set_name("axis");
681 axis->set_type(AttributeProto::INT);
682 axis->set_i(static_cast<int64_t>(_param.dim));
683 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
684 }
685
ConvertDType(int dtype)686 inline TensorProto_DataType ConvertDType(int dtype) {
687 switch (dtype) {
688 case mshadow::kFloat64:
689 return TensorProto_DataType_DOUBLE;
690 case mshadow::kFloat32:
691 return TensorProto_DataType_FLOAT;
692 case mshadow::kFloat16:
693 return TensorProto_DataType_FLOAT16;
694 case mshadow::kUint8:
695 return TensorProto_DataType_UINT8;
696 case mshadow::kInt32:
697 return TensorProto_DataType_INT32;
698 case mshadow::kInt8:
699 return TensorProto_DataType_INT8;
700 case mshadow::kInt64:
701 return TensorProto_DataType_INT64;
702 default:
703 return TensorProto_DataType_UNDEFINED;
704 }
705 }
706
GetPlaceholderShapes(const ShapeVector & shape_inputs,const nnvm::IndexedGraph & ig)707 std::unordered_map<std::string, TShape> GetPlaceholderShapes(
708 const ShapeVector& shape_inputs, const nnvm::IndexedGraph& ig) {
709 std::unordered_map<std::string, mxnet::TShape> placeholder_shapes;
710 for (uint32_t i = 0; i < shape_inputs.size(); ++i) {
711 std::string name = ig[ig.input_nodes()[i]].source->attrs.name;
712 mxnet::TShape shp = shape_inputs[i];
713 if (!mxnet::op::shape_is_none(shp)) {
714 // TODO(@reminisce): confirm
715 placeholder_shapes.emplace(name, shp);
716 }
717 }
718 return placeholder_shapes;
719 }
720
GetPlaceholderDTypes(const DTypeVector & dtype_inputs,const nnvm::IndexedGraph & ig)721 std::unordered_map<std::string, int> GetPlaceholderDTypes(
722 const DTypeVector& dtype_inputs, const nnvm::IndexedGraph& ig) {
723 std::unordered_map<std::string, int> placeholder_dtypes;
724 for (uint32_t i = 0; i < dtype_inputs.size(); ++i) {
725 std::string name = ig[ig.input_nodes()[i]].source->attrs.name;
726 int dtype = dtype_inputs[i];
727 placeholder_dtypes.emplace(name, dtype);
728 }
729 return placeholder_dtypes;
730 }
731
GetOutputLookup(const nnvm::IndexedGraph & ig)732 std::unordered_map<std::string, uint32_t> GetOutputLookup(
733 const nnvm::IndexedGraph& ig) {
734 std::unordered_map<std::string, uint32_t> output_lookup;
735 const std::vector<nnvm::IndexedGraph::NodeEntry>& graph_outputs =
736 ig.outputs();
737 for (uint32_t i = 0; i < graph_outputs.size(); ++i) {
738 const uint32_t id = graph_outputs[i].node_id;
739 const IndexedGraph::Node ig_node = ig[id];
740 const nnvm::Node* const source = ig_node.source;
741 const std::string name = source->attrs.name;
742 output_lookup.emplace(name, i);
743 }
744 return output_lookup;
745 }
746
ConvertPlaceholder(const std::string & node_name,const std::unordered_map<std::string,TShape> & placeholder_shapes,const std::unordered_map<std::string,int> & placeholder_dtypes,GraphProto * const graph_proto)747 void ConvertPlaceholder(
748 const std::string& node_name,
749 const std::unordered_map<std::string, TShape>& placeholder_shapes,
750 const std::unordered_map<std::string, int>& placeholder_dtypes,
751 GraphProto* const graph_proto) {
752 auto val_info_proto = graph_proto->add_input();
753 auto type_proto = val_info_proto->mutable_type()->mutable_tensor_type();
754 auto shape_proto = type_proto->mutable_shape();
755
756 val_info_proto->set_name(node_name);
757 auto entry_shape = placeholder_shapes.find(node_name)->second;
758 auto entry_dtype = placeholder_dtypes.find(node_name)->second;
759 type_proto->set_elem_type(ConvertDType(entry_dtype));
760 for (const auto& elem : entry_shape) {
761 TensorShapeProto_Dimension* const tsp_dim = shape_proto->add_dim();
762 tsp_dim->set_dim_value(static_cast<int64>(elem));
763 }
764 }
765
ConvertConstant(GraphProto * const graph_proto,const std::string & node_name,const std::unordered_map<std::string,NDArray> * const params_map)766 void ConvertConstant(
767 GraphProto* const graph_proto, const std::string& node_name,
768 const std::unordered_map<std::string, NDArray>* const params_map) {
769 TensorProto* const initializer_proto = graph_proto->add_initializer();
770
771 // Create initializer for constants
772 initializer_proto->set_name(node_name);
773
774 const NDArray nd = params_map->find(node_name)->second;
775 const TBlob& blob = nd.data();
776 const TShape shape = blob.shape_;
777 const auto dtype = ConvertDType(nd.dtype());
778 initializer_proto->set_data_type(dtype);
779
780 for (auto& dim : shape) {
781 initializer_proto->add_dims(static_cast<int64>(dim));
782 }
783
784 auto size = shape.Size();
785
786 if (dtype == TensorProto_DataType_FLOAT) {
787 std::shared_ptr<float[]> shared_data_ptr(new float[size]);
788 float* const data_ptr = shared_data_ptr.get();
789 nd.SyncCopyToCPU(static_cast<void*>(data_ptr), size);
790
791 for (size_t blob_idx = 0; blob_idx < size; ++blob_idx) {
792 initializer_proto->add_float_data(data_ptr[blob_idx]);
793 }
794 } else if (dtype == TensorProto_DataType_FLOAT16) {
795 std::shared_ptr<uint16_t[]> shared_data_ptr(new uint16_t[size]);
796 uint16_t* const data_ptr = shared_data_ptr.get();
797 nd.SyncCopyToCPU(static_cast<void*>(data_ptr), size);
798 for (size_t blob_idx = 0; blob_idx < size; ++blob_idx) {
799 initializer_proto->add_int32_data(
800 reinterpret_cast<int32_t*>(data_ptr)[blob_idx]);
801 }
802 } else {
803 LOG(FATAL) << "dtype not supported for variables: " << node_name;
804 }
805
806 // Create inputs for constants.
807 ValueInfoProto* const input_proto = graph_proto->add_input();
808 input_proto->set_name(node_name);
809
810 input_proto->mutable_type()->mutable_tensor_type()->set_elem_type(dtype);
811 for (auto& dim : shape) {
812 auto new_dim = input_proto->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim();
813 new_dim->set_dim_value(static_cast<int64>(dim));
814 }
815 }
816
ConvertOutput(GraphProto * const graph_proto,const std::unordered_map<std::string,uint32_t>::iterator & out_iter,const std::string & node_name,const ShapeVector & shapes,const DTypeVector & dtypes,const nnvm::IndexedGraph & ig)817 void ConvertOutput(
818 GraphProto* const graph_proto,
819 const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
820 const std::string& node_name, const ShapeVector& shapes,
821 const DTypeVector& dtypes, const nnvm::IndexedGraph &ig) {
822 uint32_t out_idx = ig.entry_id(ig.outputs()[out_iter->second]);
823 int dtype = dtypes[out_idx];
824 auto graph_out = graph_proto->add_output();
825 auto tensor_type = graph_out->mutable_type()->mutable_tensor_type();
826 auto tensor_shape_proto = tensor_type->mutable_shape();
827 graph_out->set_name(node_name);
828
829 // Also support fp16.
830 tensor_type->set_elem_type(ConvertDType(dtype));
831
832 for (int64_t dim_shp : shapes[out_idx]) {
833 TensorShapeProto_Dimension* const tsp_dim = tensor_shape_proto->add_dim();
834 tsp_dim->set_dim_value(static_cast<int64>(dim_shp));
835 }
836 }
837
ConvertClip(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)838 void ConvertClip(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
839 const nnvm::IndexedGraph& ig,
840 const array_view<IndexedGraph::NodeEntry>& inputs) {
841 NodeProto* node_proto = graph_proto->add_node();
842 node_proto->set_name(node_name);
843 const auto& param = nnvm::get<ClipParam>(attrs.parsed);
844
845 node_proto->set_op_type("Clip");
846
847 // max
848 AttributeProto* const a_max = node_proto->add_attribute();
849 a_max->set_name("max");
850 a_max->set_type(AttributeProto::FLOAT);
851 a_max->set_f(static_cast<float>(param.a_max));
852
853 // min
854 AttributeProto* const a_min = node_proto->add_attribute();
855 a_min->set_name("min");
856 a_min->set_type(AttributeProto::FLOAT);
857 a_min->set_f(static_cast<float>(param.a_min));
858 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
859 }
860
ConvertPad(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)861 void ConvertPad(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
862 const nnvm::IndexedGraph& ig,
863 const array_view<IndexedGraph::NodeEntry>& inputs) {
864 NodeProto* node_proto = graph_proto->add_node();
865 node_proto->set_name(node_name);
866 const auto& param = nnvm::get<PadParam>(attrs.parsed);
867
868 node_proto->set_op_type("Pad");
869
870 // mode
871 AttributeProto* const mode = node_proto->add_attribute();
872 mode->set_name("mode");
873 mode->set_type(AttributeProto::STRING);
874 switch (param.mode) {
875 case op::pad_enum::kConstant:
876 mode->set_s("constant");
877 break;
878 case op::pad_enum::kEdge:
879 mode->set_s("edge");
880 break;
881 case op::pad_enum::kReflect:
882 mode->set_s("reflect");
883 break;
884 default:
885 throw dmlc::Error("Such mode of padding doesn't exist");
886 }
887
888 // pads
889 AttributeProto* const pads = node_proto->add_attribute();
890 pads->set_name("pads");
891 pads->set_type(AttributeProto::INTS);
892
893 std::vector<int64> pad_begin;
894 std::vector<int64> pad_end;
895 for (int st = 0; st < 2; ++st) {
896 for (auto it = param.pad_width.begin() + st;
897 it != param.pad_width.end(); it += 2) {
898 pads->add_ints(static_cast<int64>(*it));
899 }
900 }
901
902 // value
903 AttributeProto* const value = node_proto->add_attribute();
904 value->set_name("value");
905 value->set_type(AttributeProto::FLOAT);
906 value->set_f(param.constant_value);
907 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
908 }
909
ConvertDropout(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)910 void ConvertDropout(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
911 const nnvm::IndexedGraph& ig,
912 const array_view<IndexedGraph::NodeEntry>& inputs) {
913 NodeProto* node_proto = graph_proto->add_node();
914 node_proto->set_name(node_name);
915 node_proto->set_op_type("Dropout");
916 DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
917 }
918
PreprocessBatchNorm(const NodeAttrs & attrs,const std::vector<nnvm::NodeEntry> & inputs,std::unordered_map<std::string,NDArray> * params_map)919 void PreprocessBatchNorm(const NodeAttrs &attrs,
920 const std::vector<nnvm::NodeEntry> &inputs,
921 std::unordered_map<std::string, NDArray> *params_map) {
922 const auto& param = nnvm::get<op::BatchNormParam>(attrs.parsed);
923 if (param.fix_gamma) {
924 // if mxnet is specify fix_gamma, we will need to preprocess the params map
925 // to convert the gamma associate with this batch norm layer to 1.
926 std::string gammaNodeName = inputs[batchnorm::kGamma].node->attrs.name;
927 (*params_map)[gammaNodeName] = 1.0f;
928 }
929 }
930
931 } // namespace nnvm_to_onnx
932 } // namespace op
933 } // namespace mxnet
934
935 #endif // MXNET_USE_TENSORRT
936