1import inspect 2from itertools import chain 3 4import numpy as np 5from onnx import NodeProto 6from onnx import TensorProto 7from onnx import ValueInfoProto 8from onnx import numpy_helper 9from onnx.helper import make_graph 10from onnx.helper import make_tensor 11from onnx.helper import make_tensor_value_info 12from onnx.helper import mapping 13import tensorflow as tf 14from tensorflow.core.framework.attr_value_pb2 import AttrValue 15from tensorflow.core.framework.node_def_pb2 import NodeDef 16 17from onnx_tf.common import attr_converter 18from onnx_tf.common import attr_translator 19from onnx_tf.common import CONST_MINUS_ONE_INT32 20from onnx_tf.common import CONST_ONE_FP32 21from onnx_tf.common import CONST_ONE_INT32 22from onnx_tf.common import CONST_ZERO_INT32 23from onnx_tf.common import IS_PYTHON3 24from onnx_tf.common import logger 25from onnx_tf.common.data_type import any_dtype_to_onnx_dtype 26 27class TensorflowNode(object): 28 29 def __init__(self, 30 node=None, 31 name=None, 32 inputs=None, 33 outputs=None, 34 attr=None, 35 domain=None, 36 op_type=None): 37 # storing a reference to the original protobuf object 38 if node is None: 39 self.node = None 40 self.name = name or "" 41 self.inputs = inputs or [] 42 self.attr = attr or {} 43 self.domain = domain or "" 44 self.op_type = op_type or "" 45 self.outputs = outputs or self.get_outputs_names() 46 elif isinstance(node, (OnnxNode, NodeProto)): 47 self._load_onnx_node(node) 48 elif isinstance(node, NodeDef): 49 self._load_tf_node(node) 50 51 def _load_onnx_node(self, node): 52 if isinstance(node, NodeProto): 53 node = OnnxNode(node) 54 self.name = node.name 55 self.inputs = node.inputs 56 self.outputs = node.outputs 57 self.attr = node.attrs 58 self.domain = node.domain 59 self.op_type = node.op_type 60 61 def _load_tf_node(self, node): 62 self.node = node 63 self.name = node.name 64 self.inputs = list(node.input) 65 self.attr = {} 66 for key, val in node.attr.items(): 67 new_val = attr_translator.translate_tf(key, val) 68 if isinstance(new_val, AttrValue): 69 new_val = attr_converter.convert_tf(new_val) 70 self.attr[key] = new_val 71 splitted_op_name = node.op.split(".") 72 self.domain = "" if len(splitted_op_name) == 1 else ".".join( 73 splitted_op_name[:-1]) 74 self.op_type = splitted_op_name[-1] 75 self.outputs = self.get_outputs_names() 76 77 def get_outputs_names(self, num=None): 78 """ Helper method to get outputs names. 79 e.g. tf.split: [Split, Split:1, Split:2] 80 81 :param num: Force to get `num` outputs names. 82 :return: List of outputs names. 83 """ 84 if num is None: 85 if "_output_shapes" in self.attr: 86 num = len(self.attr["_output_shapes"]) 87 else: 88 num = 1 89 logger.warning("_output_shapes is not in node.attr. " 90 "The num of output is set to 1 for commonly. " 91 "It will cause problem with case of multiple outputs.") 92 return [ 93 self.name + ":{}".format(i) if i > 0 else self.name for i in range(num) 94 ] 95 96 97class TensorflowGraph(object): 98 99 def __init__(self, graph_def, outputs=(), graph_name="graph"): 100 self._graph_name = graph_name 101 self._graph_def = self._process_graph_def(graph_def) 102 self._nodes = self._create_util_nodes() + [ 103 TensorflowNode(node) for node in self.graph_def.node 104 ] 105 self._nodes_dict = {n.name: n for n in self._nodes} 106 self._outputs = outputs or self.get_output_node_names(self.graph_def) 107 108 @staticmethod 109 def _create_util_nodes(): 110 util_nodes = [(CONST_MINUS_ONE_INT32, np.array([-1]).astype(np.int32)), 111 (CONST_ZERO_INT32, np.array([0]).astype(np.int32)), 112 (CONST_ONE_INT32, np.array([1]).astype(np.int32))] 113 return [ 114 TensorflowNode( 115 op_type="Const", 116 name=name, 117 attr={ 118 "value": value, 119 "dtype": any_dtype_to_onnx_dtype(value.dtype), 120 "_output_shapes": [value.shape] 121 }) for name, value in util_nodes 122 ] 123 124 def get_node_by_name(self, name): 125 node = self._nodes_dict.get(name, None) 126 if node is None: 127 raise ValueError( 128 "Node {} is not found in the graph provided".format(name)) 129 return node 130 131 def _process_graph_def(self, graph_def): 132 if "_output_shapes" not in TensorflowNode(graph_def.node[0]).attr: 133 graph_def = self._add_infer_shapes(graph_def) 134 return graph_def 135 136 @staticmethod 137 def _add_infer_shapes(graph_def): 138 with tf.Graph().as_default(): 139 with tf.Session( 140 config=tf.ConfigProto( 141 graph_options=tf.GraphOptions(infer_shapes=True))) as sess: 142 tf.import_graph_def(graph_def, name="") 143 return sess.graph_def 144 145 @staticmethod 146 def get_output_node_names(graph_def): 147 """Get output node names from GraphDef. 148 149 Args: 150 graph_def: GraphDef object. 151 152 Returns: 153 List of output node names. 154 """ 155 input_names, output_names = set(), set() 156 for node in graph_def.node: 157 output_names.add(node.name) 158 input_names.update(set(node.input)) 159 return list(output_names - input_names) 160 161 def update_nodes(self, nodes): 162 self._nodes = nodes 163 self._nodes_dict = {n.name: n for n in self._nodes} 164 165 @property 166 def graph_def(self): 167 return self._graph_def 168 169 @property 170 def graph_name(self): 171 return self._graph_name 172 173 @property 174 def nodes(self): 175 return self._nodes 176 177 @property 178 def nodes_dict(self): 179 return self._nodes_dict 180 181 @property 182 def outputs(self): 183 return self._outputs 184 185 186# TODO: Move this into ONNX main library 187class OnnxNode(object): 188 """ 189 Reimplementation of NodeProto from ONNX, but in a form 190 more convenient to work with from Python. 191 """ 192 193 def __init__(self, node): 194 self.name = str(node.name) 195 self.op_type = str(node.op_type) 196 self.domain = str(node.domain) 197 self.attrs = dict([(attr.name, 198 attr_translator.translate_onnx( 199 attr.name, attr_converter.convert_onnx(attr))) 200 for attr in node.attribute]) 201 self.inputs = list(node.input) 202 self.outputs = list(node.output) 203 self.node_proto = node 204 205 206class OnnxGraph(object): 207 """ A helper class for making ONNX graph. 208 This class holds all information ONNX graph needs. 209 """ 210 211 def __init__(self, name=None, graph_proto=None): 212 if graph_proto: 213 self._name = graph_proto.name 214 self._inputs_proto = list(graph_proto.input) 215 self._outputs_proto = list(graph_proto.output) 216 self._nodes_proto = list(graph_proto.node) 217 self._consts_proto = list(graph_proto.initializer) 218 self._value_info_proto = list(graph_proto.value_info) 219 self._consts = dict([(init.name, numpy_helper.to_array(init)) 220 for init in graph_proto.initializer]) 221 else: 222 self._name = name or "" 223 self._inputs_proto = [] 224 self._outputs_proto = [] 225 self._nodes_proto = [] 226 self._consts = {} 227 self._consts_proto = [] 228 self._value_info_proto = [] 229 # Either way, data_type_cast_map is empty when initialized. 230 self._data_type_cast_map = {} 231 232 self._add_utility_constants() 233 234 def _add_utility_constants(self): 235 util_consts = {CONST_ONE_FP32: np.array([1.0]).astype(np.float32)} 236 # Add a few useful utility constants: 237 for name, value in util_consts.items(): 238 self.add_const_explicit(name=name, value=value) 239 self.add_const_proto_explicit( 240 name=name, value=value, np_dtype=value.dtype) 241 self.add_input_proto_explicit( 242 name=name, shape=value.shape, np_dtype=value.dtype) 243 244 # This list holds the protobuf objects of type ValueInfoProto 245 # representing the input to the converted ONNX graph. 246 @property 247 def inputs_proto(self): 248 return self._inputs_proto 249 250 @inputs_proto.setter 251 def inputs_proto(self, inputs_proto): 252 self._inputs_proto = inputs_proto 253 254 @property 255 def all_node_inputs(self): 256 return list(chain.from_iterable(map(lambda p: p.input, self._nodes_proto))) 257 258 @property 259 def outputs(self): 260 return list(map(lambda p: p.name, self._outputs_proto)) 261 262 @property 263 def outputs_proto(self): 264 return self._outputs_proto 265 266 # This list holds the protobuf objects of type NodeProto 267 # representing the ops in the converted ONNX graph. 268 @property 269 def nodes_proto(self): 270 return self._nodes_proto 271 272 @nodes_proto.setter 273 def nodes_proto(self, nodes_proto): 274 self._nodes_proto = nodes_proto 275 276 # This dictionary contains a map from the name of the constant 277 # op to the array of values it holds. This is useful because 278 # tensorflow is less eager to know about input values at 279 # graph construction time than ONNX. That is to say, some ONNX 280 # attributes are input tensors in TF. This dictionary extracts 281 # those values of constant tensors that are known at graph 282 # construction time. 283 @property 284 def consts(self): 285 return self._consts 286 287 @consts.setter 288 def consts(self, consts): 289 self._consts = consts 290 291 # Sometimes the constants are used as inputs to ops. This list 292 # holds initializers that creates global constant tensors available 293 # to be accessed by ops as inputs (as oppose to attributes which 294 # is supplied by the `consts` map above). 295 @property 296 def consts_proto(self): 297 return self._consts_proto 298 299 @consts_proto.setter 300 def consts_proto(self, consts_proto): 301 self._consts_proto = consts_proto 302 303 # A map holds nodes name and new data type. Will be used to 304 # process protos to match ONNX type constraints. 305 @property 306 def data_type_cast_map(self): 307 return self._data_type_cast_map 308 309 @data_type_cast_map.setter 310 def data_type_cast_map(self, data_type_cast_map): 311 self._data_type_cast_map = data_type_cast_map 312 313 # This list holds the protobuf objects of type ValueInfoProto 314 # representing the all nodes' outputs to the converted ONNX graph. 315 @property 316 def value_info_proto(self): 317 return self._value_info_proto 318 319 def add_input_proto_explicit(self, 320 name, 321 shape, 322 np_dtype=None, 323 tf_dtype=None, 324 onnx_dtype=None): 325 onnx_dtype = any_dtype_to_onnx_dtype( 326 np_dtype=np_dtype, tf_dtype=tf_dtype, onnx_dtype=onnx_dtype) 327 input_proto = make_tensor_value_info(name, onnx_dtype, shape) 328 self._inputs_proto.append(input_proto) 329 330 def add_input_proto(self, node): 331 name = node.name 332 onnx_dtype = node.attr["dtype"] 333 shape = node.attr["shape"] if node.op_type != "Const" else node.attr[ 334 'value'].shape 335 self.add_input_proto_explicit(name, shape, onnx_dtype=onnx_dtype) 336 337 def add_output_proto(self, node): 338 output_onnx_type = node.attr.get("T", TensorProto.BOOL) 339 for i, output_shape in enumerate(node.attr["_output_shapes"]): 340 output_name = node.name + ":{}".format(i) if i > 0 else node.name 341 self._outputs_proto.append( 342 make_tensor_value_info(output_name, output_onnx_type, output_shape)) 343 344 def add_node_proto(self, node_proto): 345 if not isinstance(node_proto, (list, tuple)): 346 node_proto = [node_proto] 347 self._nodes_proto.extend(node_proto) 348 349 def remove_node_proto(self, names): 350 if not isinstance(names, (list, tuple)): 351 names = [names] 352 self._nodes_proto = list( 353 filter(lambda x: x.name not in names, self._nodes_proto)) 354 355 def add_const_explicit(self, name, value): 356 self._consts[name] = value 357 358 def add_const(self, node): 359 self.add_const_explicit(node.name, node.attr["value"]) 360 361 def add_const_proto_explicit(self, 362 name, 363 value, 364 np_dtype=None, 365 tf_dtype=None, 366 onnx_dtype=None): 367 onnx_dtype = any_dtype_to_onnx_dtype( 368 np_dtype=np_dtype, tf_dtype=tf_dtype, onnx_dtype=onnx_dtype) 369 370 const_dim = len(value.shape) 371 372 if const_dim == 0: 373 raw_values = [value.tolist()] 374 values = [value] 375 else: 376 raw_values = value.flatten().tolist() 377 values = value 378 379 shape = np.array(values).shape 380 const_proto = make_tensor( 381 name=name, data_type=onnx_dtype, dims=shape, vals=raw_values) 382 self._consts_proto.append(const_proto) 383 384 def add_const_proto(self, node): 385 self.add_const_proto_explicit( 386 node.name, node.attr["value"], onnx_dtype=node.attr["dtype"]) 387 388 def add_value_info_proto(self, node): 389 node_onnx_type = node.attr.get("T", TensorProto.BOOL) 390 for i, output_shape in enumerate(node.attr["_output_shapes"]): 391 node_name = node.name + ":{}".format(i) if i > 0 else node.name 392 value_info_proto = make_tensor_value_info(node_name, node_onnx_type, 393 output_shape) 394 self._value_info_proto.append(value_info_proto) 395 396 # Remove proto in inputs_proto and consts_proto 397 # if proto is not used as input or an output in ONNX 398 def _clean_graph(self): 399 in_out = self.all_node_inputs + self.outputs 400 self._inputs_proto = list( 401 filter(lambda x: x.name in in_out, self.inputs_proto)) 402 self._consts_proto = list( 403 filter(lambda x: x.name in in_out, self.consts_proto)) 404 405 def _fix_data_type(self): 406 self.inputs_proto = self._data_type_caster(self.inputs_proto, 407 self.data_type_cast_map) 408 self.consts_proto = self._data_type_caster(self.consts_proto, 409 self.data_type_cast_map) 410 411 @classmethod 412 def _data_type_caster(cls, protos, data_type_cast_map): 413 """Cast to a new data type if node name is in data_type_cast_map. 414 Be used to process protos to match ONNX type constraints. 415 416 :param protos: Target protos. 417 TensorProto for inputs and ValueInfoProto for consts. 418 :param data_type_cast_map: A {node.name: new_data_type} dict. 419 :return: Processed protos. 420 """ 421 if not data_type_cast_map: 422 return protos 423 result = [] 424 for proto in protos: 425 new_proto = proto 426 if proto.name in data_type_cast_map: 427 new_data_type = data_type_cast_map[proto.name] 428 if type(proto) == TensorProto and proto.data_type != new_data_type: 429 field = mapping.STORAGE_TENSOR_TYPE_TO_FIELD[ 430 mapping.TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE[proto.data_type]] 431 vals = getattr(proto, field) 432 new_proto = make_tensor( 433 name=proto.name, 434 data_type=new_data_type, 435 dims=proto.dims, 436 vals=vals) 437 elif type( 438 proto 439 ) == ValueInfoProto and proto.type.tensor_type.elem_type != new_data_type: 440 new_proto.type.tensor_type.elem_type = new_data_type 441 result.append(new_proto) 442 return result 443 444 def make_graph_proto(self): 445 self._clean_graph() 446 self._fix_data_type() 447 448 if IS_PYTHON3: 449 params = list(inspect.signature(make_graph).parameters.keys()) 450 else: 451 params = inspect.getargspec(make_graph).args 452 453 kwargs = { 454 "initializer": self.consts_proto, 455 "value_info": self.value_info_proto 456 } 457 458 return make_graph(self.nodes_proto, self._name, self.inputs_proto, 459 self.outputs_proto, 460 **dict([(k, kwargs[k]) for k in kwargs if k in params])) 461