1import inspect
2from itertools import chain
3
4import numpy as np
5from onnx import NodeProto
6from onnx import TensorProto
7from onnx import ValueInfoProto
8from onnx import numpy_helper
9from onnx.helper import make_graph
10from onnx.helper import make_tensor
11from onnx.helper import make_tensor_value_info
12from onnx.helper import mapping
13import tensorflow as tf
14from tensorflow.core.framework.attr_value_pb2 import AttrValue
15from tensorflow.core.framework.node_def_pb2 import NodeDef
16
17from onnx_tf.common import attr_converter
18from onnx_tf.common import attr_translator
19from onnx_tf.common import CONST_MINUS_ONE_INT32
20from onnx_tf.common import CONST_ONE_FP32
21from onnx_tf.common import CONST_ONE_INT32
22from onnx_tf.common import CONST_ZERO_INT32
23from onnx_tf.common import IS_PYTHON3
24from onnx_tf.common import logger
25from onnx_tf.common.data_type import any_dtype_to_onnx_dtype
26
27class TensorflowNode(object):
28
29  def __init__(self,
30               node=None,
31               name=None,
32               inputs=None,
33               outputs=None,
34               attr=None,
35               domain=None,
36               op_type=None):
37    # storing a reference to the original protobuf object
38    if node is None:
39      self.node = None
40      self.name = name or ""
41      self.inputs = inputs or []
42      self.attr = attr or {}
43      self.domain = domain or ""
44      self.op_type = op_type or ""
45      self.outputs = outputs or self.get_outputs_names()
46    elif isinstance(node, (OnnxNode, NodeProto)):
47      self._load_onnx_node(node)
48    elif isinstance(node, NodeDef):
49      self._load_tf_node(node)
50
51  def _load_onnx_node(self, node):
52    if isinstance(node, NodeProto):
53      node = OnnxNode(node)
54    self.name = node.name
55    self.inputs = node.inputs
56    self.outputs = node.outputs
57    self.attr = node.attrs
58    self.domain = node.domain
59    self.op_type = node.op_type
60
61  def _load_tf_node(self, node):
62    self.node = node
63    self.name = node.name
64    self.inputs = list(node.input)
65    self.attr = {}
66    for key, val in node.attr.items():
67      new_val = attr_translator.translate_tf(key, val)
68      if isinstance(new_val, AttrValue):
69        new_val = attr_converter.convert_tf(new_val)
70      self.attr[key] = new_val
71    splitted_op_name = node.op.split(".")
72    self.domain = "" if len(splitted_op_name) == 1 else ".".join(
73        splitted_op_name[:-1])
74    self.op_type = splitted_op_name[-1]
75    self.outputs = self.get_outputs_names()
76
77  def get_outputs_names(self, num=None):
78    """ Helper method to get outputs names.
79    e.g. tf.split: [Split, Split:1, Split:2]
80
81    :param num: Force to get `num` outputs names.
82    :return: List of outputs names.
83    """
84    if num is None:
85      if "_output_shapes" in self.attr:
86        num = len(self.attr["_output_shapes"])
87      else:
88        num = 1
89        logger.warning("_output_shapes is not in node.attr. "
90                      "The num of output is set to 1 for commonly. "
91                      "It will cause problem with case of multiple outputs.")
92    return [
93        self.name + ":{}".format(i) if i > 0 else self.name for i in range(num)
94    ]
95
96
97class TensorflowGraph(object):
98
99  def __init__(self, graph_def, outputs=(), graph_name="graph"):
100    self._graph_name = graph_name
101    self._graph_def = self._process_graph_def(graph_def)
102    self._nodes = self._create_util_nodes() + [
103        TensorflowNode(node) for node in self.graph_def.node
104    ]
105    self._nodes_dict = {n.name: n for n in self._nodes}
106    self._outputs = outputs or self.get_output_node_names(self.graph_def)
107
108  @staticmethod
109  def _create_util_nodes():
110    util_nodes = [(CONST_MINUS_ONE_INT32, np.array([-1]).astype(np.int32)),
111                  (CONST_ZERO_INT32, np.array([0]).astype(np.int32)),
112                  (CONST_ONE_INT32, np.array([1]).astype(np.int32))]
113    return [
114        TensorflowNode(
115            op_type="Const",
116            name=name,
117            attr={
118                "value": value,
119                "dtype": any_dtype_to_onnx_dtype(value.dtype),
120                "_output_shapes": [value.shape]
121            }) for name, value in util_nodes
122    ]
123
124  def get_node_by_name(self, name):
125    node = self._nodes_dict.get(name, None)
126    if node is None:
127      raise ValueError(
128          "Node {} is not found in the graph provided".format(name))
129    return node
130
131  def _process_graph_def(self, graph_def):
132    if "_output_shapes" not in TensorflowNode(graph_def.node[0]).attr:
133      graph_def = self._add_infer_shapes(graph_def)
134    return graph_def
135
136  @staticmethod
137  def _add_infer_shapes(graph_def):
138    with tf.Graph().as_default():
139      with tf.Session(
140          config=tf.ConfigProto(
141              graph_options=tf.GraphOptions(infer_shapes=True))) as sess:
142        tf.import_graph_def(graph_def, name="")
143      return sess.graph_def
144
145  @staticmethod
146  def get_output_node_names(graph_def):
147    """Get output node names from GraphDef.
148
149    Args:
150      graph_def: GraphDef object.
151
152    Returns:
153      List of output node names.
154    """
155    input_names, output_names = set(), set()
156    for node in graph_def.node:
157      output_names.add(node.name)
158      input_names.update(set(node.input))
159    return list(output_names - input_names)
160
161  def update_nodes(self, nodes):
162    self._nodes = nodes
163    self._nodes_dict = {n.name: n for n in self._nodes}
164
165  @property
166  def graph_def(self):
167    return self._graph_def
168
169  @property
170  def graph_name(self):
171    return self._graph_name
172
173  @property
174  def nodes(self):
175    return self._nodes
176
177  @property
178  def nodes_dict(self):
179    return self._nodes_dict
180
181  @property
182  def outputs(self):
183    return self._outputs
184
185
186# TODO: Move this into ONNX main library
187class OnnxNode(object):
188  """
189  Reimplementation of NodeProto from ONNX, but in a form
190  more convenient to work with from Python.
191  """
192
193  def __init__(self, node):
194    self.name = str(node.name)
195    self.op_type = str(node.op_type)
196    self.domain = str(node.domain)
197    self.attrs = dict([(attr.name,
198                        attr_translator.translate_onnx(
199                            attr.name, attr_converter.convert_onnx(attr)))
200                       for attr in node.attribute])
201    self.inputs = list(node.input)
202    self.outputs = list(node.output)
203    self.node_proto = node
204
205
206class OnnxGraph(object):
207  """ A helper class for making ONNX graph.
208  This class holds all information ONNX graph needs.
209  """
210
211  def __init__(self, name=None, graph_proto=None):
212    if graph_proto:
213      self._name = graph_proto.name
214      self._inputs_proto = list(graph_proto.input)
215      self._outputs_proto = list(graph_proto.output)
216      self._nodes_proto = list(graph_proto.node)
217      self._consts_proto = list(graph_proto.initializer)
218      self._value_info_proto = list(graph_proto.value_info)
219      self._consts = dict([(init.name, numpy_helper.to_array(init))
220                           for init in graph_proto.initializer])
221    else:
222      self._name = name or ""
223      self._inputs_proto = []
224      self._outputs_proto = []
225      self._nodes_proto = []
226      self._consts = {}
227      self._consts_proto = []
228      self._value_info_proto = []
229    # Either way, data_type_cast_map is empty when initialized.
230    self._data_type_cast_map = {}
231
232    self._add_utility_constants()
233
234  def _add_utility_constants(self):
235    util_consts = {CONST_ONE_FP32: np.array([1.0]).astype(np.float32)}
236    # Add a few useful utility constants:
237    for name, value in util_consts.items():
238      self.add_const_explicit(name=name, value=value)
239      self.add_const_proto_explicit(
240          name=name, value=value, np_dtype=value.dtype)
241      self.add_input_proto_explicit(
242          name=name, shape=value.shape, np_dtype=value.dtype)
243
244  # This list holds the protobuf objects of type ValueInfoProto
245  # representing the input to the converted ONNX graph.
246  @property
247  def inputs_proto(self):
248    return self._inputs_proto
249
250  @inputs_proto.setter
251  def inputs_proto(self, inputs_proto):
252    self._inputs_proto = inputs_proto
253
254  @property
255  def all_node_inputs(self):
256    return list(chain.from_iterable(map(lambda p: p.input, self._nodes_proto)))
257
258  @property
259  def outputs(self):
260    return list(map(lambda p: p.name, self._outputs_proto))
261
262  @property
263  def outputs_proto(self):
264    return self._outputs_proto
265
266  # This list holds the protobuf objects of type NodeProto
267  # representing the ops in the converted ONNX graph.
268  @property
269  def nodes_proto(self):
270    return self._nodes_proto
271
272  @nodes_proto.setter
273  def nodes_proto(self, nodes_proto):
274    self._nodes_proto = nodes_proto
275
276  # This dictionary contains a map from the name of the constant
277  # op to the array of values it holds. This is useful because
278  # tensorflow is less eager to know about input values at
279  # graph construction time than ONNX. That is to say, some ONNX
280  # attributes are input tensors in TF. This dictionary extracts
281  # those values of constant tensors that are known at graph
282  # construction time.
283  @property
284  def consts(self):
285    return self._consts
286
287  @consts.setter
288  def consts(self, consts):
289    self._consts = consts
290
291  # Sometimes the constants are used as inputs to ops. This list
292  # holds initializers that creates global constant tensors available
293  # to be accessed by ops as inputs (as oppose to attributes which
294  # is supplied by the `consts` map above).
295  @property
296  def consts_proto(self):
297    return self._consts_proto
298
299  @consts_proto.setter
300  def consts_proto(self, consts_proto):
301    self._consts_proto = consts_proto
302
303  # A map holds nodes name and new data type. Will be used to
304  # process protos to match ONNX type constraints.
305  @property
306  def data_type_cast_map(self):
307    return self._data_type_cast_map
308
309  @data_type_cast_map.setter
310  def data_type_cast_map(self, data_type_cast_map):
311    self._data_type_cast_map = data_type_cast_map
312
313  # This list holds the protobuf objects of type ValueInfoProto
314  # representing the all nodes' outputs to the converted ONNX graph.
315  @property
316  def value_info_proto(self):
317    return self._value_info_proto
318
319  def add_input_proto_explicit(self,
320                               name,
321                               shape,
322                               np_dtype=None,
323                               tf_dtype=None,
324                               onnx_dtype=None):
325    onnx_dtype = any_dtype_to_onnx_dtype(
326        np_dtype=np_dtype, tf_dtype=tf_dtype, onnx_dtype=onnx_dtype)
327    input_proto = make_tensor_value_info(name, onnx_dtype, shape)
328    self._inputs_proto.append(input_proto)
329
330  def add_input_proto(self, node):
331    name = node.name
332    onnx_dtype = node.attr["dtype"]
333    shape = node.attr["shape"] if node.op_type != "Const" else node.attr[
334        'value'].shape
335    self.add_input_proto_explicit(name, shape, onnx_dtype=onnx_dtype)
336
337  def add_output_proto(self, node):
338    output_onnx_type = node.attr.get("T", TensorProto.BOOL)
339    for i, output_shape in enumerate(node.attr["_output_shapes"]):
340      output_name = node.name + ":{}".format(i) if i > 0 else node.name
341      self._outputs_proto.append(
342          make_tensor_value_info(output_name, output_onnx_type, output_shape))
343
344  def add_node_proto(self, node_proto):
345    if not isinstance(node_proto, (list, tuple)):
346      node_proto = [node_proto]
347    self._nodes_proto.extend(node_proto)
348
349  def remove_node_proto(self, names):
350    if not isinstance(names, (list, tuple)):
351      names = [names]
352    self._nodes_proto = list(
353        filter(lambda x: x.name not in names, self._nodes_proto))
354
355  def add_const_explicit(self, name, value):
356    self._consts[name] = value
357
358  def add_const(self, node):
359    self.add_const_explicit(node.name, node.attr["value"])
360
361  def add_const_proto_explicit(self,
362                               name,
363                               value,
364                               np_dtype=None,
365                               tf_dtype=None,
366                               onnx_dtype=None):
367    onnx_dtype = any_dtype_to_onnx_dtype(
368        np_dtype=np_dtype, tf_dtype=tf_dtype, onnx_dtype=onnx_dtype)
369
370    const_dim = len(value.shape)
371
372    if const_dim == 0:
373      raw_values = [value.tolist()]
374      values = [value]
375    else:
376      raw_values = value.flatten().tolist()
377      values = value
378
379    shape = np.array(values).shape
380    const_proto = make_tensor(
381        name=name, data_type=onnx_dtype, dims=shape, vals=raw_values)
382    self._consts_proto.append(const_proto)
383
384  def add_const_proto(self, node):
385    self.add_const_proto_explicit(
386        node.name, node.attr["value"], onnx_dtype=node.attr["dtype"])
387
388  def add_value_info_proto(self, node):
389    node_onnx_type = node.attr.get("T", TensorProto.BOOL)
390    for i, output_shape in enumerate(node.attr["_output_shapes"]):
391      node_name = node.name + ":{}".format(i) if i > 0 else node.name
392      value_info_proto = make_tensor_value_info(node_name, node_onnx_type,
393                                                output_shape)
394      self._value_info_proto.append(value_info_proto)
395
396  # Remove proto in inputs_proto and consts_proto
397  # if proto is not used as input or an output in ONNX
398  def _clean_graph(self):
399    in_out = self.all_node_inputs + self.outputs
400    self._inputs_proto = list(
401        filter(lambda x: x.name in in_out, self.inputs_proto))
402    self._consts_proto = list(
403        filter(lambda x: x.name in in_out, self.consts_proto))
404
405  def _fix_data_type(self):
406    self.inputs_proto = self._data_type_caster(self.inputs_proto,
407                                               self.data_type_cast_map)
408    self.consts_proto = self._data_type_caster(self.consts_proto,
409                                               self.data_type_cast_map)
410
411  @classmethod
412  def _data_type_caster(cls, protos, data_type_cast_map):
413    """Cast to a new data type if node name is in data_type_cast_map.
414    Be used to process protos to match ONNX type constraints.
415
416    :param protos: Target protos.
417      TensorProto for inputs and ValueInfoProto for consts.
418    :param data_type_cast_map: A {node.name: new_data_type} dict.
419    :return: Processed protos.
420    """
421    if not data_type_cast_map:
422      return protos
423    result = []
424    for proto in protos:
425      new_proto = proto
426      if proto.name in data_type_cast_map:
427        new_data_type = data_type_cast_map[proto.name]
428        if type(proto) == TensorProto and proto.data_type != new_data_type:
429          field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[
430              mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[proto.data_type]]
431          vals = getattr(proto, field)
432          new_proto = make_tensor(
433              name=proto.name,
434              data_type=new_data_type,
435              dims=proto.dims,
436              vals=vals)
437        elif type(
438            proto
439        ) == ValueInfoProto and proto.type.tensor_type.elem_type != new_data_type:
440          new_proto.type.tensor_type.elem_type = new_data_type
441      result.append(new_proto)
442    return result
443
444  def make_graph_proto(self):
445    self._clean_graph()
446    self._fix_data_type()
447
448    if IS_PYTHON3:
449      params = list(inspect.signature(make_graph).parameters.keys())
450    else:
451      params = inspect.getargspec(make_graph).args
452
453    kwargs = {
454        "initializer": self.consts_proto,
455        "value_info": self.value_info_proto
456    }
457
458    return make_graph(self.nodes_proto, self._name, self.inputs_proto,
459                      self.outputs_proto,
460                      **dict([(k, kwargs[k]) for k in kwargs if k in params]))
461