1from onnx_tf.common import IS_PYTHON3
2
3
4def convert_tf(attr):
5  return __convert_tf_attr_value(attr)
6
7
8def convert_onnx(attr):
9  return __convert_onnx_attribute_proto(attr)
10
11
12def __convert_tf_attr_value(attr):
13  """ convert Tensorflow AttrValue object to Python object
14  """
15  if attr.HasField('list'):
16    return __convert_tf_list_value(attr.list)
17  if attr.HasField('s'):
18    return attr.s
19  elif attr.HasField('i'):
20    return attr.i
21  elif attr.HasField('f'):
22    return attr.f
23  elif attr.HasField('b'):
24    return attr.b
25  elif attr.HasField('type'):
26    return attr.type
27  elif attr.HasField('shape'):
28    return attr.type
29  elif attr.HasField('tensor'):
30    return attr.tensor
31  else:
32    raise ValueError("Unsupported Tensorflow attribute: {}".format(attr))
33
34
35def __convert_tf_list_value(list_value):
36  """ convert Tensorflow ListValue object to Python object
37  """
38  if list_value.s:
39    return list_value.s
40  elif list_value.i:
41    return list_value.i
42  elif list_value.f:
43    return list_value.f
44  elif list_value.b:
45    return list_value.b
46  elif list_value.tensor:
47    return list_value.tensor
48  elif list_value.type:
49    return list_value.type
50  elif list_value.shape:
51    return list_value.shape
52  elif list_value.func:
53    return list_value.func
54  else:
55    raise ValueError("Unsupported Tensorflow attribute: {}".format(list_value))
56
57
58def __convert_onnx_attribute_proto(attr_proto):
59  """
60  Convert an ONNX AttributeProto into an appropriate Python object
61  for the type.
62  NB: Tensor attribute gets returned as the straight proto.
63  """
64  if attr_proto.HasField('f'):
65    return attr_proto.f
66  elif attr_proto.HasField('i'):
67    return attr_proto.i
68  elif attr_proto.HasField('s'):
69    return str(attr_proto.s, 'utf-8') if IS_PYTHON3 else attr_proto.s
70  elif attr_proto.HasField('t'):
71    return attr_proto.t  # this is a proto!
72  elif attr_proto.HasField('g'):
73    return attr_proto.g
74  elif attr_proto.floats:
75    return list(attr_proto.floats)
76  elif attr_proto.ints:
77    return list(attr_proto.ints)
78  elif attr_proto.strings:
79    str_list = list(attr_proto.strings)
80    if IS_PYTHON3:
81      str_list = list(map(lambda x: str(x, 'utf-8'), str_list))
82    return str_list
83  elif attr_proto.HasField('sparse_tensor'):
84    return attr_proto.sparse_tensor
85  else:
86    raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto))
87