1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines
18"""TF: Tensorflow frontend."""
19from __future__ import absolute_import as _abs
20from __future__ import print_function
21
22import warnings
23# Numpy support
24import numpy as np
25
26import tvm
27from .. import symbol as _sym
28from .. import graph as _graph
29from .. compiler import graph_util, build_module
30from .common import get_nnvm_op, AttrConverter as AttrConvert
31
32__all__ = ['from_tensorflow']
33
34class AttrCvt(object):
35    """A Wrapper to handle some common jobs:
36    """
37    def __init__(self, op_name, transforms=None,
38                 excludes=None, disables=None, ignores=None,
39                 extras=None, custom_check=None):
40        self._op_name = op_name
41        self._transforms = transforms if transforms else {}
42        self._excludes = excludes if excludes else []
43        self._disables = disables if disables else []
44        self._ignores = ignores if ignores else []
45        self._extras = extras if extras else {}
46        self._custom_check = custom_check
47
48    def __call__(self, inputs, attrs, *args):
49        self._ignores.append('_output_shapes')
50        self._ignores.append('_input_shapes')
51        self._ignores.append('T')
52        self._ignores.append('use_cudnn_on_gpu')
53        self._ignores.append('_node_name')
54        self._ignores.append('is_training')
55        self._ignores.append('_target_layout')
56        self._ignores.append('_input_0d_mismatch')
57        # Retain the names
58        try:
59            attrs['name'] = attrs['_node_name']
60        except KeyError:
61            pass
62        return AttrConvert(self._op_name, self._transforms, self._excludes,
63                           self._disables, self._ignores, self._extras,
64                           self._custom_check)(inputs, attrs, *args)
65
66def _get_pad_pair(input1d, kernel1d, stride1d):
67    if input1d % stride1d == 0:
68        pad = max(kernel1d - stride1d, 0)
69    else:
70        pad = max(kernel1d - (input1d % stride1d), 0)
71
72    pad_before = pad // 2
73    pad_after = pad - pad_before
74
75    return [pad_before, pad_after]
76
77def _math_name_picker(surfix):
78    def _impl(attr):
79        return 'broadcast_' + surfix
80    return _impl
81
82def _dimension_picker(prefix, surfix=''):
83    def _impl(attr):
84        kernel = attr['kernel_shape']
85        if len(kernel) == 2:
86            return prefix + '2d' + surfix
87        raise tvm.error.OpAttributeUnImplemented(
88            'Non-2D kernels are not supported for operator {}.'.format(prefix))
89    return _impl
90
91def _dimension_constraint():
92    def _dim_check(attrs):
93        if len(attrs['kernel_shape']) == 2:
94            return True
95        return False
96    return _dim_check, "Only 2d kernel supported."
97
98def _infer_channels(inputs, params, transpose=False):
99    """A hack for getting 'channles' or 'units' since tensorflow don't provide
100    these attributes. We check the shape of weights provided to get the number.
101    """
102    g = _graph.create(inputs)
103    shape_dict = {k: v.shape for k, v in params.items()}
104    _, out_shapes = graph_util.infer_shape(g, **shape_dict)
105    channels = out_shapes[0][0] if not transpose else out_shapes[0][1]
106    return channels
107
108def _rsqrt():
109    def _impl(inputs, attr, *args):
110        return AttrCvt(op_name="__pow_scalar__", extras={'scalar': -0.5})(inputs, attr)
111    return _impl
112
113def _argx(func, func_name):
114    """ A common wrapper for argmin and argmax operations """
115    def _impl(inputs, attr, params):
116        try:
117            # In Tensorflow, `axis` argument is a Tensor, not attribute. We
118            # support the case where it inputs from a scalar constant.
119            axis_input_name = inputs[1].list_output_names()[0]
120            axis_input_vlaue = params[axis_input_name].asnumpy()[0]
121        except (IndexError, KeyError):
122            raise TypeError( \
123                "Unsupported argument for `{}` : `axis` should be a constant".format(func_name))
124        return func(inputs[0], axis=axis_input_vlaue, keepdims=False)
125    return _impl
126
127def _elemwise(name):
128    def _impl(inputs, attr, *args):
129        assert len(inputs) == 2, "{} take 2 inputs, {} given".format(name, len(inputs))
130        op_name = _math_name_picker(name)(attr)
131        return get_nnvm_op(op_name)(*inputs)
132    return _impl
133
134def _pooling(name):
135    def _impl(inputs, attr, params):
136
137        attr['data_format'] = attr['data_format'].decode("utf-8")
138        flip_layout = False
139
140        input_shape = attr['_input_shapes'][inputs[0]]
141
142        if attr['data_format'] == 'NHWC':
143            attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
144            attr['strides'] = (attr['strides'][1], attr['strides'][2])
145        elif attr['data_format'] == 'NCHW':
146            attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3])
147            attr['strides'] = (attr['strides'][2], attr['strides'][3])
148        else:
149            msg = 'Value {} in attribute "data_format" of operator Pooling is not valid.'
150            raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
151
152        if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
153            tmp_shape = attr['_input_shapes'][inputs[0]]
154            input_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
155            inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
156            attr['data_format'] = "NCHW"
157            flip_layout = True
158
159        # Fix padding
160        attr['padding'] = attr['padding'].decode("utf-8")
161
162        if attr['padding'] == 'VALID':
163            attr['padding'] = [0, 0]
164        elif attr['padding'] == 'SAME':
165            stride_h, stride_w = attr['strides']
166            kernel_h, kernel_w = attr['kernel_shape']
167            if attr['data_format'] == 'NHWC':
168                in_h = input_shape[1]
169                in_w = input_shape[2]
170            else:
171                in_h = input_shape[2]
172                in_w = input_shape[3]
173
174            pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
175            pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
176
177            attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
178        else:
179            msg = 'Value {} in attribute "padding" of operator Pooling is not valid.'
180            raise tvm.error.OpAttributeUnImplemented(msg.format(attr['padding']))
181
182        if name == "avg_pool":
183            attr['count_include_pad'] = False
184
185        out = AttrCvt(
186            op_name=_dimension_picker(name),
187            transforms={
188                'kernel_shape':'pool_size',
189                'data_format':'layout'},
190            ignores=['ksize'],
191            extras={'ceil_mode': False},
192            custom_check=_dimension_constraint())(inputs, attr)
193
194        if flip_layout:
195            out = _sym.transpose(out, axes=(0, 2, 3, 1))
196
197        return out
198    return _impl
199
200def _conv(opname):
201    def _impl(inputs, attr, params):
202        attr['data_format'] = attr['data_format'].decode("utf-8")
203        flip_layout = False
204
205        # NCHW Layout require weights transpose
206        if attr['data_format'] == 'NCHW':
207            tmp_shape = attr['_input_shapes'][inputs[1]]
208            if opname == 'conv':
209                tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)]
210                inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1))
211            else:
212                tmp_shape = [tmp_shape[ii] for ii in (2, 3, 0, 1)]
213                inputs[1] = _sym.transpose(inputs[1], axes=(2, 3, 0, 1))
214            attr['_input_shapes'][inputs[1]] = tmp_shape
215
216        input_shape = attr['_input_shapes'][inputs[0]]
217        weights_shape = attr['_input_shapes'][inputs[1]]
218
219        if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
220            input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
221            inputs[0] = _sym.transpose(inputs[0], axes=(0, 3, 1, 2))
222            if opname == 'conv':
223                weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)]
224                inputs[1] = _sym.transpose(inputs[1], axes=(3, 2, 0, 1))
225            else:
226                weights_shape = [weights_shape[ii] for ii in (2, 3, 0, 1)]
227                inputs[1] = _sym.transpose(inputs[1], axes=(2, 3, 0, 1))
228
229            attr['data_format'] = "NCHW"
230            attr['strides'] = [attr['strides'][ii] for ii in (0, 3, 1, 2)]
231            flip_layout = True
232
233        if attr['data_format'] == 'NHWC':
234            kernel_h, kernel_w, _, depth_mult = weights_shape
235            attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
236            if opname == 'conv':
237                attr['channels'] = weights_shape[3]
238            else:
239                attr['channels'] = input_shape[3] * depth_mult
240
241            if 'dilations' in attr:
242                attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
243            attr['strides'] = (attr['strides'][1], attr['strides'][2])
244        elif attr['data_format'] == 'NCHW':
245            _, depth_mult, kernel_h, kernel_w = weights_shape
246            attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
247            if opname == 'conv':
248                attr['channels'] = weights_shape[0]
249            else:
250                attr['channels'] = input_shape[1] * depth_mult
251                if attr['channels'] < 0:
252                    attr['channels'] *= -1
253
254            if 'dilations' in attr:
255                attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])
256            attr['strides'] = (attr['strides'][2], attr['strides'][3])
257        else:
258            msg = 'Value {} in attribute "data_format" of operator Conv is not valid.'
259            raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
260
261
262        if opname == 'depthwise':
263            if depth_mult > 1:
264                raise tvm.error.OpNotImplemented('depth_mult > 1 of operator DepthwiseConv2dNative'
265                                                 ' is not supported.')
266            attr['groups'] = attr['channels']
267
268        # Fix padding
269        attr['padding'] = attr['padding'].decode("utf-8")
270
271        if attr['padding'] == 'VALID':
272            attr['padding'] = [0, 0]
273        elif attr['padding'] == 'SAME':
274            stride_h, stride_w = attr['strides']
275            kernel_h, kernel_w = attr['kernel_shape']
276            if attr['data_format'] == 'NHWC':
277                in_h = input_shape[1]
278                in_w = input_shape[2]
279            else:
280                in_h = input_shape[2]
281                in_w = input_shape[3]
282
283            dilation_h = attr['dilations'][0]
284            dilation_w = attr['dilations'][1]
285            dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
286            dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
287            pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
288            pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
289
290            if attr['data_format'] == 'NHWC':
291                inputs[0] = _sym.pad(data=inputs[0],
292                                     pad_width=((0, 0),
293                                                (pad_v[0], pad_v[1]),
294                                                (pad_h[0], pad_h[1]),
295                                                (0, 0)))
296            else:
297                inputs[0] = _sym.pad(data=inputs[0],
298                                     pad_width=((0, 0),
299                                                (0, 0),
300                                                (pad_v[0], pad_v[1]),
301                                                (pad_h[0], pad_h[1])))
302
303            attr['padding'] = [0, 0]
304
305        else:
306            msg = 'Value {} in attribute "padding" of operator Conv is not valid.'
307            raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
308
309        if 'kernel_layout' not in attr:
310            if opname == 'conv':
311                attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW'
312            else:
313                attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW'
314
315        out = AttrCvt(
316            op_name=_dimension_picker('conv'),
317            transforms={
318                'kernel_shape': 'kernel_size',
319                'data_format': 'layout',
320                'dilations': ('dilation', (0, 0)),
321                'group': ('groups', 1)},
322            extras={'use_bias': len(inputs) == 3},
323            custom_check=_dimension_constraint())(inputs, attr)
324
325        if flip_layout:
326            out = _sym.transpose(out, axes=(0, 2, 3, 1))
327
328        return out
329    return _impl
330
331def _decode_image():
332    def _impl(inputs, attr, params):
333        # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
334        warnings.warn("DecodeJpeg: It's a pass through, "
335                      "please handle preprocessing before input")
336        return inputs[0]
337    return _impl
338
339def _cast():
340    def _impl(inputs, attr, params):
341        # Convert from tensorflow Dtype to str
342        attr['DstT'] = attr['DstT'].name
343        return AttrCvt(op_name='cast', transforms={'DstT': 'dtype'},
344                       ignores=['SrcT', 'Truncate'])(inputs, attr)
345    return _impl
346
347def _expand_dims():
348    def _impl(inputs, attr, params):
349        dim_input = inputs.pop(1)
350        axis = params[dim_input.list_output_names()[0]]
351        params.pop(dim_input.list_output_names()[0])
352        return _expand_dims_0d_aware(inputs[0], attr, axis=axis.asnumpy()[0])
353    return _impl
354
355def _resize_bilinear():
356    def _impl(inputs, attr, params):
357        attr['size'] = attr['_output_shapes'][0][1:3]
358        inputs.pop(1)
359        # NHWC
360        attr['layout'] = 'NHWC'
361
362        return AttrCvt(op_name="resize",
363                       ignores=['Tdim'],
364                       extras={'method': "BILINEAR"})(inputs, attr)
365    return _impl
366
367def _check_numerics():
368    def _impl(inputs, attr, params):
369        # Making a copy node assuming no need to verify
370        return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr)
371    return _impl
372
373
374def _matmul():
375    def _impl(inputs, attr, params):
376        channels = _infer_channels(inputs[1], params, not attr['transpose_b'])
377        if attr['transpose_a']:
378            inputs[0] = _sym.transpose(inputs[0], axes=(1, 0))
379        if not attr['transpose_b']:
380            inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
381        return AttrCvt(op_name="dense",
382                       extras={'use_bias': False, 'units': channels},
383                       ignores=['transpose_a', 'transpose_b', 'T'])(inputs, attr)
384
385    return _impl
386
387def _undef():
388    def _impl(inputs, attr, params):
389        return _sym.__undef__()
390    return _impl
391
392def _identity():
393    def _impl(inputs, attr, params):
394        return inputs[0]
395    return _impl
396
397def _concatV2():
398    def _impl(inputs, attr, params):
399        pop_node = inputs.pop(len(inputs)-1)
400        axis = params[pop_node.list_output_names()[0]]
401        params.pop(pop_node.list_output_names()[0])
402        return AttrCvt(
403            op_name="concatenate", ignores=['T', 'N', 'Tidx'],
404            extras={'axis': axis.asnumpy()[0]})(inputs, attr)
405    return _impl
406
407def _concat():
408    def _impl(inputs, attr, params):
409        pop_node = inputs.pop(0)
410        axis = params[pop_node.list_output_names()[0]]
411        params.pop(pop_node.list_output_names()[0])
412        return AttrCvt(
413            op_name="concatenate", ignores=['N'],
414            extras={'axis': axis.asnumpy()[0]})(inputs, attr)
415    return _impl
416
417def _pack():
418    def _impl(inputs, attr, params):
419        axis = int(attr["axis"])
420        inputs_reshaped = [_expand_dims_0d_aware(i, attr, axis=axis, num_newaxis=1) for i in inputs]
421        return _sym.concatenate(*inputs_reshaped, axis=axis, name=attr["_node_name"])
422
423    return _impl
424
425def _slice():
426    def _impl(inputs, attr, params):
427        begin = params.pop(inputs[1].list_output_names()[0]).asnumpy().tolist()
428        size = params.pop(inputs[2].list_output_names()[0]).asnumpy().tolist()
429        data_shape = attr['_input_shapes'][inputs[0]]
430        data_dim = len(data_shape)
431        end = size
432        for i in range(data_dim):
433            if size[i] == -1:
434                end[i] = data_shape[i] - begin[i]
435            else:
436                end[i] += begin[i]
437        return _sym.strided_slice(inputs[0], begin=begin, end=size)
438    return _impl
439
440def _reshape():
441    def _impl(inputs, attr, params):
442        try:
443            pop_node = inputs[1]
444            shape_arg = params.pop(pop_node.list_output_names()[0])
445            inputs.pop(1)
446
447            return AttrCvt(
448                op_name="reshape",
449                extras={'shape':tuple(shape_arg.asnumpy())},
450                ignores=['Tshape'])(inputs, attr)
451        except KeyError:
452            # Shape operator is already pruned, hence
453            # try to infer shape by precompute prune if possible.
454            if all(in_node in params for in_node in inputs[1].list_input_names()):
455                graph = _graph.create(_sym.Group(inputs[1]))
456                params_pre = {k: params[k] for k in inputs[1].list_input_names()}
457                params_new = build_module._run_graph(graph, params_pre)
458                inputs.pop(1)
459                return AttrCvt(
460                    op_name="reshape",
461                    extras={'shape':tuple(params_new[0].asnumpy().flatten())},
462                    ignores=['Tshape'])(inputs, attr)
463            raise tvm.error.OpAttributeUnimplemented(
464                'Attribute "dynamic shape" of operator Reshape is not supported.')
465    return _impl
466
467def _bias_add():
468    def _impl(inputs, attr, params):
469        if attr['data_format'].decode("utf-8") == 'NCHW':
470            bias = _sym.reshape(inputs[1], newshape=(1, -1, 1, 1))
471        else:
472            bias = inputs[1]
473        return _sym.broadcast_add(inputs[0], bias)
474    return _impl
475
476def _squeeze():
477    def _impl(inputs, attr, params):
478        return AttrCvt(
479            op_name="squeeze",
480            transforms={'squeeze_dims':'axis'},
481            ignores=['T'])(inputs, attr)
482    return _impl
483
484def _fused_batch_norm():
485    def _impl(inputs, attr, params):
486        # Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
487        # NNVM:       (data, gamma, beta, moving_mean, moving_varience)
488        axis = 3
489        need_cast = False
490
491        if 'data_format' in attr:
492            attr['data_format'] = attr['data_format'].decode("utf-8")
493            if attr['data_format'] == 'NCHW':
494                axis = 1
495        if 'U' in attr:
496            need_cast = True
497            inputs[0] = _sym.cast(inputs[0], dtype=attr['U'].name)
498
499        out = AttrCvt(op_name='batch_norm',
500                      transforms={'scale_after_normalization':'scale',
501                                  'variance_epsilon':'epsilon'},
502                      extras={'axis': axis},
503                      ignores=['data_format', 'U'],
504                      disables=['momentum'])(inputs, attr)
505
506        if need_cast:
507            out = _sym.cast(out, dtype=attr['T'].name)
508        return out
509    return _impl
510
511def _batch_norm():
512    def _impl(inputs, attr, params):
513        # Rearrange inputs from
514        # (data, moving_mean, moving_variance, beta, gamma)
515        #     to
516        # (data, gamma, beta, moving_mean, moving_var)
517        new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]]
518
519        axis = 3
520        if 'data_format' in attr:
521            attr['data_format'] = attr['data_format'].decode("utf-8")
522            if attr['data_format'] == 'NCHW':
523                axis = 1
524
525        return AttrCvt(
526            op_name='batch_norm',
527            transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
528            extras={'axis': axis},
529            ignores=['data_format'],
530            disables=['momentum'])(new_inputs, attr)
531    return _impl
532
533def _relu6():
534    def _impl(inputs, attr, params):
535        return _sym.clip(inputs[0], a_min=0, a_max=6, name=attr['_node_name'])
536    return _impl
537
538def _shape():
539    def _impl(inputs, attr, params):
540        return np.array(attr['_input_shapes'][inputs[0]], dtype='int32')
541    return _impl
542
543def _fill():
544    def _impl(inputs, attr, params):
545        fill_arg = params.pop(inputs.pop(1).list_output_names()[0])
546        new_inputs = []
547        return AttrCvt(
548            op_name='full',
549            extras={'shape':inputs[0],
550                    'fill_value':fill_arg.asnumpy()[0], 'dtype':attr['T'].name},
551            ignores=['index_type', 'T'])(new_inputs, attr)
552    return _impl
553
554def _lrn():
555    def _impl(inputs, attr, params):
556        attr_new = {}
557        depth_radius = attr.get('depth_radius', 5)
558        size = (depth_radius * 2) + 1
559        attr_new['axis'] = 3 # Fix axis, NHWC format
560        attr_new['size'] = size
561        attr_new['bias'] = attr.get('bias', 1)
562        attr_new['alpha'] = attr.get('alpha', 1) * size
563        attr_new['beta'] = attr.get('beta', 0.5)
564        return AttrCvt(op_name='lrn')(inputs, attr_new)
565    return _impl
566
567def _sum():
568    def _impl(inputs, attr, params):
569        axis = params.pop(inputs[1].list_output_names()[0]).asnumpy()
570        # convert to tuple for preventing invalid parameter format error
571        axis = tuple(axis)
572        return AttrCvt(
573            op_name='sum',
574            extras={'axis': axis},
575            transforms={'keep_dims':'keepdims'},
576            ignores=['name', 'Tidx'])(inputs[0], attr)
577    return _impl
578
579def _square():
580    def _impl(inputs, attr, params):
581        return _sym.elemwise_mul(inputs[0], inputs[0])
582    return _impl
583
584def _gather_v2():
585    "Tensorflow now support only gatherv2"
586    def _impl(inputs, attr, params):
587        axis = params[inputs.pop(2).list_output_names()[0]].asnumpy()[0]
588        new_input = []
589        new_input.append(inputs.pop(0))
590        new_input.append(inputs.pop(0))
591        return AttrCvt(
592            op_name="take",
593            extras={'axis':axis},
594            ignores=['Tindices', 'Tparams', 'validate_indices', \
595                     'Taxis', '_class'])(new_input, attr)
596    return _impl
597
598def _infer_out_shapes(inputs, params):
599    """A method to get the output shape of an intermediate node in the NNVM graph."""
600    g = _graph.create(inputs)
601    shape_dict = {k: v.shape for k, v in params.items()}
602    _, out_shapes = graph_util.infer_shape(g, **shape_dict)
603    return out_shapes
604
605def _stridedSlice():
606    def _impl(inputs, attr, params):
607        """Strided Slice.
608        Operator description: https://www.tensorflow.org/api_docs/python/tf/strided_slice
609        Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
610        tensorflow/core/util/strided_slice_op.cc#L147-L368
611        """
612        begin = params.pop(inputs[1].list_output_names()[0]).asnumpy().tolist()
613        end = params.pop(inputs[2].list_output_names()[0]).asnumpy().tolist()
614        stride = params.pop(inputs[3].list_output_names()[0]).asnumpy().tolist()
615        begin_mask = int(attr.get('begin_mask', 0))
616        end_mask = int(attr.get('end_mask', 0))
617        ellipsis_mask = int(attr.get('ellipsis_mask', 0))
618        new_axis_mask = int(attr.get('new_axis_mask', 0))
619        shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))
620        data_shape = attr['_input_shapes'][inputs[0]]
621        data_dim = len(data_shape)
622        stride_dim = len(stride)
623
624        def _transform_mask(stride_dim, ellipsis_mask):
625            """Handle mask inputs to create new begin, end, stride and output shape"""
626            m_begin = [0] * data_dim
627            m_end = [0] * data_dim
628            m_stride = [0] * data_dim
629            fshape_indices = []
630            #Count new axis after ellipsis_mask, consider while applying ellipsis_mask.
631            ellipsis_seen = False
632            new_axes_after_ellipsis = 0
633            for i in range(stride_dim):
634                mask = 1 << i
635                if ellipsis_seen and (mask & new_axis_mask) != 0:
636                    new_axes_after_ellipsis += 1
637                if (mask & ellipsis_mask) != 0:
638                    ellipsis_seen = True
639            if not ellipsis_seen:
640                #Used later for extending the stride attributes in the below loop.
641                ellipsis_mask |= (1 << stride_dim)
642                stride_dim += 1
643            final_index = 0
644            for index in range(stride_dim):
645                mask = 1 << index
646                if mask & ellipsis_mask:
647                    #Identify the end index for applying ellipsis_mask
648                    to_index = min(((data_dim - (stride_dim-index)) + 1 \
649                                     + new_axes_after_ellipsis), data_dim)
650                    for i in range(final_index, to_index):
651                        m_begin[final_index] = 0
652                        m_end[final_index] = data_shape[final_index]
653                        m_stride[final_index] = 1
654                        fshape_indices.append(final_index)
655                        final_index += 1
656                elif mask &new_axis_mask:
657                    fshape_indices.append(-1)
658                elif not mask & new_axis_mask:
659                    if final_index == len(m_begin):
660                        break
661                    if mask & begin_mask:
662                        m_begin[final_index] = data_shape[final_index] \
663                                                     if stride[index] < 0 else 0
664                    elif begin[index]:
665                        m_begin[final_index] = begin[index]
666                    if mask & end_mask:
667                        m_end[final_index] = 0 if stride[index] < 0 \
668                                                 else data_shape[final_index]
669                    elif end[index]:
670                        m_end[final_index] = end[index]
671                    m_stride[final_index] = stride[index]
672                    if mask & shrink_axis_mask:
673                        #Tensorflow make axis with shrink_axis_mask as dimension 1
674                        m_begin[final_index] = data_shape[final_index] + begin[index] \
675                                                 if begin[index] < 0 else begin[index]
676                        m_end[final_index] = begin[index] + 1
677                        m_stride[final_index] = 1
678                        fshape_indices.append(-2)
679                    else:
680                        fshape_indices.append(final_index)
681
682                    final_index += 1
683            return m_begin, m_end, m_stride, fshape_indices
684
685        fshape_indices = None
686        if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask:
687            begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask)
688        out = _sym.strided_slice(inputs[0], begin=begin, end=end, stride=stride)
689        out_shape = _infer_out_shapes(out, params)[0]
690        if not fshape_indices:
691            fshape_indices = range(len(out_shape))
692
693        #Create final output shape.
694        final_output = []
695        for gather_index in fshape_indices:
696            if gather_index == -1:
697                final_output.append(1)
698            elif gather_index == -2:
699                pass
700            else:
701                final_output.append(out_shape[gather_index])
702        # Prevent 0-dim tensors which are not accepted by nnvm
703        if not final_output:
704            final_output.append(1)
705        return _sym.reshape(out, shape=tuple(final_output))
706    return _impl
707
708def _LSTMBlockCell():
709    def _impl(inputs, in_state_c, in_state_h, attr, params):
710        """LSTM Block cell.
711        Calculations are described in: https://github.com/tensorflow/tensorflow/blob/
712        r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114
713
714        Parameters
715        ----------
716        inputs : nnvm.Symbol
717            Input data
718        in_state_c: list of nnvm.Symbol
719            Cell state input values for all the layers
720        in_state_h: list of nnvm.Symbol
721            Hidden state input values for all the layers
722        attrs : dict
723            Dict of operator attributes
724        params : dict
725            List of pretrained weights and bias
726
727        Returns
728        -------
729        sym : nnvm.Symbol
730            Converted nnvm Symbol
731        output: nnvm.Symbol
732            Output state value.
733        """
734        in_data = inputs[0]
735        in_weight = inputs[3]
736        in_bias = inputs[7]
737        forget_bias = attr.pop('forget_bias')
738        input_shape = attr['_input_shapes'][inputs[0]]
739        weight_shape = attr['_input_shapes'][inputs[3]]
740        batch_size, input_size = input_shape[0], input_shape[1]
741        num_hidden_layers = weight_shape[1]
742        num_hidden = num_hidden_layers // 4
743
744        in_data = _sym.reshape(in_data,
745                               shape=(batch_size, input_size))
746        ixh = _sym.concatenate(*[in_data, in_state_h], axis=1)
747        in_weight = _sym.transpose(in_weight)
748        gates = _sym.dense(ixh, in_weight, in_bias, use_bias=True,
749                           units=num_hidden_layers)
750        gate_list = _sym.split(gates, indices_or_sections=4, axis=1)
751        in_gate = _sym.sigmoid(gate_list[0])
752        in_transform = _sym.tanh(gate_list[1])
753        forget_gate = _sym.sigmoid(gate_list[2])
754        forget_gate = forget_gate + forget_bias
755        out_gate = _sym.sigmoid(gate_list[3])
756        next_c = _sym.broadcast_add(_sym.broadcast_mul(forget_gate, in_state_c),
757                                    _sym.broadcast_mul(in_gate, in_transform))
758        next_h = out_gate * _sym.tanh(next_c)
759        out_state = _sym.concatenate(*[next_c, next_h])
760        out_state = _sym.reshape(out_state,
761                                 shape=(2, batch_size, num_hidden))
762        return next_h, out_state
763    return _impl
764
765
766def _pad(name):
767    def _impl(inputs, attr, params):
768        padlist_key = inputs[1].list_output_names()[0]
769        if padlist_key in params:
770            padlist = params.pop(padlist_key).asnumpy()
771        else:
772            raise tvm.error.OpAttributeRequired(
773                'Required attribute "{}" not found in operator Pad.'.format(padlist_key))
774        paddings = tuple([tuple(l) for l in padlist])
775        attr['pad_width'] = paddings
776        attr['pad_value'] = 0
777        new_inputs = [inputs[0]]
778        if name == 'PadV2':
779            constant_values = params.pop(inputs[2].list_output_names()[0]).asnumpy()
780            attr['pad_value'] = constant_values[0]
781        return AttrCvt(
782            op_name='pad',
783            ignores=['Tpaddings'],)(new_inputs, attr)
784    return _impl
785
786
787def _transpose():
788    def _impl(inputs, attr, params):
789        # If perm is not specified, axes is left empty,
790        # otherwise its value is get from params
791        param_name = inputs[1].list_output_names()[0]
792        axes = params.get(param_name, tvm.nd.array([])).asnumpy()
793        return _sym.transpose(inputs[0], axes=tuple(axes))
794    return _impl
795
796def _rank():
797    def _impl(inputs, attr, params):
798        input_shape = attr['_input_shapes'][inputs[0]]
799
800        name = attr["_node_name"]
801        params[name] = tvm.nd.array([len(input_shape)])
802        return _sym.Variable(name=name, shape=params[name].shape)
803    return _impl
804
805def _range():
806    def _impl(inputs, attr, params):
807        start = params.pop(inputs[0].list_output_names()[0]).asnumpy()[0]
808        limit = params.pop(inputs[1].list_output_names()[0]).asnumpy()[0]
809        delta = params.pop(inputs[2].list_output_names()[0]).asnumpy()[0]
810
811        name = attr["_node_name"]
812        params[name] = tvm.nd.array([start, limit, delta])
813        return _sym.Variable(name=name, shape=params[name].shape)
814    return _impl
815
816def _elu():
817    def _impl(inputs, attr, params):
818        alpha = 1.0
819        return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0])
820    return _impl
821
822def _selu():
823    def _impl(inputs, attr, params):
824        alpha = 1.6732632423543772848170429916717
825        gamma = 1.0507009873554804934193349852946
826        return gamma * (-alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0]))
827    return _impl
828
829def _mean():
830    def _impl(inputs, attr, params):
831        axis = params.pop(inputs[1].list_output_names()[0])
832        return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'],
833                       transforms={'keep_dims': 'keepdims'},
834                       extras={'axis': tuple(axis.asnumpy())})(inputs[0], attr)
835    return _impl
836
837def _broadcast(name):
838    def _impl(inputs, attr, params):
839        op_name = _math_name_picker(name)(attr)
840        return AttrCvt(
841            op_name=op_name,
842            ignores=['name', 'Tidx']
843        )(inputs, attr)
844    return _impl
845
846def _split(has_size_vector):
847    # TF documentation https://www.tensorflow.org/api_docs/python/tf/split
848    def _impl(inputs, attr, params):
849        try:
850            # order and number of inputs are different:
851            # if has_size_vector:
852            #     https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split-v
853            # else:
854            #     https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/split
855
856            # in addition, `axis` and `num_or_size_splits` can be tensors in TensorFlow,
857            # we can only support constants
858            if has_size_vector:
859                input_node_index = 0
860                input_axis_index = 2
861                size_splits_input_name = inputs[1].list_output_names()[0]
862                size_splits = params[size_splits_input_name].asnumpy()
863                section_beginnings = np.cumsum(size_splits)[:-1]
864                indices_or_sections = tuple(section_beginnings)
865            else:
866                input_node_index = 1
867                input_axis_index = 0
868                indices_or_sections = attr['num_split']
869            input_node = inputs[input_node_index]
870            axis_input_name = inputs[input_axis_index].list_output_names()[0]
871            axis_input_value = params[axis_input_name].asnumpy()[0]
872        except (IndexError, KeyError):
873            raise TypeError( \
874                "Unsupported argument for split: `axis` and `num_or_size_splits` " \
875                "should be constants")
876        return _sym.split(input_node,
877                          indices_or_sections=indices_or_sections,
878                          axis=axis_input_value)
879    return _impl
880
881def _unpack():
882    def _impl(inputs, attr, params):
883        input_node = inputs[0]
884        axis = attr['axis']
885        input_shape = attr['_input_shapes'][input_node]
886        axis_length = input_shape[axis]
887        if axis_length < 0:
888            raise TypeError("Unstack with unknown axis length")
889        splitted = _sym.split(input_node,
890                              indices_or_sections=axis_length,
891                              axis=axis,
892                              name=attr.get('_node_name', 'unstack'))
893
894        return _sym.Group([_sym.squeeze(split_item, axis=axis) for split_item in splitted])
895    return _impl
896
897def _expand_dims_0d_aware(data, attr, axis, num_newaxis=1):
898    if data in attr['_input_0d_mismatch']:
899        return data if num_newaxis == 1 else \
900            _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis-1)
901
902    return _sym.expand_dims(data, axis=axis, num_newaxis=num_newaxis)
903
904def _logical(name):
905    def _impl(inputs, attr, params):
906        return AttrCvt(op_name=name)(inputs, attr)
907    return _impl
908
909# compatible operators that do NOT require any conversion.
910_identity_list = []
911
912# _convert_map defines maps of name to converter functor(callable)
913# for 1 to 1 mapping, use Renamer if nothing but name is different
914# use AttrCvt if attributes need to be converted
915# for 1 to N mapping(composed), use custom callable functions
916# for N to 1 mapping, currently not supported(?)
917_convert_map = {
918    'ArgMax'                            : _argx(_sym.argmax, 'argmax'),
919    'ArgMin'                            : _argx(_sym.argmin, 'argmin'),
920    'AvgPool'                           : _pooling('avg_pool'),
921    'BatchNormWithGlobalNormalization'  : _batch_norm(),
922    'BiasAdd'                           : _bias_add(),
923    'Cast'                              : _cast(),
924    'Ceil'                              : AttrCvt('ceil'),
925    'CheckNumerics'                     : _check_numerics(),
926    'Concat'                            : _concat(),
927    'ConcatV2'                          : _concatV2(),
928    'Conv2D'                            : _conv('conv'),
929    'DecodeJpeg'                        : _decode_image(),
930    'Elu'                               : _elu(),
931    'ExpandDims'                        : _expand_dims(),
932    'Floor'                             : AttrCvt('floor'),
933    'Identity'                          : _identity(),
934    'MatMul'                            : _matmul(),
935    'MaxPool'                           : _pooling('max_pool'),
936    'Add'                               : _elemwise('add'),
937    'Sub'                               : _elemwise('sub'),
938    'Mul'                               : _elemwise('mul'),
939    'RealDiv'                           : _elemwise('div'),
940    'Maximum'                           : _elemwise('max'),
941    'Minimum'                           : _elemwise('min'),
942    'Sum'                               : _sum(),
943    'Square'                            : _square(),
944    'Pack'                              : _pack(),
945    'Slice'                             : _slice(),
946    'LeakyRelu'                         : AttrCvt('leaky_relu'),
947    'Relu'                              : AttrCvt('relu'),
948    'Reshape'                           : _reshape(),
949    'ResizeBilinear'                    : _resize_bilinear(),
950    'Selu'                              : _selu(),
951    'Softmax'                           : AttrCvt('softmax', {'axis': ('axis', 1)}),
952    'Rsqrt'                             : _rsqrt(),
953    'Squeeze'                           : _squeeze(),
954    'FusedBatchNorm'                    : _fused_batch_norm(),
955    'FusedBatchNormV2'                  : _fused_batch_norm(),
956    'Relu6'                             : _relu6(),
957    'DepthwiseConv2dNative'             : _conv('depthwise'),
958    'Shape'                             : _shape(),
959    'Sigmoid'                           : AttrCvt('sigmoid'),
960    'Fill'                              : _fill(),
961    'GatherV2'                          : _gather_v2(),
962    'StridedSlice'                      : _stridedSlice(),
963    'LRN'                               : _lrn(),
964    'Pad'                               : _pad('Pad'),
965    'PadV2'                             : _pad('PadV2'),
966    'Range'                             : _range(),
967    'Rank'                              : _rank(),
968    'Transpose'                         : _transpose(),
969    'Tanh'                              : AttrCvt('tanh'),
970    'Mean'                              : _mean(),
971    'LogicalAnd'                        : _logical('logical_and'),
972    'LogicalOr'                         : _logical('logical_or'),
973    'LogicalNot'                        : _logical('logical_not'),
974    'Less'                              : _broadcast('less'),
975    'Greater'                           : _broadcast('greater'),
976    'LessEqual'                         : _broadcast('less_equal'),
977    'GreaterEqual'                      : _broadcast('greater_equal'),
978    'Equal'                             : _broadcast('equal'),
979    'NotEqual'                          : _broadcast('not_equal'),
980    'Split'                             : _split(False),
981    'SplitV'                            : _split(True),
982    'Unpack'                            : _unpack(),
983}
984
985# _convert_map_rnn defines maps of rnn operator name to
986# converter functor(callable) for 1 to 1 mapping.
987_convert_map_rnn = {
988    'LSTMBlockCell'                     : _LSTMBlockCell(),
989}
990
991class RecurrentNetworks(object):
992    """Recurrent network layer handlers.
993
994    Handle Layer operations.
995    ToDo: Operators like RNN/GRU layer concepts also can be handled here
996
997    Parameters
998    ----------
999    nodes : list
1000        list of graph nodes used for tensorflow parsing.
1001
1002    out_rnn : list
1003        List of RecurrentNetwork outputs. This output will be appended to the
1004        'head' nodes of the graph.
1005
1006    graph : tensorflow graph definition object
1007        The loaded tensorflow GraphDef
1008
1009    convert_map : dict
1010        Dict of name : callable, where name is the op's name that
1011        require conversion to nnvm, callable are functions which
1012        take attrs and return (new_op_name, new_attrs)
1013    """
1014    def __init__(self, nodes, out_rnn, graph, convert_map):
1015        self._graph = graph
1016        self._convert_map = convert_map
1017        self._nodes = nodes
1018        self._out_rnn = out_rnn
1019        self._cur_lstm_layer = 0
1020        self._layer_name_list = []
1021        self._recurrent_ops_layer_map = {
1022            'LSTMBlockCell'               : self._LSTMBlockCellLayer(),
1023        }
1024
1025    def _LSTMBlockCellLayer(self):
1026        """LSTMBlockCell layer handler.
1027
1028        Parameters
1029        ----------
1030        op_name : str
1031            Operator name, eg:LSTMBlockCell
1032
1033        layer_name : str list
1034            Layer name is used for creating the state input placeholder.
1035
1036        inputs : nnvm.Symbol
1037            Input data
1038
1039        attrs : dict
1040            Dict of operator attributes
1041
1042        params : dict
1043            List of pretrained weights and bias
1044
1045        num_layers : int
1046            Total number of LSTM layer presented in the graph
1047
1048        Returns
1049        -------
1050        sym : nnvm.sym.Symbol
1051            The returned nnvm symbol
1052        """
1053        def _impl(op_name, layer_name, inputs, attrs, params, num_layers):
1054            in_state_c_name = layer_name+'_c'
1055            in_state_h_name = layer_name+'_h'
1056
1057            def _init_state(num_layers, batch_size, num_hidden):
1058                """Create the initial states for the first layer in the graph."""
1059                in_state_c = _sym.Variable(in_state_c_name,
1060                                           shape=(num_layers, batch_size, num_hidden))
1061                in_state_h = _sym.Variable(in_state_h_name,
1062                                           shape=(num_layers, batch_size, num_hidden))
1063                return in_state_c, in_state_h
1064
1065            def _get_cur_input_state(in_state_c, in_state_h, num_layers,
1066                                     layer, batch_size, num_hidden):
1067                """Select the appropriate states for the current layer"""
1068                in_state_c_tup = _sym.split(in_state_c,
1069                                            indices_or_sections=num_layers, axis=0)
1070                in_state_h_tup = _sym.split(in_state_h,
1071                                            indices_or_sections=num_layers, axis=0)
1072                cur_in_state_c = _sym.reshape(in_state_c_tup[layer],
1073                                              shape=(batch_size, num_hidden))
1074                cur_in_state_h = _sym.reshape(in_state_h_tup[layer],
1075                                              shape=(batch_size, num_hidden))
1076                return cur_in_state_c, cur_in_state_h
1077
1078            def _LSTMBlockCellWrapper(inputs, attr, params,
1079                                      num_layers, layer):
1080                """LSTM cell warapper to prepare the inputs"""
1081                input_shape = attr['_input_shapes'][inputs[0]]
1082                weight_shape = attr['_input_shapes'][inputs[3]]
1083                batch_size = input_shape[0]
1084                num_hidden = weight_shape[1] // 4
1085
1086                if layer == 0:
1087                    #Create initial states placeholder in case of first layer
1088                    in_state_c, in_state_h = _init_state(num_layers,
1089                                                         batch_size, num_hidden)
1090                else:
1091                    in_state_c = self._nodes[in_state_c_name]
1092                    in_state_h = self._nodes[in_state_h_name]
1093
1094                cur_in_state_c, cur_in_state_h = _get_cur_input_state( \
1095                                                    in_state_c, in_state_h,
1096                                                    num_layers, layer,
1097                                                    batch_size, num_hidden)
1098                output, out_state = self._convert_map[op_name](inputs, cur_in_state_c,
1099                                                               cur_in_state_h,
1100                                                               attr, params)
1101                return output, out_state, in_state_c, in_state_h
1102
1103            sym, cur_out_state, in_state_c, in_state_h = \
1104                    _LSTMBlockCellWrapper(inputs, attrs, params,
1105                                          num_layers, self._cur_lstm_layer)
1106            self._nodes[in_state_c_name] = in_state_c
1107            self._nodes[in_state_h_name] = in_state_h
1108            cur_out_state = _sym.expand_dims(cur_out_state, axis=0, num_newaxis=1)
1109            self._out_rnn.append(cur_out_state)
1110            self._cur_lstm_layer += 1
1111            return sym
1112        return _impl
1113
1114    def process_op(self, op_name, inputs, attrs, params):
1115        """Process recurrent layer operators.
1116
1117        List '_recurrent_ops_layer_map' map each Layer based operators with its
1118        layer handlers. Total number of layers are calculated to form the input
1119        data shapes.
1120
1121        Parameters
1122        ----------
1123        op_name : str
1124            Operator name, such as LSTMBlockCell
1125
1126        inputs : nnvm.Symbol
1127            Input data
1128
1129        attrs : dict
1130            Dict of operator attributes
1131
1132        params : dict
1133            List of pretrained weights and bias
1134
1135        Returns
1136        -------
1137        sym : nnvm.sym.Symbol
1138            The returned nnvm symbol
1139        """
1140        def _get_abs_layer_name(node):
1141            """Identify the layer name is already handled. Return the absolute name
1142            """
1143            if not self._layer_name_list:
1144                self._layer_name_list.append(node.name)
1145                return node.name
1146
1147            for _name in self._layer_name_list:
1148                if _name in node.name:
1149                    abs_name = _name
1150                else:
1151                    self._layer_name_list.append(node.name)
1152                    abs_name = node.name
1153            return abs_name
1154
1155        #Find number of layers of this same operator node in the graph
1156        #and also read the inputs name for the current op.
1157        num_layers = 0
1158        for _, node in enumerate(self._graph.node):
1159            if node.op == op_name:
1160                layer_name = _get_abs_layer_name(node)
1161                num_layers += 1
1162
1163        sym = self._recurrent_ops_layer_map[op_name](op_name, layer_name, inputs, attrs,
1164                                                     params, num_layers)
1165        return sym
1166
1167class GraphProto(object):
1168    """ A helper class for handling nnvm graph copying from Tensorflow GraphDef.
1169    Definition:
1170        https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto
1171    """
1172    def __init__(self):
1173        self._nodes = {}
1174        self._params = {}
1175        self._output_shapes = {}
1176        self._num_param = 0
1177        self._num_rnn_layer = False
1178        self._outputs_are_0d = {}
1179        self._input_shapes = {}
1180
1181    def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
1182        """Construct nnvm nodes from tensorflow graph definition - GraphDef.
1183
1184        Follow the tensorflow graph definition to parse and convert it to NNVM.
1185        Some of the assumptions listed below.
1186
1187            -> All Placeholders are considered as graph input.
1188            -> All Const nodes are params.
1189            -> Last node is assumed as graph output.
1190            -> _output_shapes : Graph should be frozen with add_shapes=True.
1191                                Or user can pass input shape dictionary optionally.
1192            -> DecodeJpeg, ResizeBilinear: These are dummy operators.
1193                                           Hence user should handle preprocessing outside.
1194            -> CheckNumerics: No implementation as of now for this.
1195                              Just copies input to output.
1196
1197        Parameters
1198        ----------
1199        graph : tensorflow graph definition object
1200            The loaded tensorflow GraphDef
1201
1202        layout : target layout to be used (Optional)
1203            NCHW only supported now to enable NHWC models on GPU.
1204
1205        shape : Dictionary of input dimensions (Optional)
1206            Graph level input shape dictionary.
1207
1208        outputs : List of output tensor names (Optional)
1209            if not specified then the last node is assumed as graph output.
1210
1211        Returns
1212        -------
1213        sym : nnvm.sym.Symbol
1214            The returned nnvm symbol
1215        params : dict
1216            A dict of name: tvm.nd.array pairs, used as pretrained weights
1217        """
1218
1219        try:
1220            from tensorflow.python.framework import tensor_util
1221        except ImportError as e:
1222            raise ImportError(
1223                "Unable to import tensorflow which is required {}".format(e))
1224
1225        missing_operators = self._parse_import_prerequisites(graph)
1226
1227        if missing_operators:
1228            msg = 'The following operators are not supported in frontend TensorFlow: {}'
1229            ops = str(list(missing_operators)).strip('[,]')
1230            raise tvm.error.OpNotImplemented(msg.format(ops))
1231
1232        for node in graph.node:
1233            if node.op == 'Placeholder':
1234                # Give priority to user argument.
1235                if shape and node.name in shape:
1236                    self._input_shapes[node.name] = list(shape[node.name])
1237                else:
1238                    self._input_shapes[node.name] = \
1239                        tensor_util.TensorShapeProtoToList(node.attr['shape'].shape)
1240                    for idx, dim in enumerate(self._input_shapes[node.name]):
1241                        if dim < 0:
1242                            self._input_shapes[node.name][idx] = 1
1243                            warnings.warn("Use 1 instead of -1 in shape of operator %s."
1244                                          % node.name)
1245
1246                self._nodes[node.name] = _sym.Variable(name=node.name,
1247                                                       shape=self._input_shapes[node.name])
1248                self._output_shapes[node.name] = [self._input_shapes[node.name]]
1249                self._outputs_are_0d[node.name] = [ \
1250                    not tshape if isinstance(tshape, list) else False \
1251                    for tshape in self._output_shapes[node.name]]
1252
1253            # Ignore user's input shape for Non placeholder
1254            elif node.op == 'Const':
1255                tensor_value = node.attr['value'].tensor
1256                self._input_shapes[node.name] = \
1257                    tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape)
1258                if shape and node.name in shape:
1259                    warnings.warn("Ignore the passed shape. "
1260                                  "Shape in graphdef will be used for operator %s." % node.name)
1261
1262        final_op = None
1263        # Parse the nodes to re-create TF graph using Symbol API of NNVM
1264        for node in graph.node:
1265            # Tensorflow doesn't have separate list for params extraction.
1266            # Operator name 'Const' is treated as a parameter to build NNVM params dict.
1267
1268            input_shapes = {}
1269            input_0d_mismatch = set()
1270            attr = self._parse_attr(node.attr)
1271
1272            #  Variable converted to Const will not have only value attr
1273            if 'value' in attr and node.op == 'Const':
1274                self._output_shapes[node.name] = [self._input_shapes[node.name]]
1275            elif '_output_shapes' in attr:
1276                self._output_shapes[node.name] = \
1277                    [tensor_util.TensorShapeProtoToList(tshape) \
1278                    for tshape in attr['_output_shapes']]
1279            else:
1280                # Keep the list indexable to avoid key error.
1281                # Actual value will be filled after node creation.
1282                # Will infer shapes if the graph is not frozen with add_shapes=True
1283                self._output_shapes[node.name] = [None]
1284
1285            self._outputs_are_0d[node.name] = [ \
1286                not tshape if isinstance(tshape, list) else False \
1287                for tshape in self._output_shapes[node.name]]
1288
1289            if node.op == "Const":
1290                # All Const nodes are Param nodes, lets parse
1291                self._num_param += 1
1292                for key, value in node.attr.items():
1293                    self._parse_param(key, value, node.name)
1294                if node.name not in self._nodes:
1295                    raise NotImplementedError( \
1296                        "Const {} couldn't be converted to Param.".format(node.name))
1297
1298                attr = self._parse_attr(node.attr)
1299
1300            elif node.op != "Placeholder":
1301                # Pass the parsed shapes instead
1302                attr["_output_shapes"] = output_shapes = self._output_shapes[node.name]
1303
1304                # Pass the node name too in attr
1305                attr["_node_name"] = node.name
1306
1307                # Pass the target layout
1308                attr["_target_layout"] = layout
1309
1310                # Fill shapes for all inputs in a list
1311                inputs = []
1312                for i in node.input:
1313                    # Some TensorFlow operators internally maintain execution layers
1314                    # and their output name includes the layer number along with
1315                    # graph node name. E.g. the node name is 'Model/RNN/cell_0/RnnCell', but the
1316                    # output tensor name is 'Model/RNN/cell_0/RnnCell:0'. In this case,
1317                    # the number has to be ignored for single-output nodes.
1318                    # On the other hand, for multi-output nodes the number is the output index,
1319                    # and the lack of the number implies 0.
1320                    tensor_name = i.split(':')
1321                    node_name = tensor_name[0]
1322                    if node_name in self._nodes:
1323                        in_sym = self._nodes[node_name]
1324                        if len(in_sym.list_output_names()) > 1:
1325                            tensor_slot = int(tensor_name[1]) if len(tensor_name) > 1 else 0
1326                            in_sym = in_sym[tensor_slot]
1327                            input_shape = self._output_shapes[node_name][tensor_slot]
1328                        else:
1329                            tensor_slot = 0
1330                            input_shape = self._output_shapes[node_name][0]
1331                        inputs.append(in_sym)
1332                        input_shapes[in_sym] = input_shape
1333                        # This means the node is 1d in NNVM and 0d in TF.
1334                        # See `_expand_dims_0d_aware`.
1335                        if self._outputs_are_0d[node_name][tensor_slot] and input_shape:
1336                            input_0d_mismatch.add(in_sym)
1337                attr['_input_shapes'] = input_shapes
1338                attr['_input_0d_mismatch'] = input_0d_mismatch
1339
1340                inputs = self._fix_extranodes(node.op, attr, inputs)
1341                op = self._convert_operator(node.op, inputs, attr, graph)
1342
1343                # Check if op is converted to param
1344                if isinstance(op, np.ndarray):
1345                    self._params[node.name] = tvm.nd.array(op)
1346                    op = _sym.Variable(name=node.name,
1347                                       shape=self._params[node.name].shape)
1348
1349                # Assuming only one output.
1350                self._nodes[node.name] = op
1351                final_op = op
1352
1353                # Infer shapes even without specifying "add_shapes=True"
1354                if output_shapes == [None]:
1355                    g = _graph.create(final_op)
1356                    self._output_shapes[node.name] = \
1357                        list(graph_util.infer_shape(g, **self._input_shapes))[-1]
1358
1359                if self._output_shapes[node.name] and shape and node.name in shape:
1360                    assert self._output_shapes[node.name] == list(shape[node.name])
1361
1362            # Infer shapes if passed explicitely
1363            node_output = self._nodes[node.name]
1364            if shape and (not self._output_shapes[node.name][0]
1365                          or -1 in self._output_shapes[node.name][0]):
1366                g = _graph.create(node_output)
1367                shape_dict = {k: v.shape for k, v in self._params.items()}
1368                shape_dict.update(shape)
1369                _, out_shapes = graph_util.infer_shape(g, **shape_dict)
1370                self._output_shapes[node.name] = out_shapes
1371
1372        out = []
1373        if outputs is None:
1374            out.append(final_op)
1375        else:
1376            for out_name in outputs:
1377                if ":" in out_name:
1378                    out_name, out_num = out_name.split(":")
1379                    out_num = int(out_num)
1380                    out.append(self._nodes[out_name][out_num])
1381                else:
1382                    out.append(self._nodes[out_name])
1383
1384        #Add the RNN outputs also with 'head' nodes of the nnvm graph
1385        if self._num_rnn_layer:
1386            out_rnn = _sym.concatenate(*self._out_rnn, axis=0)
1387            out.append(out_rnn)
1388
1389        if isinstance(out, list):
1390            out = _sym.Group(out) if len(out) > 1 else out[0]
1391
1392        return out, self._params
1393
1394    def _parse_import_prerequisites(self, graph):
1395        """ Calculate the named preconditions from TensorFlow `graph`.
1396            Return prerequisites for parsing:
1397            a. Set of operator names which don't have their mapping in TVM, i.e.
1398                which are not supported
1399        """
1400        missing_operators = set()
1401        for node in graph.node:
1402            if node.op == "Placeholder":
1403                pass
1404            elif node.op == "Const":
1405                pass
1406            else:
1407                if any([node.op in t for t in [_identity_list, _convert_map, _convert_map_rnn]]):
1408                    pass
1409                else:
1410                    missing_operators.add(node.op)
1411
1412        return missing_operators
1413
1414    def _parse_param(self, key, value, name):
1415        try:
1416            from tensorflow.python.framework import tensor_util
1417        except ImportError as e:
1418            raise ImportError(
1419                "Unable to import tensorflow which is required {}".format(e))
1420
1421        if key == 'value':
1422            np_array = tensor_util.MakeNdarray(value.tensor)
1423
1424            if np_array.dtype == np.dtype(object):
1425                # Object types are generally tensorflow DT_STRING (DecodeJpeg op).
1426                # Just leave it as placeholder.
1427                self._nodes[name] = _sym.Variable(name=name)
1428                return
1429
1430            array_ndim = len(np_array.shape)
1431            if array_ndim == 0:
1432                new_array = np.empty([1], dtype=np_array.dtype)
1433                new_array[0] = np_array
1434                self._params[name] = tvm.nd.array(new_array)
1435            else:
1436                self._params[name] = tvm.nd.array(np_array)
1437            self._nodes[name] = _sym.Variable(name=name,
1438                                              shape=self._params[name].shape)
1439        else:
1440            if key not in ('dtype', '_output_shapes', '_class'):
1441                raise NotImplementedError \
1442                    ("Other attributes for a Const(param) Node {} ? .".format(key))
1443
1444    def _get_attr(self, buf):
1445        """Returns the value of the attr of this buf with the given `name`.
1446
1447        Args:
1448          buf: attrvalue protobuf.
1449
1450        Returns:
1451          The value of the attr, as a Python object.
1452
1453        Raises:
1454          ValueError: If this op does not have an attr with the given `name`.
1455        """
1456        fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
1457
1458        x = buf
1459
1460        ret = []
1461
1462        try:
1463            from tensorflow.python.framework import dtypes
1464        except ImportError as e:
1465            raise ImportError(
1466                "Unable to import tensorflow which is required {}".format(e))
1467
1468        # Treat an empty oneof value as an empty list.
1469        if not x.WhichOneof("value"):
1470            return ret
1471        if x.HasField("list"):
1472            for f in fields:
1473                if getattr(x.list, f):
1474                    if f == "type":
1475                        ret += [dtypes.as_dtype(x) for x in list(getattr(x.list, f))]
1476                    else:
1477                        ret += list(getattr(x.list, f))
1478        else:
1479            for f in fields:
1480                if x.HasField(f):
1481                    if f == "type":
1482                        ret = dtypes.as_dtype(getattr(x, f))
1483                    else:
1484                        ret = getattr(x, f)
1485        return ret
1486
1487    def _parse_attr(self, attr_proto):
1488        """Convert a list of AttributeProto to a dict, with names as keys."""
1489        attrs = {}
1490        for key, value in attr_proto.items():
1491            attrs[key] = self._get_attr(value)
1492
1493        return attrs
1494
1495    def _convert_rnn_operator(self, op_name, inputs,
1496                              attrs, params, graph, convert_map):
1497        """Convert RNN and its variant operators to NNVM operators.
1498        This converter read the input states of each layers and
1499        also maintain the output states of each layer in a list.
1500
1501        Parameters
1502        ----------
1503        op_name : str
1504            Operator name, such as LSTMBlockCell
1505        inputs : list of nnvm.Symbol
1506            List of input symbols.
1507        attrs : dict
1508            Dict of operator attributes
1509        params : dict
1510            List of pretrained weights and bias
1511        graph : Tensorflow graph object
1512            Graph is to find the number of upcoming same operator to
1513            calculate the number of layers.
1514        convert_map : dict
1515            Dict of name : callable, where name is the op's name that
1516            require conversion to nnvm, callable are functions which
1517            take attrs and return (new_op_name, new_attrs)
1518
1519        Returns
1520        -------
1521        sym : nnvm.Symbol
1522            Converted nnvm Symbol
1523        """
1524        if not self._num_rnn_layer:
1525            self._out_rnn = []
1526            self.rnn = RecurrentNetworks(self._nodes, self._out_rnn, graph, convert_map)
1527            self._num_rnn_layer = True
1528        sym = self.rnn.process_op(op_name, inputs, attrs, params)
1529        return sym
1530
1531    def _convert_operator(self, op_name, inputs, attrs,
1532                          graph, identity_list=None, convert_map=None):
1533        """Convert from Tensorflow operator to nnvm operator.
1534        The converter must specify conversions explicitly for incompatible name, and
1535        apply handlers to operator attributes.
1536
1537        Parameters
1538        ----------
1539        op_name : str
1540            Operator name, such as Conv2D, AvgPool
1541        inputs : list of nnvm.Symbol
1542            List of input symbols.
1543        attrs : dict
1544            Dict of operator attributes
1545        identity_list : list
1546            List of operators that don't require conversion
1547        convert_map : dict
1548            Dict of name : callable, where name is the op's name that
1549            require conversion to nnvm, callable are functions which
1550            take attrs and return (new_op_name, new_attrs)
1551
1552        Returns
1553        -------
1554        sym : nnvm.Symbol
1555            Converted nnvm Symbol
1556        """
1557        identity_list = identity_list if identity_list else _identity_list
1558        convert_map = convert_map if convert_map else _convert_map
1559        convert_map_rnn = _convert_map_rnn
1560        if op_name in identity_list:
1561            sym = get_nnvm_op(op_name)(*inputs, **attrs)
1562        elif op_name in convert_map:
1563            sym = convert_map[op_name](inputs, attrs, self._params)
1564        elif op_name in convert_map_rnn:
1565            sym = self._convert_rnn_operator(op_name, inputs, attrs,
1566                                             self._params, graph,
1567                                             convert_map_rnn)
1568        else:
1569            raise tvm.error.OpNotImplemented(
1570                'Operator {} is not supported in frontend TensorFlow.'.format(op_name))
1571        return sym
1572
1573    def _fix_extranodes(self, op_name, attr, inputs):
1574        if op_name == "Softmax":
1575            # Require some times flatten of data before it goes to softmax
1576            # Need to relook into this with latest softmax axis support.
1577            op = AttrCvt(op_name='flatten')(inputs, {})
1578            node_output = op.list_output_names()
1579            for k, i in zip(list(node_output), range(len(node_output))):
1580                self._nodes[k] = op[i]
1581            inputs = [op]
1582
1583        return inputs
1584
1585def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
1586    """Load tensorflow graph which is a python tensorflow graph object into nnvm graph.
1587    The companion parameters will be handled automatically.
1588
1589    Parameters
1590    ----------
1591    graph : GraphDef object
1592        Tensorflow GraphDef
1593
1594    layout : target layout to be used (Optional)
1595        NCHW only supported now to enable NHWC models on GPU.
1596
1597    shape : Dictionary of input dimensions (Optional)
1598        Graph level input shape dictionary.
1599
1600    outputs : List of output tensor names (Optional)
1601        if not specified then the last node is assumed as graph output.
1602
1603    Returns
1604    -------
1605    sym : nnvm.Symbol
1606        Compatible nnvm symbol
1607
1608    params : dict of str to tvm.ndarray
1609        Dict of converted parameters stored in tvm.ndarray format
1610    """
1611    g = GraphProto()
1612    sym, params = g.from_tensorflow(graph, layout, shape, outputs)
1613    return sym, params
1614