1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19# pylint: disable=invalid-name,too-many-locals,no-self-use 20""" Support import export formats.""" 21import numpy as np 22from .... import symbol 23from .... import ndarray as nd 24from ....base import string_types 25from ._import_helper import _convert_map as convert_map 26 27class GraphProto(object): # pylint: disable=too-few-public-methods 28 """A helper class for handling mxnet symbol copying from pb2.GraphProto. 29 Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto 30 """ 31 def __init__(self): 32 self._nodes = {} 33 self._params = {} 34 self._num_input = 0 35 self._num_param = 0 36 self.aux_dict = {} 37 self.arg_dict = {} 38 self.model_metadata = {} 39 self.opset_version = 0 40 41 def _convert_operator(self, node_name, op_name, attrs, inputs): 42 """Convert from onnx operator to mxnet operator. 43 The converter must specify conversions explicitly for incompatible name, and 44 apply handlers to operator attributes. 45 46 Parameters 47 ---------- 48 :param node_name : str 49 name of the node to be translated. 50 :param op_name : str 51 Operator name, such as Convolution, FullyConnected 52 :param attrs : dict 53 Dict of operator attributes 54 :param inputs: list 55 list of inputs to the operator 56 Returns 57 ------- 58 :return mxnet_sym 59 Converted mxnet symbol 60 """ 61 if op_name in convert_map: 62 op_name, new_attrs, inputs = convert_map[op_name](attrs, inputs, self) 63 else: 64 raise NotImplementedError("Operator {} not implemented.".format(op_name)) 65 if isinstance(op_name, string_types): 66 new_op = getattr(symbol, op_name, None) 67 if not new_op: 68 raise RuntimeError("Unable to map op_name {} to sym".format(op_name)) 69 if node_name is None: 70 mxnet_sym = new_op(*inputs, **new_attrs) 71 else: 72 mxnet_sym = new_op(name=node_name, *inputs, **new_attrs) 73 return mxnet_sym 74 return op_name 75 76 def from_onnx(self, graph, opset_version): 77 """Construct symbol from onnx graph. 78 79 Parameters 80 ---------- 81 graph : onnx protobuf object 82 The loaded onnx graph 83 84 Returns 85 ------- 86 sym :symbol.Symbol 87 The returned mxnet symbol 88 params : dict 89 A dict of name: nd.array pairs, used as pretrained weights 90 """ 91 self.opset_version = opset_version 92 # get input, output shapes 93 self.model_metadata = self.get_graph_metadata(graph) 94 # parse network inputs, aka parameters 95 for init_tensor in graph.initializer: 96 if not init_tensor.name.strip(): 97 raise ValueError("Tensor's name is required.") 98 self._params[init_tensor.name] = self._parse_array(init_tensor) 99 100 # converting GraphProto message 101 for i in graph.input: 102 if i.name in self._params: 103 # i is a param instead of input 104 self._nodes[i.name] = symbol.Variable(name=i.name, 105 shape=self._params[i.name].shape) 106 else: 107 self._nodes[i.name] = symbol.Variable(name=i.name) 108 109 # constructing nodes, nodes are stored as directed acyclic graph 110 # converting NodeProto message 111 for node in graph.node: 112 op_name = node.op_type 113 node_name = node.name.strip() 114 node_name = node_name if node_name else None 115 onnx_attr = self._parse_attr(node.attribute) 116 inputs = [self._nodes[i] for i in node.input] 117 mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, inputs) 118 119 for k, i in zip(list(node.output), range(len(mxnet_sym.list_outputs()))): 120 self._nodes[k] = mxnet_sym[i] 121 122 # splitting params into args and aux params 123 for args in mxnet_sym.list_arguments(): 124 if args in self._params: 125 self.arg_dict.update({args: nd.array(self._params[args])}) 126 for aux in mxnet_sym.list_auxiliary_states(): 127 if aux in self._params: 128 self.aux_dict.update({aux: nd.array(self._params[aux])}) 129 130 # now return the outputs 131 out = [self._nodes[i.name] for i in graph.output] 132 if len(out) > 1: 133 out = symbol.Group(out) 134 else: 135 out = out[0] 136 return out, self.arg_dict, self.aux_dict 137 138 def get_graph_metadata(self, graph): 139 """ 140 Get the model metadata from a given onnx graph. 141 """ 142 _params = set() 143 for tensor_vals in graph.initializer: 144 _params.add(tensor_vals.name) 145 146 input_data = [] 147 for graph_input in graph.input: 148 if graph_input.name not in _params: 149 shape = [val.dim_value for val in graph_input.type.tensor_type.shape.dim] 150 dtype = graph_input.type.tensor_type.elem_type 151 input_data.append((graph_input.name, tuple(shape), dtype)) 152 153 output_data = [] 154 for graph_out in graph.output: 155 shape = [val.dim_value for val in graph_out.type.tensor_type.shape.dim] 156 output_data.append((graph_out.name, tuple(shape))) 157 metadata = {'input_tensor_data' : input_data, 158 'output_tensor_data' : output_data 159 } 160 return metadata 161 162 def graph_to_gluon(self, graph, ctx, opset_version): 163 """Construct SymbolBlock from onnx graph. 164 165 Parameters 166 ---------- 167 graph : onnx protobuf object 168 The loaded onnx graph 169 ctx : Context or list of Context 170 Loads the model into one or many context(s). 171 172 Returns 173 ------- 174 sym_block :gluon.nn.SymbolBlock 175 The returned gluon SymbolBlock 176 """ 177 sym, arg_params, aux_params = self.from_onnx(graph, opset_version) 178 metadata = self.get_graph_metadata(graph) 179 data_names = [input_tensor[0] for input_tensor in metadata['input_tensor_data']] 180 data_inputs = [symbol.var(data_name) for data_name in data_names] 181 182 from ....gluon import SymbolBlock 183 net = SymbolBlock(outputs=sym, inputs=data_inputs) 184 net_params = net.collect_params() 185 for param in arg_params: 186 if param in net_params: 187 net_params[param].shape = arg_params[param].shape 188 net_params[param]._load_init(arg_params[param], ctx=ctx) 189 for param in aux_params: 190 if param in net_params: 191 net_params[param].shape = aux_params[param].shape 192 net_params[param]._load_init(aux_params[param], ctx=ctx) 193 return net 194 195 def _parse_array(self, tensor_proto): 196 """Grab data in TensorProto and convert to numpy array.""" 197 try: 198 from onnx.numpy_helper import to_array 199 except ImportError: 200 raise ImportError("Onnx and protobuf need to be installed. " 201 + "Instructions to install - https://github.com/onnx/onnx") 202 if len(tuple(tensor_proto.dims)) > 0: 203 np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims)) 204 else: 205 # If onnx's params are scalar values without dims mentioned. 206 np_array = np.array([to_array(tensor_proto)]) 207 return nd.array(np_array) 208 209 def _parse_attr(self, attr_proto): 210 """Convert a list of AttributeProto to a dict, with names as keys.""" 211 attrs = {} 212 for a in attr_proto: 213 for f in ['f', 'i', 's']: 214 if a.HasField(f): 215 attrs[a.name] = getattr(a, f) 216 # Needed for supporting python version > 3.5 217 if isinstance(attrs[a.name], bytes): 218 attrs[a.name] = attrs[a.name].decode(encoding='utf-8') 219 for f in ['floats', 'ints', 'strings']: 220 if list(getattr(a, f)): 221 assert a.name not in attrs, "Only one type of attr is allowed" 222 attrs[a.name] = tuple(getattr(a, f)) 223 for f in ['t', 'g']: 224 if a.HasField(f): 225 attrs[a.name] = getattr(a, f) 226 for f in ['tensors', 'graphs']: 227 if list(getattr(a, f)): 228 raise NotImplementedError("Filed {} is not supported in mxnet.".format(f)) 229 if a.name not in attrs: 230 raise ValueError("Cannot parse attribute: \n{}\n.".format(a)) 231 return attrs 232