1from onnx_tf.common import IS_PYTHON3 2 3 4def convert_tf(attr): 5 return __convert_tf_attr_value(attr) 6 7 8def convert_onnx(attr): 9 return __convert_onnx_attribute_proto(attr) 10 11 12def __convert_tf_attr_value(attr): 13 """ convert Tensorflow AttrValue object to Python object 14 """ 15 if attr.HasField('list'): 16 return __convert_tf_list_value(attr.list) 17 if attr.HasField('s'): 18 return attr.s 19 elif attr.HasField('i'): 20 return attr.i 21 elif attr.HasField('f'): 22 return attr.f 23 elif attr.HasField('b'): 24 return attr.b 25 elif attr.HasField('type'): 26 return attr.type 27 elif attr.HasField('shape'): 28 return attr.type 29 elif attr.HasField('tensor'): 30 return attr.tensor 31 else: 32 raise ValueError("Unsupported Tensorflow attribute: {}".format(attr)) 33 34 35def __convert_tf_list_value(list_value): 36 """ convert Tensorflow ListValue object to Python object 37 """ 38 if list_value.s: 39 return list_value.s 40 elif list_value.i: 41 return list_value.i 42 elif list_value.f: 43 return list_value.f 44 elif list_value.b: 45 return list_value.b 46 elif list_value.tensor: 47 return list_value.tensor 48 elif list_value.type: 49 return list_value.type 50 elif list_value.shape: 51 return list_value.shape 52 elif list_value.func: 53 return list_value.func 54 else: 55 raise ValueError("Unsupported Tensorflow attribute: {}".format(list_value)) 56 57 58def __convert_onnx_attribute_proto(attr_proto): 59 """ 60 Convert an ONNX AttributeProto into an appropriate Python object 61 for the type. 62 NB: Tensor attribute gets returned as the straight proto. 63 """ 64 if attr_proto.HasField('f'): 65 return attr_proto.f 66 elif attr_proto.HasField('i'): 67 return attr_proto.i 68 elif attr_proto.HasField('s'): 69 return str(attr_proto.s, 'utf-8') if IS_PYTHON3 else attr_proto.s 70 elif attr_proto.HasField('t'): 71 return attr_proto.t # this is a proto! 72 elif attr_proto.HasField('g'): 73 return attr_proto.g 74 elif attr_proto.floats: 75 return list(attr_proto.floats) 76 elif attr_proto.ints: 77 return list(attr_proto.ints) 78 elif attr_proto.strings: 79 str_list = list(attr_proto.strings) 80 if IS_PYTHON3: 81 str_list = list(map(lambda x: str(x, 'utf-8'), str_list)) 82 return str_list 83 elif attr_proto.HasField('sparse_tensor'): 84 return attr_proto.sparse_tensor 85 else: 86 raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto)) 87