1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable=invalid-name,too-many-locals,no-self-use
20""" Support import export formats."""
21import numpy as np
22from .... import symbol
23from .... import ndarray as nd
24from ....base import string_types
25from ._import_helper import _convert_map as convert_map
26
27class GraphProto(object): # pylint: disable=too-few-public-methods
28    """A helper class for handling mxnet symbol copying from pb2.GraphProto.
29    Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
30    """
31    def __init__(self):
32        self._nodes = {}
33        self._params = {}
34        self._num_input = 0
35        self._num_param = 0
36        self.aux_dict = {}
37        self.arg_dict = {}
38        self.model_metadata = {}
39        self.opset_version = 0
40
41    def _convert_operator(self, node_name, op_name, attrs, inputs):
42        """Convert from onnx operator to mxnet operator.
43        The converter must specify conversions explicitly for incompatible name, and
44        apply handlers to operator attributes.
45
46        Parameters
47        ----------
48        :param node_name : str
49            name of the node to be translated.
50        :param op_name : str
51            Operator name, such as Convolution, FullyConnected
52        :param attrs : dict
53            Dict of operator attributes
54        :param inputs: list
55            list of inputs to the operator
56        Returns
57        -------
58        :return mxnet_sym
59            Converted mxnet symbol
60        """
61        if op_name in convert_map:
62            op_name, new_attrs, inputs = convert_map[op_name](attrs, inputs, self)
63        else:
64            raise NotImplementedError("Operator {} not implemented.".format(op_name))
65        if isinstance(op_name, string_types):
66            new_op = getattr(symbol, op_name, None)
67            if not new_op:
68                raise RuntimeError("Unable to map op_name {} to sym".format(op_name))
69            if node_name is None:
70                mxnet_sym = new_op(*inputs, **new_attrs)
71            else:
72                mxnet_sym = new_op(name=node_name, *inputs, **new_attrs)
73            return mxnet_sym
74        return op_name
75
76    def from_onnx(self, graph, opset_version):
77        """Construct symbol from onnx graph.
78
79        Parameters
80        ----------
81        graph : onnx protobuf object
82            The loaded onnx graph
83
84        Returns
85        -------
86        sym :symbol.Symbol
87            The returned mxnet symbol
88        params : dict
89            A dict of name: nd.array pairs, used as pretrained weights
90        """
91        self.opset_version = opset_version
92        # get input, output shapes
93        self.model_metadata = self.get_graph_metadata(graph)
94        # parse network inputs, aka parameters
95        for init_tensor in graph.initializer:
96            if not init_tensor.name.strip():
97                raise ValueError("Tensor's name is required.")
98            self._params[init_tensor.name] = self._parse_array(init_tensor)
99
100        # converting GraphProto message
101        for i in graph.input:
102            if i.name in self._params:
103                # i is a param instead of input
104                self._nodes[i.name] = symbol.Variable(name=i.name,
105                                                      shape=self._params[i.name].shape)
106            else:
107                self._nodes[i.name] = symbol.Variable(name=i.name)
108
109        # constructing nodes, nodes are stored as directed acyclic graph
110        # converting NodeProto message
111        for node in graph.node:
112            op_name = node.op_type
113            node_name = node.name.strip()
114            node_name = node_name if node_name else None
115            onnx_attr = self._parse_attr(node.attribute)
116            inputs = [self._nodes[i] for i in node.input]
117            mxnet_sym = self._convert_operator(node_name, op_name, onnx_attr, inputs)
118
119            for k, i in zip(list(node.output), range(len(mxnet_sym.list_outputs()))):
120                self._nodes[k] = mxnet_sym[i]
121
122            # splitting params into args and aux params
123            for args in mxnet_sym.list_arguments():
124                if args in self._params:
125                    self.arg_dict.update({args: nd.array(self._params[args])})
126            for aux in mxnet_sym.list_auxiliary_states():
127                if aux in self._params:
128                    self.aux_dict.update({aux: nd.array(self._params[aux])})
129
130        # now return the outputs
131        out = [self._nodes[i.name] for i in graph.output]
132        if len(out) > 1:
133            out = symbol.Group(out)
134        else:
135            out = out[0]
136        return out, self.arg_dict, self.aux_dict
137
138    def get_graph_metadata(self, graph):
139        """
140        Get the model metadata from a given onnx graph.
141        """
142        _params = set()
143        for tensor_vals in graph.initializer:
144            _params.add(tensor_vals.name)
145
146        input_data = []
147        for graph_input in graph.input:
148            if graph_input.name not in _params:
149                shape = [val.dim_value for val in graph_input.type.tensor_type.shape.dim]
150                dtype = graph_input.type.tensor_type.elem_type
151                input_data.append((graph_input.name, tuple(shape), dtype))
152
153        output_data = []
154        for graph_out in graph.output:
155            shape = [val.dim_value for val in graph_out.type.tensor_type.shape.dim]
156            output_data.append((graph_out.name, tuple(shape)))
157        metadata = {'input_tensor_data' : input_data,
158                    'output_tensor_data' : output_data
159                   }
160        return metadata
161
162    def graph_to_gluon(self, graph, ctx, opset_version):
163        """Construct SymbolBlock from onnx graph.
164
165        Parameters
166        ----------
167        graph : onnx protobuf object
168            The loaded onnx graph
169        ctx : Context or list of Context
170            Loads the model into one or many context(s).
171
172        Returns
173        -------
174        sym_block :gluon.nn.SymbolBlock
175            The returned gluon SymbolBlock
176        """
177        sym, arg_params, aux_params = self.from_onnx(graph, opset_version)
178        metadata = self.get_graph_metadata(graph)
179        data_names = [input_tensor[0] for input_tensor in metadata['input_tensor_data']]
180        data_inputs = [symbol.var(data_name) for data_name in data_names]
181
182        from ....gluon import SymbolBlock
183        net = SymbolBlock(outputs=sym, inputs=data_inputs)
184        net_params = net.collect_params()
185        for param in arg_params:
186            if param in net_params:
187                net_params[param].shape = arg_params[param].shape
188                net_params[param]._load_init(arg_params[param], ctx=ctx)
189        for param in aux_params:
190            if param in net_params:
191                net_params[param].shape = aux_params[param].shape
192                net_params[param]._load_init(aux_params[param], ctx=ctx)
193        return net
194
195    def _parse_array(self, tensor_proto):
196        """Grab data in TensorProto and convert to numpy array."""
197        try:
198            from onnx.numpy_helper import to_array
199        except ImportError:
200            raise ImportError("Onnx and protobuf need to be installed. "
201                              + "Instructions to install - https://github.com/onnx/onnx")
202        if len(tuple(tensor_proto.dims)) > 0:
203            np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
204        else:
205            # If onnx's params are scalar values without dims mentioned.
206            np_array = np.array([to_array(tensor_proto)])
207        return nd.array(np_array)
208
209    def _parse_attr(self, attr_proto):
210        """Convert a list of AttributeProto to a dict, with names as keys."""
211        attrs = {}
212        for a in attr_proto:
213            for f in ['f', 'i', 's']:
214                if a.HasField(f):
215                    attrs[a.name] = getattr(a, f)
216                    # Needed for supporting python version  > 3.5
217                    if isinstance(attrs[a.name], bytes):
218                        attrs[a.name] = attrs[a.name].decode(encoding='utf-8')
219            for f in ['floats', 'ints', 'strings']:
220                if list(getattr(a, f)):
221                    assert a.name not in attrs, "Only one type of attr is allowed"
222                    attrs[a.name] = tuple(getattr(a, f))
223            for f in ['t', 'g']:
224                if a.HasField(f):
225                    attrs[a.name] = getattr(a, f)
226            for f in ['tensors', 'graphs']:
227                if list(getattr(a, f)):
228                    raise NotImplementedError("Filed {} is not supported in mxnet.".format(f))
229            if a.name not in attrs:
230                raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
231        return attrs
232