1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines 18"""TF: Tensorflow frontend.""" 19from __future__ import absolute_import as _abs 20from __future__ import print_function 21 22import warnings 23# Numpy support 24import numpy as np 25 26import tvm 27from .. import symbol as _sym 28from .. import graph as _graph 29from .. compiler import graph_util, build_module 30from .common import get_nnvm_op, AttrConverter as AttrConvert 31 32__all__ = ['from_tensorflow'] 33 34class AttrCvt(object): 35 """A Wrapper to handle some common jobs: 36 """ 37 def __init__(self, op_name, transforms=None, 38 excludes=None, disables=None, ignores=None, 39 extras=None, custom_check=None): 40 self._op_name = op_name 41 self._transforms = transforms if transforms else {} 42 self._excludes = excludes if excludes else [] 43 self._disables = disables if disables else [] 44 self._ignores = ignores if ignores else [] 45 self._extras = extras if extras else {} 46 self._custom_check = custom_check 47 48 def __call__(self, inputs, attrs, *args): 49 self._ignores.append('_output_shapes') 50 self._ignores.append('_input_shapes') 51 self._ignores.append('T') 52 self._ignores.append('use_cudnn_on_gpu') 53 self._ignores.append('_node_name') 54 self._ignores.append('is_training') 55 self._ignores.append('_target_layout') 56 self._ignores.append('_input_0d_mismatch') 57 # Retain the names 58 try: 59 attrs['name'] = attrs['_node_name'] 60 except KeyError: 61 pass 62 return AttrConvert(self._op_name, self._transforms, self._excludes, 63 self._disables, self._ignores, self._extras, 64 self._custom_check)(inputs, attrs, *args) 65 66def _get_pad_pair(input1d, kernel1d, stride1d): 67 if input1d % stride1d == 0: 68 pad = max(kernel1d - stride1d, 0) 69 else: 70 pad = max(kernel1d - (input1d % stride1d), 0) 71 72 pad_before = pad // 2 73 pad_after = pad - pad_before 74 75 return [pad_before, pad_after] 76 77def _math_name_picker(surfix): 78 def _impl(attr): 79 return 'broadcast_' + surfix 80 return _impl 81 82def _dimension_picker(prefix, surfix=''): 83 def _impl(attr): 84 kernel = attr['kernel_shape'] 85 if len(kernel) == 2: 86 return prefix + '2d' + surfix 87 raise tvm.error.OpAttributeUnImplemented( 88 'Non-2D kernels are not supported for operator {}.'.format(prefix)) 89 return _impl 90 91def _dimension_constraint(): 92 def _dim_check(attrs): 93 if len(attrs['kernel_shape']) == 2: 94 return True 95 return False 96 return _dim_check, "Only 2d kernel supported." 97 98def _infer_channels(inputs, params, transpose=False): 99 """A hack for getting 'channles' or 'units' since tensorflow don't provide 100 these attributes. We check the shape of weights provided to get the number. 101 """ 102 g = _graph.create(inputs) 103 shape_dict = {k: v.shape for k, v in params.items()} 104 _, out_shapes = graph_util.infer_shape(g, **shape_dict) 105 channels = out_shapes[0][0] if not transpose else out_shapes[0][1] 106 return channels 107 108def _rsqrt(): 109 def _impl(inputs, attr, *args): 110 return AttrCvt(op_name="__pow_scalar__", extras={'scalar': -0.5})(inputs, attr) 111 return _impl 112 113def _argx(func, func_name): 114 """ A common wrapper for argmin and argmax operations """ 115 def _impl(inputs, attr, params): 116 try: 117 # In Tensorflow, `axis` argument is a Tensor, not attribute. We 118 # support the case where it inputs from a scalar constant. 119 axis_input_name = inputs[1].list_output_names()[0] 120 axis_input_vlaue = params[axis_input_name].asnumpy()[0] 121 except (IndexError, KeyError): 122 raise TypeError( \ 123 "Unsupported argument for `{}` : `axis` should be a constant".format(func_name)) 124 return func(inputs[0], axis=axis_input_vlaue, keepdims=False) 125 return _impl 126 127def _elemwise(name): 128 def _impl(inputs, attr, *args): 129 assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs)) 130 op_name = _math_name_picker(name)(attr) 131 return get_nnvm_op(op_name)(*inputs) 132 return _impl 133 134def _pooling(name): 135 def _impl(inputs, attr, params): 136 137 attr['data_format'] = attr['data_format'].decode("utf-8") 138 flip_layout = False 139 140 input_shape = attr['_input_shapes'][inputs[0]] 141 142 if attr['data_format'] == 'NHWC': 143 attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2]) 144 attr['strides'] = (attr['strides'][1], attr['strides'][2]) 145 elif attr['data_format'] == 'NCHW': 146 attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3]) 147 attr['strides'] = (attr['strides'][2], attr['strides'][3]) 148 else: 149 msg = 'Value {} in attribute "data_format" of operator Pooling is not valid.' 150 raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) 151 152 if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": 153 tmp_shape = attr['_input_shapes'][inputs[0]] 154 input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] 155 inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2)) 156 attr['data_format'] = "NCHW" 157 flip_layout = True 158 159 # Fix padding 160 attr['padding'] = attr['padding'].decode("utf-8") 161 162 if attr['padding'] == 'VALID': 163 attr['padding'] = [0, 0] 164 elif attr['padding'] == 'SAME': 165 stride_h, stride_w = attr['strides'] 166 kernel_h, kernel_w = attr['kernel_shape'] 167 if attr['data_format'] == 'NHWC': 168 in_h = input_shape[1] 169 in_w = input_shape[2] 170 else: 171 in_h = input_shape[2] 172 in_w = input_shape[3] 173 174 pad_v = _get_pad_pair(in_h, kernel_h, stride_h) 175 pad_h = _get_pad_pair(in_w, kernel_w, stride_w) 176 177 attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]] 178 else: 179 msg = 'Value {} in attribute "padding" of operator Pooling is not valid.' 180 raise tvm.error.OpAttributeUnImplemented(msg.format(attr['padding'])) 181 182 if name == "avg_pool": 183 attr['count_include_pad'] = False 184 185 out = AttrCvt( 186 op_name=_dimension_picker(name), 187 transforms={ 188 'kernel_shape':'pool_size', 189 'data_format':'layout'}, 190 ignores=['ksize'], 191 extras={'ceil_mode': False}, 192 custom_check=_dimension_constraint())(inputs, attr) 193 194 if flip_layout: 195 out = _sym.transpose(out, axes=(0, 2, 3, 1)) 196 197 return out 198 return _impl 199 200def _conv(opname): 201 def _impl(inputs, attr, params): 202 attr['data_format'] = attr['data_format'].decode("utf-8") 203 flip_layout = False 204 205 # NCHW Layout require weights transpose 206 if attr['data_format'] == 'NCHW': 207 tmp_shape = attr['_input_shapes'][inputs[1]] 208 if opname == 'conv': 209 tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] 210 inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1)) 211 else: 212 tmp_shape = [tmp_shape[ii] for ii in (2, 3, 0, 1)] 213 inputs[1] = _sym.transpose(inputs[1], axes=(2, 3, 0, 1)) 214 attr['_input_shapes'][inputs[1]] = tmp_shape 215 216 input_shape = attr['_input_shapes'][inputs[0]] 217 weights_shape = attr['_input_shapes'][inputs[1]] 218 219 if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": 220 input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] 221 inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2)) 222 if opname == 'conv': 223 weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)] 224 inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1)) 225 else: 226 weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)] 227 inputs[1] = _sym.transpose(inputs[1], axes=(2, 3, 0, 1)) 228 229 attr['data_format'] = "NCHW" 230 attr['strides'] = [attr['strides'][ii] for ii in (0, 3, 1, 2)] 231 flip_layout = True 232 233 if attr['data_format'] == 'NHWC': 234 kernel_h, kernel_w, _, depth_mult = weights_shape 235 attr['kernel_shape'] = (weights_shape[0], weights_shape[1]) 236 if opname == 'conv': 237 attr['channels'] = weights_shape[3] 238 else: 239 attr['channels'] = input_shape[3] * depth_mult 240 241 if 'dilations' in attr: 242 attr['dilations'] = (attr['dilations'][1], attr['dilations'][2]) 243 attr['strides'] = (attr['strides'][1], attr['strides'][2]) 244 elif attr['data_format'] == 'NCHW': 245 _, depth_mult, kernel_h, kernel_w = weights_shape 246 attr['kernel_shape'] = (weights_shape[2], weights_shape[3]) 247 if opname == 'conv': 248 attr['channels'] = weights_shape[0] 249 else: 250 attr['channels'] = input_shape[1] * depth_mult 251 if attr['channels'] < 0: 252 attr['channels'] *= -1 253 254 if 'dilations' in attr: 255 attr['dilations'] = (attr['dilations'][2], attr['dilations'][3]) 256 attr['strides'] = (attr['strides'][2], attr['strides'][3]) 257 else: 258 msg = 'Value {} in attribute "data_format" of operator Conv is not valid.' 259 raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format'])) 260 261 262 if opname == 'depthwise': 263 if depth_mult > 1: 264 raise tvm.error.OpNotImplemented('depth_mult > 1 of operator DepthwiseConv2dNative' 265 ' is not supported.') 266 attr['groups'] = attr['channels'] 267 268 # Fix padding 269 attr['padding'] = attr['padding'].decode("utf-8") 270 271 if attr['padding'] == 'VALID': 272 attr['padding'] = [0, 0] 273 elif attr['padding'] == 'SAME': 274 stride_h, stride_w = attr['strides'] 275 kernel_h, kernel_w = attr['kernel_shape'] 276 if attr['data_format'] == 'NHWC': 277 in_h = input_shape[1] 278 in_w = input_shape[2] 279 else: 280 in_h = input_shape[2] 281 in_w = input_shape[3] 282 283 dilation_h = attr['dilations'][0] 284 dilation_w = attr['dilations'][1] 285 dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 286 dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 287 pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) 288 pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) 289 290 if attr['data_format'] == 'NHWC': 291 inputs[0] = _sym.pad(data=inputs[0], 292 pad_width=((0, 0), 293 (pad_v[0], pad_v[1]), 294 (pad_h[0], pad_h[1]), 295 (0, 0))) 296 else: 297 inputs[0] = _sym.pad(data=inputs[0], 298 pad_width=((0, 0), 299 (0, 0), 300 (pad_v[0], pad_v[1]), 301 (pad_h[0], pad_h[1]))) 302 303 attr['padding'] = [0, 0] 304 305 else: 306 msg = 'Value {} in attribute "padding" of operator Conv is not valid.' 307 raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) 308 309 if 'kernel_layout' not in attr: 310 if opname == 'conv': 311 attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW' 312 else: 313 attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW' 314 315 out = AttrCvt( 316 op_name=_dimension_picker('conv'), 317 transforms={ 318 'kernel_shape': 'kernel_size', 319 'data_format': 'layout', 320 'dilations': ('dilation', (0, 0)), 321 'group': ('groups', 1)}, 322 extras={'use_bias': len(inputs) == 3}, 323 custom_check=_dimension_constraint())(inputs, attr) 324 325 if flip_layout: 326 out = _sym.transpose(out, axes=(0, 2, 3, 1)) 327 328 return out 329 return _impl 330 331def _decode_image(): 332 def _impl(inputs, attr, params): 333 # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. 334 warnings.warn("DecodeJpeg: It's a pass through, " 335 "please handle preprocessing before input") 336 return inputs[0] 337 return _impl 338 339def _cast(): 340 def _impl(inputs, attr, params): 341 # Convert from tensorflow Dtype to str 342 attr['DstT'] = attr['DstT'].name 343 return AttrCvt(op_name='cast', transforms={'DstT': 'dtype'}, 344 ignores=['SrcT', 'Truncate'])(inputs, attr) 345 return _impl 346 347def _expand_dims(): 348 def _impl(inputs, attr, params): 349 dim_input = inputs.pop(1) 350 axis = params[dim_input.list_output_names()[0]] 351 params.pop(dim_input.list_output_names()[0]) 352 return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0]) 353 return _impl 354 355def _resize_bilinear(): 356 def _impl(inputs, attr, params): 357 attr['size'] = attr['_output_shapes'][0][1:3] 358 inputs.pop(1) 359 # NHWC 360 attr['layout'] = 'NHWC' 361 362 return AttrCvt(op_name="resize", 363 ignores=['Tdim'], 364 extras={'method': "BILINEAR"})(inputs, attr) 365 return _impl 366 367def _check_numerics(): 368 def _impl(inputs, attr, params): 369 # Making a copy node assuming no need to verify 370 return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr) 371 return _impl 372 373 374def _matmul(): 375 def _impl(inputs, attr, params): 376 channels = _infer_channels(inputs[1], params, not attr['transpose_b']) 377 if attr['transpose_a']: 378 inputs[0] = _sym.transpose(inputs[0], axes=(1, 0)) 379 if not attr['transpose_b']: 380 inputs[1] = _sym.transpose(inputs[1], axes=(1, 0)) 381 return AttrCvt(op_name="dense", 382 extras={'use_bias': False, 'units': channels}, 383 ignores=['transpose_a', 'transpose_b', 'T'])(inputs, attr) 384 385 return _impl 386 387def _undef(): 388 def _impl(inputs, attr, params): 389 return _sym.__undef__() 390 return _impl 391 392def _identity(): 393 def _impl(inputs, attr, params): 394 return inputs[0] 395 return _impl 396 397def _concatV2(): 398 def _impl(inputs, attr, params): 399 pop_node = inputs.pop(len(inputs)-1) 400 axis = params[pop_node.list_output_names()[0]] 401 params.pop(pop_node.list_output_names()[0]) 402 return AttrCvt( 403 op_name="concatenate", ignores=['T', 'N', 'Tidx'], 404 extras={'axis': axis.asnumpy()[0]})(inputs, attr) 405 return _impl 406 407def _concat(): 408 def _impl(inputs, attr, params): 409 pop_node = inputs.pop(0) 410 axis = params[pop_node.list_output_names()[0]] 411 params.pop(pop_node.list_output_names()[0]) 412 return AttrCvt( 413 op_name="concatenate", ignores=['N'], 414 extras={'axis': axis.asnumpy()[0]})(inputs, attr) 415 return _impl 416 417def _pack(): 418 def _impl(inputs, attr, params): 419 axis = int(attr["axis"]) 420 inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs] 421 return _sym.concatenate(*inputs_reshaped, axis=axis, name=attr["_node_name"]) 422 423 return _impl 424 425def _slice(): 426 def _impl(inputs, attr, params): 427 begin = params.pop(inputs[1].list_output_names()[0]).asnumpy().tolist() 428 size = params.pop(inputs[2].list_output_names()[0]).asnumpy().tolist() 429 data_shape = attr['_input_shapes'][inputs[0]] 430 data_dim = len(data_shape) 431 end = size 432 for i in range(data_dim): 433 if size[i] == -1: 434 end[i] = data_shape[i] - begin[i] 435 else: 436 end[i] += begin[i] 437 return _sym.strided_slice(inputs[0], begin=begin, end=size) 438 return _impl 439 440def _reshape(): 441 def _impl(inputs, attr, params): 442 try: 443 pop_node = inputs[1] 444 shape_arg = params.pop(pop_node.list_output_names()[0]) 445 inputs.pop(1) 446 447 return AttrCvt( 448 op_name="reshape", 449 extras={'shape':tuple(shape_arg.asnumpy())}, 450 ignores=['Tshape'])(inputs, attr) 451 except KeyError: 452 # Shape operator is already pruned, hence 453 # try to infer shape by precompute prune if possible. 454 if all(in_node in params for in_node in inputs[1].list_input_names()): 455 graph = _graph.create(_sym.Group(inputs[1])) 456 params_pre = {k: params[k] for k in inputs[1].list_input_names()} 457 params_new = build_module._run_graph(graph, params_pre) 458 inputs.pop(1) 459 return AttrCvt( 460 op_name="reshape", 461 extras={'shape':tuple(params_new[0].asnumpy().flatten())}, 462 ignores=['Tshape'])(inputs, attr) 463 raise tvm.error.OpAttributeUnimplemented( 464 'Attribute "dynamic shape" of operator Reshape is not supported.') 465 return _impl 466 467def _bias_add(): 468 def _impl(inputs, attr, params): 469 if attr['data_format'].decode("utf-8") == 'NCHW': 470 bias = _sym.reshape(inputs[1], newshape=(1, -1, 1, 1)) 471 else: 472 bias = inputs[1] 473 return _sym.broadcast_add(inputs[0], bias) 474 return _impl 475 476def _squeeze(): 477 def _impl(inputs, attr, params): 478 return AttrCvt( 479 op_name="squeeze", 480 transforms={'squeeze_dims':'axis'}, 481 ignores=['T'])(inputs, attr) 482 return _impl 483 484def _fused_batch_norm(): 485 def _impl(inputs, attr, params): 486 # Tensorflow: (data, gamma, beta, moving_mean, moving_variance) 487 # NNVM: (data, gamma, beta, moving_mean, moving_varience) 488 axis = 3 489 need_cast = False 490 491 if 'data_format' in attr: 492 attr['data_format'] = attr['data_format'].decode("utf-8") 493 if attr['data_format'] == 'NCHW': 494 axis = 1 495 if 'U' in attr: 496 need_cast = True 497 inputs[0] = _sym.cast(inputs[0], dtype=attr['U'].name) 498 499 out = AttrCvt(op_name='batch_norm', 500 transforms={'scale_after_normalization':'scale', 501 'variance_epsilon':'epsilon'}, 502 extras={'axis': axis}, 503 ignores=['data_format', 'U'], 504 disables=['momentum'])(inputs, attr) 505 506 if need_cast: 507 out = _sym.cast(out, dtype=attr['T'].name) 508 return out 509 return _impl 510 511def _batch_norm(): 512 def _impl(inputs, attr, params): 513 # Rearrange inputs from 514 # (data, moving_mean, moving_variance, beta, gamma) 515 # to 516 # (data, gamma, beta, moving_mean, moving_var) 517 new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]] 518 519 axis = 3 520 if 'data_format' in attr: 521 attr['data_format'] = attr['data_format'].decode("utf-8") 522 if attr['data_format'] == 'NCHW': 523 axis = 1 524 525 return AttrCvt( 526 op_name='batch_norm', 527 transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'}, 528 extras={'axis': axis}, 529 ignores=['data_format'], 530 disables=['momentum'])(new_inputs, attr) 531 return _impl 532 533def _relu6(): 534 def _impl(inputs, attr, params): 535 return _sym.clip(inputs[0], a_min=0, a_max=6, name=attr['_node_name']) 536 return _impl 537 538def _shape(): 539 def _impl(inputs, attr, params): 540 return np.array(attr['_input_shapes'][inputs[0]], dtype='int32') 541 return _impl 542 543def _fill(): 544 def _impl(inputs, attr, params): 545 fill_arg = params.pop(inputs.pop(1).list_output_names()[0]) 546 new_inputs = [] 547 return AttrCvt( 548 op_name='full', 549 extras={'shape':inputs[0], 550 'fill_value':fill_arg.asnumpy()[0], 'dtype':attr['T'].name}, 551 ignores=['index_type', 'T'])(new_inputs, attr) 552 return _impl 553 554def _lrn(): 555 def _impl(inputs, attr, params): 556 attr_new = {} 557 depth_radius = attr.get('depth_radius', 5) 558 size = (depth_radius * 2) + 1 559 attr_new['axis'] = 3 # Fix axis, NHWC format 560 attr_new['size'] = size 561 attr_new['bias'] = attr.get('bias', 1) 562 attr_new['alpha'] = attr.get('alpha', 1) * size 563 attr_new['beta'] = attr.get('beta', 0.5) 564 return AttrCvt(op_name='lrn')(inputs, attr_new) 565 return _impl 566 567def _sum(): 568 def _impl(inputs, attr, params): 569 axis = params.pop(inputs[1].list_output_names()[0]).asnumpy() 570 # convert to tuple for preventing invalid parameter format error 571 axis = tuple(axis) 572 return AttrCvt( 573 op_name='sum', 574 extras={'axis': axis}, 575 transforms={'keep_dims':'keepdims'}, 576 ignores=['name', 'Tidx'])(inputs[0], attr) 577 return _impl 578 579def _square(): 580 def _impl(inputs, attr, params): 581 return _sym.elemwise_mul(inputs[0], inputs[0]) 582 return _impl 583 584def _gather_v2(): 585 "Tensorflow now support only gatherv2" 586 def _impl(inputs, attr, params): 587 axis = params[inputs.pop(2).list_output_names()[0]].asnumpy()[0] 588 new_input = [] 589 new_input.append(inputs.pop(0)) 590 new_input.append(inputs.pop(0)) 591 return AttrCvt( 592 op_name="take", 593 extras={'axis':axis}, 594 ignores=['Tindices', 'Tparams', 'validate_indices', \ 595 'Taxis', '_class'])(new_input, attr) 596 return _impl 597 598def _infer_out_shapes(inputs, params): 599 """A method to get the output shape of an intermediate node in the NNVM graph.""" 600 g = _graph.create(inputs) 601 shape_dict = {k: v.shape for k, v in params.items()} 602 _, out_shapes = graph_util.infer_shape(g, **shape_dict) 603 return out_shapes 604 605def _stridedSlice(): 606 def _impl(inputs, attr, params): 607 """Strided Slice. 608 Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice 609 Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/ 610 tensorflow/core/util/strided_slice_op.cc#L147-L368 611 """ 612 begin = params.pop(inputs[1].list_output_names()[0]).asnumpy().tolist() 613 end = params.pop(inputs[2].list_output_names()[0]).asnumpy().tolist() 614 stride = params.pop(inputs[3].list_output_names()[0]).asnumpy().tolist() 615 begin_mask = int(attr.get('begin_mask', 0)) 616 end_mask = int(attr.get('end_mask', 0)) 617 ellipsis_mask = int(attr.get('ellipsis_mask', 0)) 618 new_axis_mask = int(attr.get('new_axis_mask', 0)) 619 shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) 620 data_shape = attr['_input_shapes'][inputs[0]] 621 data_dim = len(data_shape) 622 stride_dim = len(stride) 623 624 def _transform_mask(stride_dim, ellipsis_mask): 625 """Handle mask inputs to create new begin, end, stride and output shape""" 626 m_begin = [0] * data_dim 627 m_end = [0] * data_dim 628 m_stride = [0] * data_dim 629 fshape_indices = [] 630 #Count new axis after ellipsis_mask, consider while applying ellipsis_mask. 631 ellipsis_seen = False 632 new_axes_after_ellipsis = 0 633 for i in range(stride_dim): 634 mask = 1 << i 635 if ellipsis_seen and (mask & new_axis_mask) != 0: 636 new_axes_after_ellipsis += 1 637 if (mask & ellipsis_mask) != 0: 638 ellipsis_seen = True 639 if not ellipsis_seen: 640 #Used later for extending the stride attributes in the below loop. 641 ellipsis_mask |= (1 << stride_dim) 642 stride_dim += 1 643 final_index = 0 644 for index in range(stride_dim): 645 mask = 1 << index 646 if mask & ellipsis_mask: 647 #Identify the end index for applying ellipsis_mask 648 to_index = min(((data_dim - (stride_dim-index)) + 1 \ 649 + new_axes_after_ellipsis), data_dim) 650 for i in range(final_index, to_index): 651 m_begin[final_index] = 0 652 m_end[final_index] = data_shape[final_index] 653 m_stride[final_index] = 1 654 fshape_indices.append(final_index) 655 final_index += 1 656 elif mask &new_axis_mask: 657 fshape_indices.append(-1) 658 elif not mask & new_axis_mask: 659 if final_index == len(m_begin): 660 break 661 if mask & begin_mask: 662 m_begin[final_index] = data_shape[final_index] \ 663 if stride[index] < 0 else 0 664 elif begin[index]: 665 m_begin[final_index] = begin[index] 666 if mask & end_mask: 667 m_end[final_index] = 0 if stride[index] < 0 \ 668 else data_shape[final_index] 669 elif end[index]: 670 m_end[final_index] = end[index] 671 m_stride[final_index] = stride[index] 672 if mask & shrink_axis_mask: 673 #Tensorflow make axis with shrink_axis_mask as dimension 1 674 m_begin[final_index] = data_shape[final_index] + begin[index] \ 675 if begin[index] < 0 else begin[index] 676 m_end[final_index] = begin[index] + 1 677 m_stride[final_index] = 1 678 fshape_indices.append(-2) 679 else: 680 fshape_indices.append(final_index) 681 682 final_index += 1 683 return m_begin, m_end, m_stride, fshape_indices 684 685 fshape_indices = None 686 if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: 687 begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) 688 out = _sym.strided_slice(inputs[0], begin=begin, end=end, stride=stride) 689 out_shape = _infer_out_shapes(out, params)[0] 690 if not fshape_indices: 691 fshape_indices = range(len(out_shape)) 692 693 #Create final output shape. 694 final_output = [] 695 for gather_index in fshape_indices: 696 if gather_index == -1: 697 final_output.append(1) 698 elif gather_index == -2: 699 pass 700 else: 701 final_output.append(out_shape[gather_index]) 702 # Prevent 0-dim tensors which are not accepted by nnvm 703 if not final_output: 704 final_output.append(1) 705 return _sym.reshape(out, shape=tuple(final_output)) 706 return _impl 707 708def _LSTMBlockCell(): 709 def _impl(inputs, in_state_c, in_state_h, attr, params): 710 """LSTM Block cell. 711 Calculations are described in: https://github.com/tensorflow/tensorflow/blob/ 712 r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114 713 714 Parameters 715 ---------- 716 inputs : nnvm.Symbol 717 Input data 718 in_state_c: list of nnvm.Symbol 719 Cell state input values for all the layers 720 in_state_h: list of nnvm.Symbol 721 Hidden state input values for all the layers 722 attrs : dict 723 Dict of operator attributes 724 params : dict 725 List of pretrained weights and bias 726 727 Returns 728 ------- 729 sym : nnvm.Symbol 730 Converted nnvm Symbol 731 output: nnvm.Symbol 732 Output state value. 733 """ 734 in_data = inputs[0] 735 in_weight = inputs[3] 736 in_bias = inputs[7] 737 forget_bias = attr.pop('forget_bias') 738 input_shape = attr['_input_shapes'][inputs[0]] 739 weight_shape = attr['_input_shapes'][inputs[3]] 740 batch_size, input_size = input_shape[0], input_shape[1] 741 num_hidden_layers = weight_shape[1] 742 num_hidden = num_hidden_layers // 4 743 744 in_data = _sym.reshape(in_data, 745 shape=(batch_size, input_size)) 746 ixh = _sym.concatenate(*[in_data, in_state_h], axis=1) 747 in_weight = _sym.transpose(in_weight) 748 gates = _sym.dense(ixh, in_weight, in_bias, use_bias=True, 749 units=num_hidden_layers) 750 gate_list = _sym.split(gates, indices_or_sections=4, axis=1) 751 in_gate = _sym.sigmoid(gate_list[0]) 752 in_transform = _sym.tanh(gate_list[1]) 753 forget_gate = _sym.sigmoid(gate_list[2]) 754 forget_gate = forget_gate + forget_bias 755 out_gate = _sym.sigmoid(gate_list[3]) 756 next_c = _sym.broadcast_add(_sym.broadcast_mul(forget_gate, in_state_c), 757 _sym.broadcast_mul(in_gate, in_transform)) 758 next_h = out_gate * _sym.tanh(next_c) 759 out_state = _sym.concatenate(*[next_c, next_h]) 760 out_state = _sym.reshape(out_state, 761 shape=(2, batch_size, num_hidden)) 762 return next_h, out_state 763 return _impl 764 765 766def _pad(name): 767 def _impl(inputs, attr, params): 768 padlist_key = inputs[1].list_output_names()[0] 769 if padlist_key in params: 770 padlist = params.pop(padlist_key).asnumpy() 771 else: 772 raise tvm.error.OpAttributeRequired( 773 'Required attribute "{}" not found in operator Pad.'.format(padlist_key)) 774 paddings = tuple([tuple(l) for l in padlist]) 775 attr['pad_width'] = paddings 776 attr['pad_value'] = 0 777 new_inputs = [inputs[0]] 778 if name == 'PadV2': 779 constant_values = params.pop(inputs[2].list_output_names()[0]).asnumpy() 780 attr['pad_value'] = constant_values[0] 781 return AttrCvt( 782 op_name='pad', 783 ignores=['Tpaddings'],)(new_inputs, attr) 784 return _impl 785 786 787def _transpose(): 788 def _impl(inputs, attr, params): 789 # If perm is not specified, axes is left empty, 790 # otherwise its value is get from params 791 param_name = inputs[1].list_output_names()[0] 792 axes = params.get(param_name, tvm.nd.array([])).asnumpy() 793 return _sym.transpose(inputs[0], axes=tuple(axes)) 794 return _impl 795 796def _rank(): 797 def _impl(inputs, attr, params): 798 input_shape = attr['_input_shapes'][inputs[0]] 799 800 name = attr["_node_name"] 801 params[name] = tvm.nd.array([len(input_shape)]) 802 return _sym.Variable(name=name, shape=params[name].shape) 803 return _impl 804 805def _range(): 806 def _impl(inputs, attr, params): 807 start = params.pop(inputs[0].list_output_names()[0]).asnumpy()[0] 808 limit = params.pop(inputs[1].list_output_names()[0]).asnumpy()[0] 809 delta = params.pop(inputs[2].list_output_names()[0]).asnumpy()[0] 810 811 name = attr["_node_name"] 812 params[name] = tvm.nd.array([start, limit, delta]) 813 return _sym.Variable(name=name, shape=params[name].shape) 814 return _impl 815 816def _elu(): 817 def _impl(inputs, attr, params): 818 alpha = 1.0 819 return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0]) 820 return _impl 821 822def _selu(): 823 def _impl(inputs, attr, params): 824 alpha = 1.6732632423543772848170429916717 825 gamma = 1.0507009873554804934193349852946 826 return gamma * (-alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0])) 827 return _impl 828 829def _mean(): 830 def _impl(inputs, attr, params): 831 axis = params.pop(inputs[1].list_output_names()[0]) 832 return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'], 833 transforms={'keep_dims': 'keepdims'}, 834 extras={'axis': tuple(axis.asnumpy())})(inputs[0], attr) 835 return _impl 836 837def _broadcast(name): 838 def _impl(inputs, attr, params): 839 op_name = _math_name_picker(name)(attr) 840 return AttrCvt( 841 op_name=op_name, 842 ignores=['name', 'Tidx'] 843 )(inputs, attr) 844 return _impl 845 846def _split(has_size_vector): 847 # TF documentation https://www.tensorflow.org/api_docs/python/tf/split 848 def _impl(inputs, attr, params): 849 try: 850 # order and number of inputs are different: 851 # if has_size_vector: 852 # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v 853 # else: 854 # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split 855 856 # in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow, 857 # we can only support constants 858 if has_size_vector: 859 input_node_index = 0 860 input_axis_index = 2 861 size_splits_input_name = inputs[1].list_output_names()[0] 862 size_splits = params[size_splits_input_name].asnumpy() 863 section_beginnings = np.cumsum(size_splits)[:-1] 864 indices_or_sections = tuple(section_beginnings) 865 else: 866 input_node_index = 1 867 input_axis_index = 0 868 indices_or_sections = attr['num_split'] 869 input_node = inputs[input_node_index] 870 axis_input_name = inputs[input_axis_index].list_output_names()[0] 871 axis_input_value = params[axis_input_name].asnumpy()[0] 872 except (IndexError, KeyError): 873 raise TypeError( \ 874 "Unsupported argument for split: `axis` and `num_or_size_splits` " \ 875 "should be constants") 876 return _sym.split(input_node, 877 indices_or_sections=indices_or_sections, 878 axis=axis_input_value) 879 return _impl 880 881def _unpack(): 882 def _impl(inputs, attr, params): 883 input_node = inputs[0] 884 axis = attr['axis'] 885 input_shape = attr['_input_shapes'][input_node] 886 axis_length = input_shape[axis] 887 if axis_length < 0: 888 raise TypeError("Unstack with unknown axis length") 889 splitted = _sym.split(input_node, 890 indices_or_sections=axis_length, 891 axis=axis, 892 name=attr.get('_node_name', 'unstack')) 893 894 return _sym.Group([_sym.squeeze(split_item, axis=axis) for split_item in splitted]) 895 return _impl 896 897def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1): 898 if data in attr['_input_0d_mismatch']: 899 return data if num_newaxis == 1 else \ 900 _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis-1) 901 902 return _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis) 903 904def _logical(name): 905 def _impl(inputs, attr, params): 906 return AttrCvt(op_name=name)(inputs, attr) 907 return _impl 908 909# compatible operators that do NOT require any conversion. 910_identity_list = [] 911 912# _convert_map defines maps of name to converter functor(callable) 913# for 1 to 1 mapping, use Renamer if nothing but name is different 914# use AttrCvt if attributes need to be converted 915# for 1 to N mapping(composed), use custom callable functions 916# for N to 1 mapping, currently not supported(?) 917_convert_map = { 918 'ArgMax' : _argx(_sym.argmax, 'argmax'), 919 'ArgMin' : _argx(_sym.argmin, 'argmin'), 920 'AvgPool' : _pooling('avg_pool'), 921 'BatchNormWithGlobalNormalization' : _batch_norm(), 922 'BiasAdd' : _bias_add(), 923 'Cast' : _cast(), 924 'Ceil' : AttrCvt('ceil'), 925 'CheckNumerics' : _check_numerics(), 926 'Concat' : _concat(), 927 'ConcatV2' : _concatV2(), 928 'Conv2D' : _conv('conv'), 929 'DecodeJpeg' : _decode_image(), 930 'Elu' : _elu(), 931 'ExpandDims' : _expand_dims(), 932 'Floor' : AttrCvt('floor'), 933 'Identity' : _identity(), 934 'MatMul' : _matmul(), 935 'MaxPool' : _pooling('max_pool'), 936 'Add' : _elemwise('add'), 937 'Sub' : _elemwise('sub'), 938 'Mul' : _elemwise('mul'), 939 'RealDiv' : _elemwise('div'), 940 'Maximum' : _elemwise('max'), 941 'Minimum' : _elemwise('min'), 942 'Sum' : _sum(), 943 'Square' : _square(), 944 'Pack' : _pack(), 945 'Slice' : _slice(), 946 'LeakyRelu' : AttrCvt('leaky_relu'), 947 'Relu' : AttrCvt('relu'), 948 'Reshape' : _reshape(), 949 'ResizeBilinear' : _resize_bilinear(), 950 'Selu' : _selu(), 951 'Softmax' : AttrCvt('softmax', {'axis': ('axis', 1)}), 952 'Rsqrt' : _rsqrt(), 953 'Squeeze' : _squeeze(), 954 'FusedBatchNorm' : _fused_batch_norm(), 955 'FusedBatchNormV2' : _fused_batch_norm(), 956 'Relu6' : _relu6(), 957 'DepthwiseConv2dNative' : _conv('depthwise'), 958 'Shape' : _shape(), 959 'Sigmoid' : AttrCvt('sigmoid'), 960 'Fill' : _fill(), 961 'GatherV2' : _gather_v2(), 962 'StridedSlice' : _stridedSlice(), 963 'LRN' : _lrn(), 964 'Pad' : _pad('Pad'), 965 'PadV2' : _pad('PadV2'), 966 'Range' : _range(), 967 'Rank' : _rank(), 968 'Transpose' : _transpose(), 969 'Tanh' : AttrCvt('tanh'), 970 'Mean' : _mean(), 971 'LogicalAnd' : _logical('logical_and'), 972 'LogicalOr' : _logical('logical_or'), 973 'LogicalNot' : _logical('logical_not'), 974 'Less' : _broadcast('less'), 975 'Greater' : _broadcast('greater'), 976 'LessEqual' : _broadcast('less_equal'), 977 'GreaterEqual' : _broadcast('greater_equal'), 978 'Equal' : _broadcast('equal'), 979 'NotEqual' : _broadcast('not_equal'), 980 'Split' : _split(False), 981 'SplitV' : _split(True), 982 'Unpack' : _unpack(), 983} 984 985# _convert_map_rnn defines maps of rnn operator name to 986# converter functor(callable) for 1 to 1 mapping. 987_convert_map_rnn = { 988 'LSTMBlockCell' : _LSTMBlockCell(), 989} 990 991class RecurrentNetworks(object): 992 """Recurrent network layer handlers. 993 994 Handle Layer operations. 995 ToDo: Operators like RNN/GRU layer concepts also can be handled here 996 997 Parameters 998 ---------- 999 nodes : list 1000 list of graph nodes used for tensorflow parsing. 1001 1002 out_rnn : list 1003 List of RecurrentNetwork outputs. This output will be appended to the 1004 'head' nodes of the graph. 1005 1006 graph : tensorflow graph definition object 1007 The loaded tensorflow GraphDef 1008 1009 convert_map : dict 1010 Dict of name : callable, where name is the op's name that 1011 require conversion to nnvm, callable are functions which 1012 take attrs and return (new_op_name, new_attrs) 1013 """ 1014 def __init__(self, nodes, out_rnn, graph, convert_map): 1015 self._graph = graph 1016 self._convert_map = convert_map 1017 self._nodes = nodes 1018 self._out_rnn = out_rnn 1019 self._cur_lstm_layer = 0 1020 self._layer_name_list = [] 1021 self._recurrent_ops_layer_map = { 1022 'LSTMBlockCell' : self._LSTMBlockCellLayer(), 1023 } 1024 1025 def _LSTMBlockCellLayer(self): 1026 """LSTMBlockCell layer handler. 1027 1028 Parameters 1029 ---------- 1030 op_name : str 1031 Operator name, eg:LSTMBlockCell 1032 1033 layer_name : str list 1034 Layer name is used for creating the state input placeholder. 1035 1036 inputs : nnvm.Symbol 1037 Input data 1038 1039 attrs : dict 1040 Dict of operator attributes 1041 1042 params : dict 1043 List of pretrained weights and bias 1044 1045 num_layers : int 1046 Total number of LSTM layer presented in the graph 1047 1048 Returns 1049 ------- 1050 sym : nnvm.sym.Symbol 1051 The returned nnvm symbol 1052 """ 1053 def _impl(op_name, layer_name, inputs, attrs, params, num_layers): 1054 in_state_c_name = layer_name+'_c' 1055 in_state_h_name = layer_name+'_h' 1056 1057 def _init_state(num_layers, batch_size, num_hidden): 1058 """Create the initial states for the first layer in the graph.""" 1059 in_state_c = _sym.Variable(in_state_c_name, 1060 shape=(num_layers, batch_size, num_hidden)) 1061 in_state_h = _sym.Variable(in_state_h_name, 1062 shape=(num_layers, batch_size, num_hidden)) 1063 return in_state_c, in_state_h 1064 1065 def _get_cur_input_state(in_state_c, in_state_h, num_layers, 1066 layer, batch_size, num_hidden): 1067 """Select the appropriate states for the current layer""" 1068 in_state_c_tup = _sym.split(in_state_c, 1069 indices_or_sections=num_layers, axis=0) 1070 in_state_h_tup = _sym.split(in_state_h, 1071 indices_or_sections=num_layers, axis=0) 1072 cur_in_state_c = _sym.reshape(in_state_c_tup[layer], 1073 shape=(batch_size, num_hidden)) 1074 cur_in_state_h = _sym.reshape(in_state_h_tup[layer], 1075 shape=(batch_size, num_hidden)) 1076 return cur_in_state_c, cur_in_state_h 1077 1078 def _LSTMBlockCellWrapper(inputs, attr, params, 1079 num_layers, layer): 1080 """LSTM cell warapper to prepare the inputs""" 1081 input_shape = attr['_input_shapes'][inputs[0]] 1082 weight_shape = attr['_input_shapes'][inputs[3]] 1083 batch_size = input_shape[0] 1084 num_hidden = weight_shape[1] // 4 1085 1086 if layer == 0: 1087 #Create initial states placeholder in case of first layer 1088 in_state_c, in_state_h = _init_state(num_layers, 1089 batch_size, num_hidden) 1090 else: 1091 in_state_c = self._nodes[in_state_c_name] 1092 in_state_h = self._nodes[in_state_h_name] 1093 1094 cur_in_state_c, cur_in_state_h = _get_cur_input_state( \ 1095 in_state_c, in_state_h, 1096 num_layers, layer, 1097 batch_size, num_hidden) 1098 output, out_state = self._convert_map[op_name](inputs, cur_in_state_c, 1099 cur_in_state_h, 1100 attr, params) 1101 return output, out_state, in_state_c, in_state_h 1102 1103 sym, cur_out_state, in_state_c, in_state_h = \ 1104 _LSTMBlockCellWrapper(inputs, attrs, params, 1105 num_layers, self._cur_lstm_layer) 1106 self._nodes[in_state_c_name] = in_state_c 1107 self._nodes[in_state_h_name] = in_state_h 1108 cur_out_state = _sym.expand_dims(cur_out_state, axis=0, num_newaxis=1) 1109 self._out_rnn.append(cur_out_state) 1110 self._cur_lstm_layer += 1 1111 return sym 1112 return _impl 1113 1114 def process_op(self, op_name, inputs, attrs, params): 1115 """Process recurrent layer operators. 1116 1117 List '_recurrent_ops_layer_map' map each Layer based operators with its 1118 layer handlers. Total number of layers are calculated to form the input 1119 data shapes. 1120 1121 Parameters 1122 ---------- 1123 op_name : str 1124 Operator name, such as LSTMBlockCell 1125 1126 inputs : nnvm.Symbol 1127 Input data 1128 1129 attrs : dict 1130 Dict of operator attributes 1131 1132 params : dict 1133 List of pretrained weights and bias 1134 1135 Returns 1136 ------- 1137 sym : nnvm.sym.Symbol 1138 The returned nnvm symbol 1139 """ 1140 def _get_abs_layer_name(node): 1141 """Identify the layer name is already handled. Return the absolute name 1142 """ 1143 if not self._layer_name_list: 1144 self._layer_name_list.append(node.name) 1145 return node.name 1146 1147 for _name in self._layer_name_list: 1148 if _name in node.name: 1149 abs_name = _name 1150 else: 1151 self._layer_name_list.append(node.name) 1152 abs_name = node.name 1153 return abs_name 1154 1155 #Find number of layers of this same operator node in the graph 1156 #and also read the inputs name for the current op. 1157 num_layers = 0 1158 for _, node in enumerate(self._graph.node): 1159 if node.op == op_name: 1160 layer_name = _get_abs_layer_name(node) 1161 num_layers += 1 1162 1163 sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs, 1164 params, num_layers) 1165 return sym 1166 1167class GraphProto(object): 1168 """ A helper class for handling nnvm graph copying from Tensorflow GraphDef. 1169 Definition: 1170 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto 1171 """ 1172 def __init__(self): 1173 self._nodes = {} 1174 self._params = {} 1175 self._output_shapes = {} 1176 self._num_param = 0 1177 self._num_rnn_layer = False 1178 self._outputs_are_0d = {} 1179 self._input_shapes = {} 1180 1181 def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): 1182 """Construct nnvm nodes from tensorflow graph definition - GraphDef. 1183 1184 Follow the tensorflow graph definition to parse and convert it to NNVM. 1185 Some of the assumptions listed below. 1186 1187 -> All Placeholders are considered as graph input. 1188 -> All Const nodes are params. 1189 -> Last node is assumed as graph output. 1190 -> _output_shapes : Graph should be frozen with add_shapes=True. 1191 Or user can pass input shape dictionary optionally. 1192 -> DecodeJpeg, ResizeBilinear: These are dummy operators. 1193 Hence user should handle preprocessing outside. 1194 -> CheckNumerics: No implementation as of now for this. 1195 Just copies input to output. 1196 1197 Parameters 1198 ---------- 1199 graph : tensorflow graph definition object 1200 The loaded tensorflow GraphDef 1201 1202 layout : target layout to be used (Optional) 1203 NCHW only supported now to enable NHWC models on GPU. 1204 1205 shape : Dictionary of input dimensions (Optional) 1206 Graph level input shape dictionary. 1207 1208 outputs : List of output tensor names (Optional) 1209 if not specified then the last node is assumed as graph output. 1210 1211 Returns 1212 ------- 1213 sym : nnvm.sym.Symbol 1214 The returned nnvm symbol 1215 params : dict 1216 A dict of name: tvm.nd.array pairs, used as pretrained weights 1217 """ 1218 1219 try: 1220 from tensorflow.python.framework import tensor_util 1221 except ImportError as e: 1222 raise ImportError( 1223 "Unable to import tensorflow which is required {}".format(e)) 1224 1225 missing_operators = self._parse_import_prerequisites(graph) 1226 1227 if missing_operators: 1228 msg = 'The following operators are not supported in frontend TensorFlow: {}' 1229 ops = str(list(missing_operators)).strip('[,]') 1230 raise tvm.error.OpNotImplemented(msg.format(ops)) 1231 1232 for node in graph.node: 1233 if node.op == 'Placeholder': 1234 # Give priority to user argument. 1235 if shape and node.name in shape: 1236 self._input_shapes[node.name] = list(shape[node.name]) 1237 else: 1238 self._input_shapes[node.name] = \ 1239 tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) 1240 for idx, dim in enumerate(self._input_shapes[node.name]): 1241 if dim < 0: 1242 self._input_shapes[node.name][idx] = 1 1243 warnings.warn("Use 1 instead of -1 in shape of operator %s." 1244 % node.name) 1245 1246 self._nodes[node.name] = _sym.Variable(name=node.name, 1247 shape=self._input_shapes[node.name]) 1248 self._output_shapes[node.name] = [self._input_shapes[node.name]] 1249 self._outputs_are_0d[node.name] = [ \ 1250 not tshape if isinstance(tshape, list) else False \ 1251 for tshape in self._output_shapes[node.name]] 1252 1253 # Ignore user's input shape for Non placeholder 1254 elif node.op == 'Const': 1255 tensor_value = node.attr['value'].tensor 1256 self._input_shapes[node.name] = \ 1257 tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape) 1258 if shape and node.name in shape: 1259 warnings.warn("Ignore the passed shape. " 1260 "Shape in graphdef will be used for operator %s." % node.name) 1261 1262 final_op = None 1263 # Parse the nodes to re-create TF graph using Symbol API of NNVM 1264 for node in graph.node: 1265 # Tensorflow doesn't have separate list for params extraction. 1266 # Operator name 'Const' is treated as a parameter to build NNVM params dict. 1267 1268 input_shapes = {} 1269 input_0d_mismatch = set() 1270 attr = self._parse_attr(node.attr) 1271 1272 # Variable converted to Const will not have only value attr 1273 if 'value' in attr and node.op == 'Const': 1274 self._output_shapes[node.name] = [self._input_shapes[node.name]] 1275 elif '_output_shapes' in attr: 1276 self._output_shapes[node.name] = \ 1277 [tensor_util.TensorShapeProtoToList(tshape) \ 1278 for tshape in attr['_output_shapes']] 1279 else: 1280 # Keep the list indexable to avoid key error. 1281 # Actual value will be filled after node creation. 1282 # Will infer shapes if the graph is not frozen with add_shapes=True 1283 self._output_shapes[node.name] = [None] 1284 1285 self._outputs_are_0d[node.name] = [ \ 1286 not tshape if isinstance(tshape, list) else False \ 1287 for tshape in self._output_shapes[node.name]] 1288 1289 if node.op == "Const": 1290 # All Const nodes are Param nodes, lets parse 1291 self._num_param += 1 1292 for key, value in node.attr.items(): 1293 self._parse_param(key, value, node.name) 1294 if node.name not in self._nodes: 1295 raise NotImplementedError( \ 1296 "Const {} couldn't be converted to Param.".format(node.name)) 1297 1298 attr = self._parse_attr(node.attr) 1299 1300 elif node.op != "Placeholder": 1301 # Pass the parsed shapes instead 1302 attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] 1303 1304 # Pass the node name too in attr 1305 attr["_node_name"] = node.name 1306 1307 # Pass the target layout 1308 attr["_target_layout"] = layout 1309 1310 # Fill shapes for all inputs in a list 1311 inputs = [] 1312 for i in node.input: 1313 # Some TensorFlow operators internally maintain execution layers 1314 # and their output name includes the layer number along with 1315 # graph node name. E.g. the node name is 'Model/RNN/cell_0/RnnCell', but the 1316 # output tensor name is 'Model/RNN/cell_0/RnnCell:0'. In this case, 1317 # the number has to be ignored for single-output nodes. 1318 # On the other hand, for multi-output nodes the number is the output index, 1319 # and the lack of the number implies 0. 1320 tensor_name = i.split(':') 1321 node_name = tensor_name[0] 1322 if node_name in self._nodes: 1323 in_sym = self._nodes[node_name] 1324 if len(in_sym.list_output_names()) > 1: 1325 tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0 1326 in_sym = in_sym[tensor_slot] 1327 input_shape = self._output_shapes[node_name][tensor_slot] 1328 else: 1329 tensor_slot = 0 1330 input_shape = self._output_shapes[node_name][0] 1331 inputs.append(in_sym) 1332 input_shapes[in_sym] = input_shape 1333 # This means the node is 1d in NNVM and 0d in TF. 1334 # See `_expand_dims_0d_aware`. 1335 if self._outputs_are_0d[node_name][tensor_slot] and input_shape: 1336 input_0d_mismatch.add(in_sym) 1337 attr['_input_shapes'] = input_shapes 1338 attr['_input_0d_mismatch'] = input_0d_mismatch 1339 1340 inputs = self._fix_extranodes(node.op, attr, inputs) 1341 op = self._convert_operator(node.op, inputs, attr, graph) 1342 1343 # Check if op is converted to param 1344 if isinstance(op, np.ndarray): 1345 self._params[node.name] = tvm.nd.array(op) 1346 op = _sym.Variable(name=node.name, 1347 shape=self._params[node.name].shape) 1348 1349 # Assuming only one output. 1350 self._nodes[node.name] = op 1351 final_op = op 1352 1353 # Infer shapes even without specifying "add_shapes=True" 1354 if output_shapes == [None]: 1355 g = _graph.create(final_op) 1356 self._output_shapes[node.name] = \ 1357 list(graph_util.infer_shape(g, **self._input_shapes))[-1] 1358 1359 if self._output_shapes[node.name] and shape and node.name in shape: 1360 assert self._output_shapes[node.name] == list(shape[node.name]) 1361 1362 # Infer shapes if passed explicitely 1363 node_output = self._nodes[node.name] 1364 if shape and (not self._output_shapes[node.name][0] 1365 or -1 in self._output_shapes[node.name][0]): 1366 g = _graph.create(node_output) 1367 shape_dict = {k: v.shape for k, v in self._params.items()} 1368 shape_dict.update(shape) 1369 _, out_shapes = graph_util.infer_shape(g, **shape_dict) 1370 self._output_shapes[node.name] = out_shapes 1371 1372 out = [] 1373 if outputs is None: 1374 out.append(final_op) 1375 else: 1376 for out_name in outputs: 1377 if ":" in out_name: 1378 out_name, out_num = out_name.split(":") 1379 out_num = int(out_num) 1380 out.append(self._nodes[out_name][out_num]) 1381 else: 1382 out.append(self._nodes[out_name]) 1383 1384 #Add the RNN outputs also with 'head' nodes of the nnvm graph 1385 if self._num_rnn_layer: 1386 out_rnn = _sym.concatenate(*self._out_rnn, axis=0) 1387 out.append(out_rnn) 1388 1389 if isinstance(out, list): 1390 out = _sym.Group(out) if len(out) > 1 else out[0] 1391 1392 return out, self._params 1393 1394 def _parse_import_prerequisites(self, graph): 1395 """ Calculate the named preconditions from TensorFlow `graph`. 1396 Return prerequisites for parsing: 1397 a. Set of operator names which don't have their mapping in TVM, i.e. 1398 which are not supported 1399 """ 1400 missing_operators = set() 1401 for node in graph.node: 1402 if node.op == "Placeholder": 1403 pass 1404 elif node.op == "Const": 1405 pass 1406 else: 1407 if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]): 1408 pass 1409 else: 1410 missing_operators.add(node.op) 1411 1412 return missing_operators 1413 1414 def _parse_param(self, key, value, name): 1415 try: 1416 from tensorflow.python.framework import tensor_util 1417 except ImportError as e: 1418 raise ImportError( 1419 "Unable to import tensorflow which is required {}".format(e)) 1420 1421 if key == 'value': 1422 np_array = tensor_util.MakeNdarray(value.tensor) 1423 1424 if np_array.dtype == np.dtype(object): 1425 # Object types are generally tensorflow DT_STRING (DecodeJpeg op). 1426 # Just leave it as placeholder. 1427 self._nodes[name] = _sym.Variable(name=name) 1428 return 1429 1430 array_ndim = len(np_array.shape) 1431 if array_ndim == 0: 1432 new_array = np.empty([1], dtype=np_array.dtype) 1433 new_array[0] = np_array 1434 self._params[name] = tvm.nd.array(new_array) 1435 else: 1436 self._params[name] = tvm.nd.array(np_array) 1437 self._nodes[name] = _sym.Variable(name=name, 1438 shape=self._params[name].shape) 1439 else: 1440 if key not in ('dtype', '_output_shapes', '_class'): 1441 raise NotImplementedError \ 1442 ("Other attributes for a Const(param) Node {} ? .".format(key)) 1443 1444 def _get_attr(self, buf): 1445 """Returns the value of the attr of this buf with the given `name`. 1446 1447 Args: 1448 buf: attrvalue protobuf. 1449 1450 Returns: 1451 The value of the attr, as a Python object. 1452 1453 Raises: 1454 ValueError: If this op does not have an attr with the given `name`. 1455 """ 1456 fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] 1457 1458 x = buf 1459 1460 ret = [] 1461 1462 try: 1463 from tensorflow.python.framework import dtypes 1464 except ImportError as e: 1465 raise ImportError( 1466 "Unable to import tensorflow which is required {}".format(e)) 1467 1468 # Treat an empty oneof value as an empty list. 1469 if not x.WhichOneof("value"): 1470 return ret 1471 if x.HasField("list"): 1472 for f in fields: 1473 if getattr(x.list, f): 1474 if f == "type": 1475 ret += [dtypes.as_dtype(x) for x in list(getattr(x.list, f))] 1476 else: 1477 ret += list(getattr(x.list, f)) 1478 else: 1479 for f in fields: 1480 if x.HasField(f): 1481 if f == "type": 1482 ret = dtypes.as_dtype(getattr(x, f)) 1483 else: 1484 ret = getattr(x, f) 1485 return ret 1486 1487 def _parse_attr(self, attr_proto): 1488 """Convert a list of AttributeProto to a dict, with names as keys.""" 1489 attrs = {} 1490 for key, value in attr_proto.items(): 1491 attrs[key] = self._get_attr(value) 1492 1493 return attrs 1494 1495 def _convert_rnn_operator(self, op_name, inputs, 1496 attrs, params, graph, convert_map): 1497 """Convert RNN and its variant operators to NNVM operators. 1498 This converter read the input states of each layers and 1499 also maintain the output states of each layer in a list. 1500 1501 Parameters 1502 ---------- 1503 op_name : str 1504 Operator name, such as LSTMBlockCell 1505 inputs : list of nnvm.Symbol 1506 List of input symbols. 1507 attrs : dict 1508 Dict of operator attributes 1509 params : dict 1510 List of pretrained weights and bias 1511 graph : Tensorflow graph object 1512 Graph is to find the number of upcoming same operator to 1513 calculate the number of layers. 1514 convert_map : dict 1515 Dict of name : callable, where name is the op's name that 1516 require conversion to nnvm, callable are functions which 1517 take attrs and return (new_op_name, new_attrs) 1518 1519 Returns 1520 ------- 1521 sym : nnvm.Symbol 1522 Converted nnvm Symbol 1523 """ 1524 if not self._num_rnn_layer: 1525 self._out_rnn = [] 1526 self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map) 1527 self._num_rnn_layer = True 1528 sym = self.rnn.process_op(op_name, inputs, attrs, params) 1529 return sym 1530 1531 def _convert_operator(self, op_name, inputs, attrs, 1532 graph, identity_list=None, convert_map=None): 1533 """Convert from Tensorflow operator to nnvm operator. 1534 The converter must specify conversions explicitly for incompatible name, and 1535 apply handlers to operator attributes. 1536 1537 Parameters 1538 ---------- 1539 op_name : str 1540 Operator name, such as Conv2D, AvgPool 1541 inputs : list of nnvm.Symbol 1542 List of input symbols. 1543 attrs : dict 1544 Dict of operator attributes 1545 identity_list : list 1546 List of operators that don't require conversion 1547 convert_map : dict 1548 Dict of name : callable, where name is the op's name that 1549 require conversion to nnvm, callable are functions which 1550 take attrs and return (new_op_name, new_attrs) 1551 1552 Returns 1553 ------- 1554 sym : nnvm.Symbol 1555 Converted nnvm Symbol 1556 """ 1557 identity_list = identity_list if identity_list else _identity_list 1558 convert_map = convert_map if convert_map else _convert_map 1559 convert_map_rnn = _convert_map_rnn 1560 if op_name in identity_list: 1561 sym = get_nnvm_op(op_name)(*inputs, **attrs) 1562 elif op_name in convert_map: 1563 sym = convert_map[op_name](inputs, attrs, self._params) 1564 elif op_name in convert_map_rnn: 1565 sym = self._convert_rnn_operator(op_name, inputs, attrs, 1566 self._params, graph, 1567 convert_map_rnn) 1568 else: 1569 raise tvm.error.OpNotImplemented( 1570 'Operator {} is not supported in frontend TensorFlow.'.format(op_name)) 1571 return sym 1572 1573 def _fix_extranodes(self, op_name, attr, inputs): 1574 if op_name == "Softmax": 1575 # Require some times flatten of data before it goes to softmax 1576 # Need to relook into this with latest softmax axis support. 1577 op = AttrCvt(op_name='flatten')(inputs, {}) 1578 node_output = op.list_output_names() 1579 for k, i in zip(list(node_output), range(len(node_output))): 1580 self._nodes[k] = op[i] 1581 inputs = [op] 1582 1583 return inputs 1584 1585def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None): 1586 """Load tensorflow graph which is a python tensorflow graph object into nnvm graph. 1587 The companion parameters will be handled automatically. 1588 1589 Parameters 1590 ---------- 1591 graph : GraphDef object 1592 Tensorflow GraphDef 1593 1594 layout : target layout to be used (Optional) 1595 NCHW only supported now to enable NHWC models on GPU. 1596 1597 shape : Dictionary of input dimensions (Optional) 1598 Graph level input shape dictionary. 1599 1600 outputs : List of output tensor names (Optional) 1601 if not specified then the last node is assumed as graph output. 1602 1603 Returns 1604 ------- 1605 sym : nnvm.Symbol 1606 Compatible nnvm symbol 1607 1608 params : dict of str to tvm.ndarray 1609 Dict of converted parameters stored in tvm.ndarray format 1610 """ 1611 g = GraphProto() 1612 sym, params = g.from_tensorflow(graph, layout, shape, outputs) 1613 return sym, params 1614