1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file nnvm_to_onnx.cc
22  * \brief Conversion from NNVM to ONNX for TensorRT
23  * \author Marek Kolodziej, Clement Fuji Tsang
24 */
25 
26 #if MXNET_USE_TENSORRT
27 
28 #include "./nnvm_to_onnx-inl.h"
29 
30 #include <mxnet/base.h>
31 #include <nnvm/graph.h>
32 #include <nnvm/pass_functions.h>
33 
34 #include "../../../common/utils.h"
35 #include "../../../ndarray/ndarray_function.h"
36 #include "../../pad-inl.h"
37 #include "../../nn/activation-inl.h"
38 #include "../../nn/batch_norm-inl.h"
39 #include "../../nn/convolution-inl.h"
40 #include "../../nn/deconvolution-inl.h"
41 #include "../../nn/fully_connected-inl.h"
42 #include "../../nn/pooling-inl.h"
43 #include "../../nn/concat-inl.h"
44 #include "../../softmax_output-inl.h"
45 #include "../../tensor/matrix_op-inl.h"
46 
47 #if MXNET_USE_TENSORRT_ONNX_CHECKER
48 #include <onnx/checker.h>
49 #endif  // MXNET_USE_TENSORRT_ONNX_CHECKER
50 
51 namespace mxnet {
52 namespace op {
53 namespace nnvm_to_onnx {
54 
ConvertNnvmGraphToOnnx(const nnvm::Graph & g,std::unordered_map<std::string,NDArray> * params_map)55 std::string ConvertNnvmGraphToOnnx(
56     const nnvm::Graph& g,
57     std::unordered_map<std::string, NDArray>* params_map) {
58 
59   static std::atomic_ulong subgraph_count = { 0 };
60 
61   std::string serialized_onnx_graph;
62 
63   const nnvm::IndexedGraph& ig = g.indexed_graph();
64   const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
65   const auto& shapes = g.GetAttr<ShapeVector>("shape");
66   const auto& dtype_inputs = g.GetAttr<DTypeVector>("dtype_inputs");
67   const auto& shape_inputs = g.GetAttr<ShapeVector>("shape_inputs");
68 
69   ModelProto model_proto;
70 
71   // We're currently serializing our models in ONNX 3, opset 8 as it is best supported by the
72   // currently linked version of the onnx-tensorrt library.
73   // More information on ONNX versions and opsets can be found at:
74   // https://github.com/onnx/onnx/blob/master/docs/IR.md
75 
76   auto opset_proto = model_proto.add_opset_import();
77   const int64 onnx_opset = 8;
78   const int64 onnx_major_version = 3;
79 
80   // Declare our ONNX versions in our protobuf model.
81   opset_proto->set_version(onnx_opset);
82   model_proto.set_ir_version(onnx_major_version);
83 
84   GraphProto* graph_proto = model_proto.mutable_graph();
85   auto subgraph_name_id = subgraph_count.fetch_add(1);
86   graph_proto->set_name("MXNetTRTSubgraph" + std::to_string(subgraph_name_id));
87 
88   auto placeholder_shapes = GetPlaceholderShapes(shape_inputs, ig);
89   auto placeholder_dtypes = GetPlaceholderDTypes(dtype_inputs, ig);
90   auto output_lookup = GetOutputLookup(ig);
91 
92   for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) {
93       const IndexedGraph::Node& node = ig[node_idx];
94       const nnvm::Node* source = node.source;
95       // If this is a op
96       if (!source->is_variable()) {
97         auto mightNeedPreprocessNode = preprocess_map.find(source->op()->name);
98         // if this op is defined in preprocess_map
99         if (mightNeedPreprocessNode != preprocess_map.end()) {
100           mightNeedPreprocessNode->second(source->attrs, source->inputs, params_map);
101         }
102       }
103   }
104 
105   uint32_t current_input = 0;
106   // Can't do a foreach over IndexedGraph since it doesn't implement begin(), etc.
107   for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) {
108     const IndexedGraph::Node& node = ig[node_idx];
109     const nnvm::Node* source = node.source;
110     const NodeAttrs& attrs = source->attrs;
111     const Op* op = source->op();
112 
113     std::string node_name = attrs.name;
114     // Here, "variable" actually means anything that's not an op i.e. a constant (weights) or a
115     // placeholder
116     if (source->is_variable()) {
117       // Is this a placeholder?
118       if (params_map->count(node_name) == 0) {
119         // This fixes the problem with a SoftmaxOutput node during inference, but it's hacky.
120         // Need to figure out how to properly fix it.
121         if (node_name.find("label") != std::string::npos) {
122           current_input++;
123           continue;
124         }
125         ConvertPlaceholder(node_name, placeholder_shapes, placeholder_dtypes, graph_proto);
126       } else {
127         // If it's not a placeholder, then by exclusion it's a constant.
128         ConvertConstant(graph_proto, node_name, params_map);
129       }  // is_placeholder
130     } else {
131       // It's an op, rather than a "variable" (constant or placeholder)
132       if (converter_map.count(op->name) == 0) {
133         LOG(FATAL) << "Conversion for node of type " << op->name << " (node "
134                    << node_name << ") "
135                    << " is not supported yet.";
136       }
137       // Find function ptr to a converter based on the op name, and invoke the converter. This
138       // looks unsafe because find may not succeed, but it does because we're in the operator
139       // logic after testing that this node name does not represent a variable.
140       converter_map.find(op->name)->second(graph_proto, node_name, attrs, ig, node.inputs);
141       // See if the current node is an output node
142       auto out_iter = output_lookup.find(node_name);
143       // We found an output
144       if (out_iter != output_lookup.end()) {
145         ConvertOutput(graph_proto, out_iter, node_name, shapes, dtypes, ig);
146       }  // output found
147     }    // conversion function exists
148   }      // loop over i from 0 to num_nodes
149 
150   model_proto.SerializeToString(&serialized_onnx_graph);
151 
152 #if MXNET_USE_TENSORRT_ONNX_CHECKER
153   onnx::checker::check_model(model_proto);
154 #endif  // MXNET_USE_TENSORRT_ONNX_CHECKER
155 
156   return serialized_onnx_graph;
157 }
158 
DefaultConnectInputsOutputs(NodeProto * node_proto,const array_view<IndexedGraph::NodeEntry> & inputs,const nnvm::IndexedGraph & ig,const std::string & node_name)159 void DefaultConnectInputsOutputs(NodeProto *node_proto,
160                                  const array_view<IndexedGraph::NodeEntry>& inputs,
161                                  const nnvm::IndexedGraph& ig,
162                                  const std::string& node_name) {
163   for (const nnvm::IndexedGraph::NodeEntry& entry : inputs) {
164     std::string in_node_name = ig[entry.node_id].source->attrs.name;
165     // As before, we're not adding labels e.g. for SoftmaxOutput, but I wish there was a less
166     // hacky way to do it than name matching.
167     if (in_node_name.find("label") != std::string::npos) {
168       continue;
169     }
170     node_proto->add_input(in_node_name);
171   }
172   // The node's output will have the same name as the node name.
173   node_proto->add_output(node_name);
174 }
175 
Make1DTensor(GraphProto * const graph_proto,const int64_t & size,const std::string & name,const TensorProto_DataType & dtype)176 TensorProto* const Make1DTensor(GraphProto* const graph_proto, const int64_t& size,
177                                 const std::string& name, const TensorProto_DataType& dtype) {
178   TensorProto* const initializer_proto = graph_proto->add_initializer();
179   initializer_proto->set_name(name);
180   initializer_proto->set_data_type(dtype);
181   initializer_proto->add_dims(static_cast<int64>(size));
182 
183   ValueInfoProto* const input_proto = graph_proto->add_input();
184   input_proto->set_name(name);
185   auto var = input_proto->mutable_type()->mutable_tensor_type();
186   var->set_elem_type(dtype);
187   var->mutable_shape()->add_dim()->set_dim_value(static_cast<int64>(size));
188   return initializer_proto;
189 }
190 
191 // Keep for when ONNX version will be updated
192 /*
193 void ConvertSlice(GraphProto* const graph_proto, const Node* node, const Graph& g) {
194   const auto& params = nnvm::get<SliceParam>(node->attrs.parsed);
195   int64 nb_slices = static_cast<int64>(params.begin.ndim());
196 
197   // starts
198   auto init_starts = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_starts",
199                                   TensorProto_DataType_INT64);
200   for (auto& opt : params.begin) {
201     if (opt.has_value()) {
202       init_starts->add_int64_data(static_cast<int64>(opt.value()));
203     } else {
204       init_starts->add_int64_data(static_cast<int64>(0));
205     }
206   }
207 
208   // ends
209   auto init_ends = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_ends",
210                                 TensorProto_DataType_INT64);
211   for (auto& opt : params.end) {
212     if (opt.has_value()) {
213       init_ends->add_int64_data(static_cast<int64>(opt.value()));
214     } else {
215       init_ends->add_int64_data(static_cast<int64>(INT_MAX));
216     }
217   }
218 
219   // axes
220   auto init_axes = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_axes",
221                                 TensorProto_DataType_INT64);
222   for (int64_t i = 0; i < nb_slices; ++i) {
223     init_axes->add_int64_data(static_cast<int64>(i));
224   }
225 
226   // slice node
227   NodeProto* node_proto = graph_proto->add_node();
228   node_proto->set_name(node->attrs.name);
229   node_proto->set_op_type("Slice");
230   node_proto->add_input(node->inputs[0].node->attrs.name);
231   node_proto->add_input(node->attrs.name + "_starts");
232   node_proto->add_input(node->attrs.name + "_ends");
233   node_proto->add_input(node->attrs.name + "_axes");
234 
235   // steps
236   if (params.step.ndim() != 0) {
237     auto init_steps = Make1DTensor(graph_proto, nb_slices, node->attrs.name + "_steps",
238                                    TensorProto_DataType_INT64);
239     for (auto& opt : params.step) {
240       if (opt.has_value()) {
241         init_steps->add_int64_data(static_cast<int64>(opt.value()));
242       } else {
243         init_steps->add_int64_data(static_cast<int64>(1));
244       }
245     }
246     node_proto->add_input(node->attrs.name + "_steps");
247   }
248 
249   node_proto->add_output(node->attrs.name);
250 }
251 */
252 
ConvertIdentity(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)253 void ConvertIdentity(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
254                      const nnvm::IndexedGraph& ig,
255                      const array_view<IndexedGraph::NodeEntry>& inputs) {
256   NodeProto* node_proto = graph_proto->add_node();
257   node_proto->set_name(node_name);
258   node_proto->set_op_type("Identity");
259   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
260 }
261 
262 template <class ConvDeconvParam>
ConvDeconvConvertHelper(NodeProto * node_proto,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs,const ConvDeconvParam & param,ConvDeconvType type)263 void ConvDeconvConvertHelper(NodeProto *node_proto, const NodeAttrs& attrs,
264                              const nnvm::IndexedGraph& ig,
265                              const array_view<IndexedGraph::NodeEntry>& inputs,
266                              const ConvDeconvParam& param,
267                              ConvDeconvType type) {
268   if (type == ConvDeconvType::Convolution) {
269     node_proto->set_op_type("Conv");
270   } else {
271     node_proto->set_op_type("ConvTranspose");
272   }
273 
274   const mxnet::TShape kernel = param.kernel;
275   const mxnet::TShape stride = param.stride;
276   const mxnet::TShape dilate = param.dilate;
277   const mxnet::TShape pad = param.pad;
278   const uint32_t num_group = param.num_group;
279   // const bool no_bias = conv_param.no_bias;
280   const dmlc::optional<int> layout = param.layout;
281 
282   // dilations
283   AttributeProto* const dilations = node_proto->add_attribute();
284   dilations->set_name("dilations");
285   dilations->set_type(AttributeProto::INTS);
286   for (const dim_t kval : dilate) {
287     dilations->add_ints(static_cast<int64>(kval));
288   }
289 
290   // group
291   AttributeProto* const group = node_proto->add_attribute();
292   group->set_name("group");
293   group->set_type(AttributeProto::INT);
294   group->set_i(static_cast<int64>(num_group));
295 
296   // kernel shape
297   AttributeProto* const kernel_shape = node_proto->add_attribute();
298   kernel_shape->set_name("kernel_shape");
299   kernel_shape->set_type(AttributeProto::INTS);
300 
301   for (const dim_t kval : kernel) {
302     kernel_shape->add_ints(static_cast<int64>(kval));
303   }
304 
305   // pads
306   AttributeProto* const pads = node_proto->add_attribute();
307   pads->set_name("pads");
308   pads->set_type(AttributeProto::INTS);
309 
310   for (int i =0; i < 2; i++) {
311     for (dim_t kval : pad) {
312       pads->add_ints(static_cast<int64>(kval));
313     }
314   }
315 
316   // strides
317   AttributeProto* const strides = node_proto->add_attribute();
318   strides->set_name("strides");
319   strides->set_type(AttributeProto::INTS);
320   for (const dim_t kval : stride) {
321     strides->add_ints(static_cast<int64>(kval));
322   }
323 }
324 
ConvertConvolution(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)325 void ConvertConvolution(GraphProto *graph_proto, const std::string& node_name,
326                         const NodeAttrs& attrs,
327                         const nnvm::IndexedGraph& ig,
328                         const array_view<IndexedGraph::NodeEntry>& inputs) {
329   NodeProto* node_proto = graph_proto->add_node();
330   node_proto->set_name(node_name);
331   const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
332   ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, conv_param,
333       ConvDeconvType::Convolution);
334   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
335 }  // end ConvertConvolution
336 
ConvertDeconvolution(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)337 void ConvertDeconvolution(GraphProto *graph_proto, const std::string& node_name,
338                           const NodeAttrs& attrs,
339                           const nnvm::IndexedGraph& ig,
340                           const array_view<IndexedGraph::NodeEntry>& inputs) {
341   NodeProto* node_proto = graph_proto->add_node();
342   node_proto->set_name(node_name);
343   const auto& deconv_param = nnvm::get<op::DeconvolutionParam>(attrs.parsed);
344   ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, deconv_param,
345       ConvDeconvType::Deconvolution);
346   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
347 }  // end ConvertDeconvolution
348 
ConvertPooling(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)349 void ConvertPooling(GraphProto *graph_proto, const std::string& node_name,
350                     const NodeAttrs& attrs,
351                     const nnvm::IndexedGraph& ig,
352                     const array_view<IndexedGraph::NodeEntry>& inputs) {
353   NodeProto* node_proto = graph_proto->add_node();
354   node_proto->set_name(node_name);
355   const auto& pooling_param = nnvm::get<op::PoolingParam>(attrs.parsed);
356 
357   const mxnet::TShape kernel = pooling_param.kernel;
358   const mxnet::TShape stride = pooling_param.stride;
359   const mxnet::TShape pad = pooling_param.pad;
360   const int pool_type = pooling_param.pool_type;
361   const bool global_pool = pooling_param.global_pool;
362 
363   if (global_pool) {
364     if (pool_type == pool_enum::kMaxPooling) {
365       node_proto->set_op_type("GlobalMaxPool");
366     } else if (pool_type == pool_enum::kAvgPooling) {
367       node_proto->set_op_type("GlobalAveragePool");
368     } else {
369       LOG(FATAL) << "Pool type of node '" << attrs.name << "' unsupported: " << attrs.name;
370     }
371     DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
372     return;
373   }
374 
375   // kernel_shape
376   AttributeProto* const kernel_shape = node_proto->add_attribute();
377   kernel_shape->set_name("kernel_shape");
378   kernel_shape->set_type(AttributeProto::INTS);
379   for (dim_t kval : kernel) {
380     kernel_shape->add_ints(static_cast<int64>(kval));
381   }
382 
383   // pads
384   AttributeProto* const pads = node_proto->add_attribute();
385   pads->set_name("pads");
386   pads->set_type(AttributeProto::INTS);
387 
388   // Convert from MXNet symetric pads to ONNX non-symetric by running through padding twice.
389   for (int i =0; i < 2; i++) {
390     for (dim_t kval : pad) {
391       pads->add_ints(static_cast<int64>(kval));
392     }
393   }
394 
395   // strides
396   AttributeProto* const strides = node_proto->add_attribute();
397   strides->set_name("strides");
398   strides->set_type(AttributeProto::INTS);
399   for (dim_t kval : stride) {
400     strides->add_ints(static_cast<int64>(kval));
401   }
402 
403   // ceil_mode
404   AttributeProto* const ceil_mode = node_proto->add_attribute();
405   ceil_mode->set_name("ceil_mode");
406   ceil_mode->set_type(AttributeProto::INT);
407   ceil_mode->set_i(static_cast<int64>(pooling_param.pooling_convention == pool_enum::kFull));
408 
409   if (pool_type == pool_enum::kMaxPooling) {
410     node_proto->set_op_type("MaxPool");
411   } else if (pool_type == pool_enum::kAvgPooling) {
412     node_proto->set_op_type("AveragePool");
413   } else {
414     LOG(FATAL) << "Pool type of node '" << attrs.name << "' unsupported: " << attrs.name;
415   }
416 
417   // count_include_pad
418   AttributeProto* const count_include_pad = node_proto->add_attribute();
419   count_include_pad->set_name("count_include_pad");
420   count_include_pad->set_type(AttributeProto::INT);
421   if (pooling_param.count_include_pad.has_value()) {
422     count_include_pad->set_i(pooling_param.count_include_pad.value());
423   } else {
424     count_include_pad->set_i(1);
425   }
426   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
427 }  // end ConvertPooling
428 
ConvertRelu(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs &,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)429 void ConvertRelu(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& /*attrs*/,
430                  const nnvm::IndexedGraph& ig,
431                  const array_view<IndexedGraph::NodeEntry>& inputs) {
432   NodeProto* node_proto = graph_proto->add_node();
433   node_proto->set_name(node_name);
434   node_proto->set_op_type("Relu");
435   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
436 }
437 
ConvertActivation(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)438 void ConvertActivation(GraphProto *graph_proto, const std::string& node_name,
439                        const NodeAttrs& attrs,
440                        const nnvm::IndexedGraph& ig,
441                        const array_view<IndexedGraph::NodeEntry>& inputs) {
442   NodeProto* node_proto = graph_proto->add_node();
443   node_proto->set_name(node_name);
444   const auto& act_param = nnvm::get<op::ActivationParam>(attrs.parsed);
445   std::string act_type;
446   switch (act_param.act_type) {
447     case op::activation::kReLU:
448       act_type = "Relu";
449       break;
450     case op::activation::kSigmoid:
451       act_type = "Sigmoid";
452       break;
453     case op::activation::kTanh:
454       act_type = "Tanh";
455       break;
456     case op::activation::kSoftReLU:
457       // act_type = "SoftReLU";
458       throw dmlc::Error("SoftReLU is not supported in ONNX");
459       break;
460     default:
461       throw dmlc::Error("Activation of such type doesn't exist");
462   }
463 
464   node_proto->set_op_type(act_type);
465   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
466 }
467 
ConvertFullyConnected(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)468 void ConvertFullyConnected(GraphProto *graph_proto, const std::string& node_name,
469                            const NodeAttrs& attrs,
470                            const nnvm::IndexedGraph& ig,
471                            const array_view<IndexedGraph::NodeEntry>& inputs) {
472   const auto& act_param = nnvm::get<op::FullyConnectedParam>(attrs.parsed);
473   // ONNX spec doesn't support GEMMs with input of different dims, so we need to replace it
474   // by Transpose+MatMul+Add
475   if (!act_param.flatten && !act_param.no_bias) {
476     NodeProto* tranpose_node_proto = graph_proto->add_node();
477     NodeProto* matmul_node_proto = graph_proto->add_node();
478     NodeProto* add_node_proto = graph_proto->add_node();
479     tranpose_node_proto->set_name(node_name+"_Transpose");
480     matmul_node_proto->set_name(node_name+"_MatMul");
481     add_node_proto->set_name(node_name+"_Add");
482 
483     tranpose_node_proto->set_op_type("Transpose");
484     matmul_node_proto->set_op_type("MatMul");
485     add_node_proto->set_op_type("Add");
486 
487     std::string input_node_name = ig[inputs[op::conv::kData].node_id].source->attrs.name;
488     std::string weight_node_name = ig[inputs[op::conv::kWeight].node_id].source->attrs.name;
489     std::string bias_node_name = ig[inputs[op::conv::kBias].node_id].source->attrs.name;
490 
491     tranpose_node_proto->add_input(weight_node_name);
492     tranpose_node_proto->add_output(node_name+"_Transpose");
493 
494     matmul_node_proto->add_input(input_node_name);
495     matmul_node_proto->add_input(node_name+"_Transpose");
496     matmul_node_proto->add_output(node_name+"_MatMul");
497 
498     add_node_proto->add_input(node_name+"_MatMul");
499     add_node_proto->add_input(bias_node_name);
500     // Add's output is the output of the Transpose+MatMul+Add subgraph
501     add_node_proto->add_output(node_name);
502   } else {
503     NodeProto* node_proto = graph_proto->add_node();
504     node_proto->set_name(node_name);
505     if (act_param.no_bias) {
506         node_proto->set_op_type("MatMul");
507     } else {
508         node_proto->set_op_type("Gemm");
509 
510         AttributeProto* const alpha = node_proto->add_attribute();
511         alpha->set_name("alpha");
512         alpha->set_type(AttributeProto::FLOAT);
513         alpha->set_f(1.0f);
514 
515         AttributeProto* const beta = node_proto->add_attribute();
516         beta->set_name("beta");
517         beta->set_type(AttributeProto::FLOAT);
518         beta->set_f(1.0f);
519 
520         AttributeProto* const transA = node_proto->add_attribute();
521         transA->set_name("transA");
522         transA->set_type(AttributeProto::INT);
523         transA->set_i(0);
524 
525         AttributeProto* const transB = node_proto->add_attribute();
526         transB->set_name("transB");
527         transB->set_type(AttributeProto::INT);
528         transB->set_i(1);
529     }
530     DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
531   }
532 }
533 
ConvertSlice(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)534 void ConvertSlice(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
535                   const nnvm::IndexedGraph& ig,
536                   const array_view<IndexedGraph::NodeEntry>& inputs) {
537   NodeProto* node_proto = graph_proto->add_node();
538   node_proto->set_name(node_name);
539   const auto& params = nnvm::get<SliceParam>(attrs.parsed);
540   node_proto->set_op_type("Slice");
541 
542   // starts
543   AttributeProto* const starts = node_proto->add_attribute();
544   starts->set_name("starts");
545   starts->set_type(AttributeProto::INTS);
546 
547   // ends
548   AttributeProto* const ends = node_proto->add_attribute();
549   ends->set_name("ends");
550   ends->set_type(AttributeProto::INTS);
551 
552   // axes
553   AttributeProto* const axes = node_proto->add_attribute();
554   axes->set_name("axes");
555   axes->set_type(AttributeProto::INTS);
556 
557   for (int64_t i = 1; i < params.begin.ndim(); ++i) {
558     if (params.begin[i].has_value()) {
559       starts->add_ints(static_cast<int64>(params.begin[i].value()));
560     } else {
561       starts->add_ints(static_cast<int64>(0));
562     }
563     if (params.end[i].has_value()) {
564       ends->add_ints(static_cast<int64>(params.end[i].value()));
565     } else {
566       ends->add_ints(static_cast<int64>(INT_MAX));
567     }
568     axes->add_ints(static_cast<int64>(i));
569   }
570   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
571 }
572 
ConvertSoftmaxOutput(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs &,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)573 void ConvertSoftmaxOutput(GraphProto *graph_proto, const std::string& node_name,
574                           const NodeAttrs& /*attrs*/,
575                           const nnvm::IndexedGraph& ig,
576                           const array_view<IndexedGraph::NodeEntry>& inputs) {
577   NodeProto* node_proto = graph_proto->add_node();
578   node_proto->set_name(node_name);
579   node_proto->set_op_type("Softmax");
580 
581   // Setting by default to 1 since MXNet doesn't provide such an attribute for softmax in its
582   // node params. This attribute is only relevant when the input is coerced to 2D, and in that
583   // case dimension 0 is assumed to be the batch dimension.
584   AttributeProto* const axis = node_proto->add_attribute();
585   axis->set_name("axis");
586   axis->set_type(AttributeProto::INT);
587   axis->set_i(1);
588   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
589 }
590 
591 
ConvertFlatten(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs &,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)592 void ConvertFlatten(GraphProto *graph_proto, const std::string& node_name,
593                     const NodeAttrs& /*attrs*/,
594                     const nnvm::IndexedGraph& ig,
595                     const array_view<IndexedGraph::NodeEntry>& inputs) {
596   NodeProto* node_proto = graph_proto->add_node();
597   node_proto->set_name(node_name);
598   node_proto->set_op_type("Flatten");
599 
600   // Setting by default to 1 since MXNet doesn't provide such an attribute for Flatten in its
601   // node params. This attribute is only relevant when the input is coerced to 2D, and in that
602   // case dimension 0 is assumed to be the batch dimension.
603   AttributeProto* const axis = node_proto->add_attribute();
604   axis->set_name("axis");
605   axis->set_type(AttributeProto::INT);
606   axis->set_i(1);
607   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
608 }
609 
ConvertBatchNorm(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)610 void ConvertBatchNorm(GraphProto *graph_proto, const std::string& node_name,
611                       const NodeAttrs& attrs,
612                       const nnvm::IndexedGraph& ig,
613                       const array_view<IndexedGraph::NodeEntry>& inputs) {
614   NodeProto* node_proto = graph_proto->add_node();
615   node_proto->set_name(node_name);
616   node_proto->set_op_type("BatchNormalization");
617   const auto& param = nnvm::get<op::BatchNormParam>(attrs.parsed);
618 
619   AttributeProto* const epsilon = node_proto->add_attribute();
620   epsilon->set_name("epsilon");
621   epsilon->set_type(AttributeProto::FLOAT);
622   epsilon->set_f(static_cast<float>(param.eps));
623 
624   AttributeProto* const momentum = node_proto->add_attribute();
625   momentum->set_name("momentum");
626   momentum->set_type(AttributeProto::FLOAT);
627   momentum->set_f(param.momentum);
628 
629   AttributeProto* const spatial = node_proto->add_attribute();
630   spatial->set_name("spatial");
631   spatial->set_type(AttributeProto::INT);
632   // MXNet computes mean and variance per feature for batchnorm.  Enabling spatial mode
633   // (default in ONNX3) implies running batchnorm on all spatial features so we need to explicitly
634   // disable this for MXNet's BatchNorm.
635   spatial->set_i(0);
636   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
637 }
638 
ConvertElementwiseAdd(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs &,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)639 void ConvertElementwiseAdd(GraphProto *graph_proto, const std::string& node_name,
640                            const NodeAttrs& /*attrs*/,
641                            const nnvm::IndexedGraph& ig,
642                            const array_view<IndexedGraph::NodeEntry>& inputs) {
643   NodeProto* node_proto = graph_proto->add_node();
644   node_proto->set_name(node_name);
645   node_proto->set_op_type("Add");
646   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
647 }
648 
ConvertElementwiseSub(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs &,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)649 void ConvertElementwiseSub(GraphProto *graph_proto, const std::string& node_name,
650                            const NodeAttrs& /*attrs*/,
651                            const nnvm::IndexedGraph& ig,
652                            const array_view<IndexedGraph::NodeEntry>& inputs) {
653   NodeProto* node_proto = graph_proto->add_node();
654   node_proto->set_name(node_name);
655   node_proto->set_op_type("Sub");
656   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
657 }
658 
ConvertElementwiseMul(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs &,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)659 void ConvertElementwiseMul(GraphProto *graph_proto, const std::string& node_name,
660                            const NodeAttrs& /*attrs*/,
661                            const nnvm::IndexedGraph& ig,
662                            const array_view<IndexedGraph::NodeEntry>& inputs) {
663   NodeProto* node_proto = graph_proto->add_node();
664   node_proto->set_name(node_name);
665   node_proto->set_op_type("Mul");
666   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
667 }
668 
ConvertConcatenate(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)669 void ConvertConcatenate(GraphProto *graph_proto, const std::string& node_name,
670                         const NodeAttrs& attrs,
671                         const nnvm::IndexedGraph& ig,
672                         const array_view<IndexedGraph::NodeEntry>& inputs) {
673   NodeProto* node_proto = graph_proto->add_node();
674   node_proto->set_name(node_name);
675   const auto& _param = nnvm::get<ConcatParam>(attrs.parsed);
676   node_proto->set_op_type("Concat");
677   node_proto->set_name(attrs.name);
678   // axis
679   AttributeProto* const axis = node_proto->add_attribute();
680   axis->set_name("axis");
681   axis->set_type(AttributeProto::INT);
682   axis->set_i(static_cast<int64_t>(_param.dim));
683   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
684 }
685 
ConvertDType(int dtype)686 inline TensorProto_DataType ConvertDType(int dtype) {
687   switch (dtype) {
688     case mshadow::kFloat64:
689       return TensorProto_DataType_DOUBLE;
690     case mshadow::kFloat32:
691       return TensorProto_DataType_FLOAT;
692     case mshadow::kFloat16:
693       return TensorProto_DataType_FLOAT16;
694     case mshadow::kUint8:
695       return TensorProto_DataType_UINT8;
696     case mshadow::kInt32:
697       return TensorProto_DataType_INT32;
698     case mshadow::kInt8:
699       return TensorProto_DataType_INT8;
700     case mshadow::kInt64:
701       return TensorProto_DataType_INT64;
702     default:
703       return TensorProto_DataType_UNDEFINED;
704   }
705 }
706 
GetPlaceholderShapes(const ShapeVector & shape_inputs,const nnvm::IndexedGraph & ig)707 std::unordered_map<std::string, TShape> GetPlaceholderShapes(
708     const ShapeVector& shape_inputs, const nnvm::IndexedGraph& ig) {
709   std::unordered_map<std::string, mxnet::TShape> placeholder_shapes;
710   for (uint32_t i = 0; i < shape_inputs.size(); ++i) {
711     std::string name = ig[ig.input_nodes()[i]].source->attrs.name;
712     mxnet::TShape shp = shape_inputs[i];
713     if (!mxnet::op::shape_is_none(shp)) {
714       // TODO(@reminisce): confirm
715       placeholder_shapes.emplace(name, shp);
716     }
717   }
718   return placeholder_shapes;
719 }
720 
GetPlaceholderDTypes(const DTypeVector & dtype_inputs,const nnvm::IndexedGraph & ig)721 std::unordered_map<std::string, int> GetPlaceholderDTypes(
722     const DTypeVector& dtype_inputs, const nnvm::IndexedGraph& ig) {
723   std::unordered_map<std::string, int> placeholder_dtypes;
724   for (uint32_t i = 0; i < dtype_inputs.size(); ++i) {
725     std::string name = ig[ig.input_nodes()[i]].source->attrs.name;
726     int dtype = dtype_inputs[i];
727     placeholder_dtypes.emplace(name, dtype);
728   }
729   return placeholder_dtypes;
730 }
731 
GetOutputLookup(const nnvm::IndexedGraph & ig)732 std::unordered_map<std::string, uint32_t> GetOutputLookup(
733     const nnvm::IndexedGraph& ig) {
734   std::unordered_map<std::string, uint32_t> output_lookup;
735   const std::vector<nnvm::IndexedGraph::NodeEntry>& graph_outputs =
736       ig.outputs();
737   for (uint32_t i = 0; i < graph_outputs.size(); ++i) {
738     const uint32_t id = graph_outputs[i].node_id;
739     const IndexedGraph::Node ig_node = ig[id];
740     const nnvm::Node* const source = ig_node.source;
741     const std::string name = source->attrs.name;
742     output_lookup.emplace(name, i);
743   }
744   return output_lookup;
745 }
746 
ConvertPlaceholder(const std::string & node_name,const std::unordered_map<std::string,TShape> & placeholder_shapes,const std::unordered_map<std::string,int> & placeholder_dtypes,GraphProto * const graph_proto)747 void ConvertPlaceholder(
748     const std::string& node_name,
749     const std::unordered_map<std::string, TShape>& placeholder_shapes,
750     const std::unordered_map<std::string, int>& placeholder_dtypes,
751     GraphProto* const graph_proto) {
752   auto val_info_proto = graph_proto->add_input();
753   auto type_proto = val_info_proto->mutable_type()->mutable_tensor_type();
754   auto shape_proto = type_proto->mutable_shape();
755 
756   val_info_proto->set_name(node_name);
757   auto entry_shape = placeholder_shapes.find(node_name)->second;
758   auto entry_dtype = placeholder_dtypes.find(node_name)->second;
759   type_proto->set_elem_type(ConvertDType(entry_dtype));
760   for (const auto& elem : entry_shape) {
761     TensorShapeProto_Dimension* const tsp_dim = shape_proto->add_dim();
762     tsp_dim->set_dim_value(static_cast<int64>(elem));
763   }
764 }
765 
ConvertConstant(GraphProto * const graph_proto,const std::string & node_name,const std::unordered_map<std::string,NDArray> * const params_map)766 void ConvertConstant(
767     GraphProto* const graph_proto, const std::string& node_name,
768     const std::unordered_map<std::string, NDArray>* const params_map) {
769     TensorProto* const initializer_proto = graph_proto->add_initializer();
770 
771   // Create initializer for constants
772   initializer_proto->set_name(node_name);
773 
774   const NDArray nd = params_map->find(node_name)->second;
775   const TBlob& blob = nd.data();
776   const TShape shape = blob.shape_;
777   const auto dtype = ConvertDType(nd.dtype());
778   initializer_proto->set_data_type(dtype);
779 
780   for (auto& dim : shape) {
781     initializer_proto->add_dims(static_cast<int64>(dim));
782   }
783 
784   auto size = shape.Size();
785 
786   if (dtype == TensorProto_DataType_FLOAT) {
787     std::shared_ptr<float[]> shared_data_ptr(new float[size]);
788     float* const data_ptr = shared_data_ptr.get();
789     nd.SyncCopyToCPU(static_cast<void*>(data_ptr), size);
790 
791     for (size_t blob_idx = 0; blob_idx < size; ++blob_idx) {
792       initializer_proto->add_float_data(data_ptr[blob_idx]);
793     }
794   } else if (dtype == TensorProto_DataType_FLOAT16) {
795     std::shared_ptr<uint16_t[]> shared_data_ptr(new uint16_t[size]);
796     uint16_t* const data_ptr = shared_data_ptr.get();
797     nd.SyncCopyToCPU(static_cast<void*>(data_ptr), size);
798     for (size_t blob_idx = 0; blob_idx < size; ++blob_idx) {
799       initializer_proto->add_int32_data(
800           reinterpret_cast<int32_t*>(data_ptr)[blob_idx]);
801     }
802   } else {
803     LOG(FATAL) << "dtype not supported for variables: " << node_name;
804   }
805 
806   // Create inputs for constants.
807   ValueInfoProto* const input_proto = graph_proto->add_input();
808   input_proto->set_name(node_name);
809 
810   input_proto->mutable_type()->mutable_tensor_type()->set_elem_type(dtype);
811   for (auto& dim : shape) {
812     auto new_dim = input_proto->mutable_type()->mutable_tensor_type()->mutable_shape()->add_dim();
813     new_dim->set_dim_value(static_cast<int64>(dim));
814   }
815 }
816 
ConvertOutput(GraphProto * const graph_proto,const std::unordered_map<std::string,uint32_t>::iterator & out_iter,const std::string & node_name,const ShapeVector & shapes,const DTypeVector & dtypes,const nnvm::IndexedGraph & ig)817 void ConvertOutput(
818     GraphProto* const graph_proto,
819     const std::unordered_map<std::string, uint32_t>::iterator& out_iter,
820     const std::string& node_name, const ShapeVector& shapes,
821     const DTypeVector& dtypes, const nnvm::IndexedGraph &ig) {
822   uint32_t out_idx = ig.entry_id(ig.outputs()[out_iter->second]);
823   int dtype = dtypes[out_idx];
824   auto graph_out = graph_proto->add_output();
825   auto tensor_type = graph_out->mutable_type()->mutable_tensor_type();
826   auto tensor_shape_proto = tensor_type->mutable_shape();
827   graph_out->set_name(node_name);
828 
829   // Also support fp16.
830   tensor_type->set_elem_type(ConvertDType(dtype));
831 
832   for (int64_t dim_shp : shapes[out_idx]) {
833     TensorShapeProto_Dimension* const tsp_dim = tensor_shape_proto->add_dim();
834     tsp_dim->set_dim_value(static_cast<int64>(dim_shp));
835   }
836 }
837 
ConvertClip(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)838 void ConvertClip(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
839                  const nnvm::IndexedGraph& ig,
840                  const array_view<IndexedGraph::NodeEntry>& inputs) {
841   NodeProto* node_proto = graph_proto->add_node();
842   node_proto->set_name(node_name);
843   const auto& param = nnvm::get<ClipParam>(attrs.parsed);
844 
845   node_proto->set_op_type("Clip");
846 
847   // max
848   AttributeProto* const a_max = node_proto->add_attribute();
849   a_max->set_name("max");
850   a_max->set_type(AttributeProto::FLOAT);
851   a_max->set_f(static_cast<float>(param.a_max));
852 
853   // min
854   AttributeProto* const a_min = node_proto->add_attribute();
855   a_min->set_name("min");
856   a_min->set_type(AttributeProto::FLOAT);
857   a_min->set_f(static_cast<float>(param.a_min));
858   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
859 }
860 
ConvertPad(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)861 void ConvertPad(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
862                 const nnvm::IndexedGraph& ig,
863                 const array_view<IndexedGraph::NodeEntry>& inputs) {
864   NodeProto* node_proto = graph_proto->add_node();
865   node_proto->set_name(node_name);
866   const auto& param = nnvm::get<PadParam>(attrs.parsed);
867 
868   node_proto->set_op_type("Pad");
869 
870   // mode
871   AttributeProto* const mode = node_proto->add_attribute();
872   mode->set_name("mode");
873   mode->set_type(AttributeProto::STRING);
874   switch (param.mode) {
875     case op::pad_enum::kConstant:
876       mode->set_s("constant");
877       break;
878     case op::pad_enum::kEdge:
879       mode->set_s("edge");
880       break;
881     case op::pad_enum::kReflect:
882       mode->set_s("reflect");
883       break;
884     default:
885       throw dmlc::Error("Such mode of padding doesn't exist");
886   }
887 
888   // pads
889   AttributeProto* const pads = node_proto->add_attribute();
890   pads->set_name("pads");
891   pads->set_type(AttributeProto::INTS);
892 
893   std::vector<int64> pad_begin;
894   std::vector<int64> pad_end;
895   for (int st = 0; st < 2; ++st) {
896     for (auto it = param.pad_width.begin() + st;
897          it != param.pad_width.end(); it += 2) {
898       pads->add_ints(static_cast<int64>(*it));
899     }
900   }
901 
902   // value
903   AttributeProto* const value = node_proto->add_attribute();
904   value->set_name("value");
905   value->set_type(AttributeProto::FLOAT);
906   value->set_f(param.constant_value);
907   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
908 }
909 
ConvertDropout(GraphProto * graph_proto,const std::string & node_name,const NodeAttrs & attrs,const nnvm::IndexedGraph & ig,const array_view<IndexedGraph::NodeEntry> & inputs)910 void ConvertDropout(GraphProto *graph_proto, const std::string& node_name, const NodeAttrs& attrs,
911                     const nnvm::IndexedGraph& ig,
912                     const array_view<IndexedGraph::NodeEntry>& inputs) {
913   NodeProto* node_proto = graph_proto->add_node();
914   node_proto->set_name(node_name);
915   node_proto->set_op_type("Dropout");
916   DefaultConnectInputsOutputs(node_proto, inputs, ig, node_name);
917 }
918 
PreprocessBatchNorm(const NodeAttrs & attrs,const std::vector<nnvm::NodeEntry> & inputs,std::unordered_map<std::string,NDArray> * params_map)919 void PreprocessBatchNorm(const NodeAttrs &attrs,
920                          const std::vector<nnvm::NodeEntry> &inputs,
921                          std::unordered_map<std::string, NDArray> *params_map) {
922   const auto& param = nnvm::get<op::BatchNormParam>(attrs.parsed);
923   if (param.fix_gamma) {
924     // if mxnet is specify fix_gamma, we will need to preprocess the params map
925     // to convert the gamma associate with this batch norm layer to 1.
926     std::string gammaNodeName = inputs[batchnorm::kGamma].node->attrs.name;
927     (*params_map)[gammaNodeName] = 1.0f;
928   }
929 }
930 
931 }  // namespace nnvm_to_onnx
932 }  // namespace op
933 }  // namespace mxnet
934 
935 #endif  // MXNET_USE_TENSORRT
936