1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines
18"""ONNX: Open Neural Network Exchange frontend."""
19from __future__ import absolute_import as _abs
20import numpy as np
21import tvm
22from .. import symbol as _sym
23from .common import get_nnvm_op, Renamer, SymbolTable, AttrConverter as AttrCvt
24from .onnx_caffe2_utils import dimension_picker, dimension_constraint, \
25    infer_channels, revert_caffe2_pad
26
27__all__ = ['from_onnx']
28
29
30def onnx_storage_order2layout(storage_order):
31    if storage_order not in (0, 1):
32        raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1')
33
34    return 'NCHW' if storage_order == 0 else 'NHWC'
35
36
37class OnnxOpConverter(object):
38    """ A helper class for holding onnx op converters.
39    """
40
41    @classmethod
42    def get_converter(cls, opset):
43        """ Get converter matches given opset.
44
45        :param opset: opset from model.
46        :return: converter, which should be `_impl_vx`. Number x is the biggest
47            number smaller than or equal to opset belongs to all support versions.
48        """
49        versions = [
50            int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d
51        ]
52        versions = sorted(versions + [opset])
53        version = versions[
54            max([i for i, v in enumerate(versions) if v == opset]) - 1]
55        if hasattr(cls, '_impl_v{}'.format(version)):
56            return getattr(cls, '_impl_v{}'.format(version))
57        raise NotImplementedError(
58            'opset version {} of {} not implemented'.format(
59                version, cls.__name__))
60
61
62class Elemwise(OnnxOpConverter):
63    """ A helper class for elemwise op converters.
64    """
65
66    name = ''
67
68    @classmethod
69    def _math_name_picker(cls, suffix):
70
71        def _impl(attr):
72            if attr.get('broadcast', 0):
73                return 'broadcast_' + suffix
74            return 'elemwise_' + suffix
75
76        return _impl
77
78    @classmethod
79    def _impl_v1(cls, inputs, attr, params):
80        assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(
81            len(inputs))
82        op_name = cls._math_name_picker(cls.name)(attr)
83        axis = int(attr.get('axis', 0))
84        conv_ops = ["conv2d", "conv2d_transpose"]
85        if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops:
86            # TODO(zhreshold): remove hard coded infershape
87            inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2)
88        return get_nnvm_op(op_name)(*inputs)
89
90
91class Pool(OnnxOpConverter):
92    """ A helper class for pool op converters.
93    """
94
95    name = ''
96
97    @classmethod
98    def _impl_v1(cls, inputs, attr, params):
99        return AttrCvt(
100            op_name=dimension_picker(cls.name),
101            transforms={
102                'kernel_shape': 'pool_size',
103                'pads': ('padding', (0, 0), revert_caffe2_pad)
104            },
105            # very weird attributes here in onnx, force check
106            ignores=['dilations'],
107            # TODO(zhreshold): make sure ceil_mode in onnx, and layout?
108            extras={'ceil_mode': False},
109            custom_check=dimension_constraint())(inputs, attr, params)
110
111
112class Absolute(OnnxOpConverter):
113
114    @classmethod
115    def _impl_v1(cls, inputs, attr, params):
116        return _sym.relu(inputs[0]) + _sym.relu(_sym.negative(inputs[0]))
117
118
119class Add(Elemwise):
120    name = 'add'
121
122
123class AveragePool(Pool):
124    name = 'avg_pool'
125
126
127class BatchNorm(OnnxOpConverter):
128
129    @classmethod
130    def _impl_v1(cls, inputs, attr, params):
131        # TODO(zhreshold): 'spatial' is not properly handled here.
132        return AttrCvt(
133            op_name='batch_norm',
134            disables=['momentum'],
135            ignores=['spatial', 'is_test', 'consumed_inputs'])(inputs, attr,
136                                                               params)
137
138
139class Conv(OnnxOpConverter):
140
141    @classmethod
142    def _impl_v1(cls, inputs, attr, params):
143        # get number of channels
144        channels = infer_channels(inputs[1], params)
145        attr['channels'] = channels
146        return AttrCvt(
147            op_name=dimension_picker('conv'),
148            transforms={
149                'kernel_shape': 'kernel_size',
150                'dilations': ('dilation', (0, 0)),
151                'pads': ('padding', (0, 0), revert_caffe2_pad),
152                'group': ('groups', 1)
153            },
154            extras={'use_bias': len(inputs) == 3},
155            custom_check=dimension_constraint())(inputs, attr, params)
156
157
158class ConvTranspose(OnnxOpConverter):
159
160    @classmethod
161    def _impl_v1(cls, inputs, attr, params):
162        # get number of channels
163        channels = infer_channels(inputs[1], params, True)
164        attr['channels'] = channels
165        groups = attr.pop('group')
166        attr['groups'] = groups
167        return AttrCvt(
168            op_name=dimension_picker('conv', '_transpose'),
169            transforms={
170                'kernel_shape': 'kernel_size',
171                'dilations': ('dilation', (0, 0)),
172                'pads': ('padding', (0, 0), revert_caffe2_pad)
173            },
174            disables=['output_shape'],
175            extras={'use_bias': len(inputs) == 3},
176            custom_check=dimension_constraint())(inputs, attr, params)
177
178
179class Div(Elemwise):
180    name = 'div'
181
182
183class Elu(OnnxOpConverter):
184
185    @classmethod
186    def _impl_v1(cls, inputs, attr, params):
187        alpha = float(attr.get('alpha', 1.0))
188        return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(
189            inputs[0])
190
191
192class Gemm(OnnxOpConverter):
193    """ Operator converter for Gemm.
194    """
195
196    @classmethod
197    def _impl_v1(cls, inputs, attr, params):
198        assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format(
199            len(inputs))
200        # Y = alpha * A * B + beta * C
201        alpha = float(attr.get('alpha', 1.0))
202        beta = float(attr.get('beta', 1.0))
203        transA = int(attr.get('transA', 0))
204        transB = int(attr.get('transB', 0))
205        # get number of channels
206        channels = infer_channels(inputs[1], params, not transB)
207        if transA:
208            inputs[0] = _sym.transpose(inputs[0], axes=(1, 0))
209        if not transB:
210            inputs[1] = _sym.transpose(inputs[1], axes=(1, 0))
211        inputs[0] = _sym.flatten(inputs[0])
212        return _sym.dense(
213            alpha * inputs[0], inputs[1], beta * inputs[2], units=channels)
214
215
216class MaxPool(Pool):
217    """ Operator converter for MaxPool
218    """
219    name = 'max_pool'
220
221    @classmethod
222    def _impl_v8(cls, inputs, attr, params):
223        return AttrCvt(
224            op_name=dimension_picker(cls.name),
225            transforms={
226                'kernel_shape': 'pool_size',
227                'pads': ('padding', (0, 0), revert_caffe2_pad),
228                'storage_order': ('layout', 'NCHW', onnx_storage_order2layout),
229            },
230            # very weird attributes here in onnx, force check
231            ignores=['dilations', 'auto_pad'],
232            # TODO(higumachan): make sure ceil_mode in onnx, and layout?
233            extras={'ceil_mode': False},
234            custom_check=dimension_constraint())(inputs, attr, params)
235
236    @classmethod
237    def _impl_v10(cls, inputs, attr, params):
238        return AttrCvt(
239            op_name=dimension_picker(cls.name),
240            transforms={
241                'kernel_shape': 'pool_size',
242                'pads': ('padding', (0, 0), revert_caffe2_pad),
243                'storage_order': ('layout', 'NCHW', onnx_storage_order2layout),
244                'ceil_mode': 'ceil_mode'
245            },
246            # very weird attributes here in onnx, force check
247            ignores=['dilations', 'auto_pad'],
248            custom_check=dimension_constraint())(inputs, attr, params)
249
250class Mul(Elemwise):
251    name = 'mul'
252
253
254class Pad(OnnxOpConverter):
255    """ Operator converter for Pad.
256    """
257
258    @classmethod
259    def _impl_v1(cls, inputs, attr, params):
260        pad_width = []
261        pads = attr.pop('paddings')
262        dims = int(len(pads) / 2)
263        for i in range(dims):
264            pad_width.append((pads[i], pads[i+dims]))
265        attr['pad_width'] = pad_width
266
267        return AttrCvt(
268            op_name='pad',
269            transforms={
270                'value': 'pad_value',
271            },
272            ignores=['mode'],
273            custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant',
274                          'split mode != constant'))(inputs, attr, params)
275
276    @classmethod
277    def _impl_v2(cls, inputs, attr, params):
278        pad_width = []
279        pads = attr.pop('pads')
280        dims = int(len(pads) / 2)
281        for i in range(dims):
282            pad_width.append((pads[i], pads[i+dims]))
283        attr['pad_width'] = pad_width
284
285        return AttrCvt(
286            op_name='pad',
287            transforms={
288                'value': 'pad_value',
289            },
290            ignores=['mode'],
291            custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant',
292                          'split mode != constant'))(inputs, attr, params)
293
294
295class ParametricSoftPlus(OnnxOpConverter):
296
297    @classmethod
298    def _impl_v1(cls, inputs, attr, params):
299        alpha = float(attr.get('alpha', 1.0))
300        beta = float(attr.get('beta', 1.0))
301        return _sym.log(_sym.exp(beta * inputs[0]) + 1) * alpha
302
303
304class Prelu(OnnxOpConverter):
305
306    @classmethod
307    def _impl_v1(cls, inputs, attr, params):
308        assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format(
309            len(inputs))
310        return _sym.prelu(inputs[0], inputs[1])
311
312
313class Reciprocal(OnnxOpConverter):
314
315    @classmethod
316    def _impl_v1(cls, inputs, attr, params):
317        return 1.0 / inputs[0]
318
319
320class Reshape(OnnxOpConverter):
321    """ Operator converter for Reshape.
322    """
323
324    @classmethod
325    def _impl_v1(cls, inputs, attr, params):
326        return _sym.reshape(inputs[0], shape=attr['shape'])
327
328    @classmethod
329    def _impl_v5(cls, inputs, attr, params):
330        if inputs[1].list_output_names()[0] in params:
331            shape = tuple(params[inputs[1].list_output_names()[0]].asnumpy())
332            out = _sym.reshape(inputs[0], shape=shape)
333        else:
334            out = _sym.reshape_like(inputs[0], inputs[1])
335
336        return out
337
338class Scale(OnnxOpConverter):
339
340    @classmethod
341    def _impl_v1(cls, inputs, attr, params):
342        scale = float(attr.get('scale', 1.0))
343        return inputs[0] * scale
344
345
346class Selu(OnnxOpConverter):
347
348    @classmethod
349    def _impl_v1(cls, inputs, attr, params):
350        alpha = float(attr.get('alpha', 1.6732))
351        gamma = float(attr.get('gamma', 1.0507))
352        return gamma * (
353            -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0]))
354
355
356class ScaledTanh(OnnxOpConverter):
357
358    @classmethod
359    def _impl_v1(cls, inputs, attr, params):
360        alpha = float(attr.get('alpha', 1.0))
361        beta = float(attr.get('beta', 1.0))
362        return _sym.tanh(beta * inputs[0]) * alpha
363
364
365class SoftPlus(OnnxOpConverter):
366
367    @classmethod
368    def _impl_v1(cls, inputs, attr, params):
369        return _sym.log(_sym.exp(inputs[0]) + 1)
370
371
372class Softsign(OnnxOpConverter):
373
374    @classmethod
375    def _impl_v1(cls, inputs, attr, params):
376        return inputs[0] / (1 + Absolute.get_converter(1)(inputs, attr, params))
377
378
379class Sub(Elemwise):
380    name = 'sub'
381
382
383class Sum(OnnxOpConverter):
384
385    @classmethod
386    def _impl_v1(cls, inputs, attr, params):
387        # Onnx Sum Operator
388        for in_index in range(len(inputs) - 1):
389            inputs[in_index + 1] = _sym.broadcast_add(inputs[in_index],
390                                                      inputs[in_index + 1])
391
392        return inputs[len(inputs) - 1]
393
394
395class ThresholdedRelu(OnnxOpConverter):
396
397    @classmethod
398    def _impl_v1(cls, inputs, attr, params):
399        alpha = float(attr.get('alpha', 1.0))
400        alpha_tensor = _sym.full_like(inputs[0], fill_value=float(alpha))
401        return _sym.elemwise_mul(inputs[0], _sym.greater(inputs[0], alpha_tensor))
402
403class ImageScaler(OnnxOpConverter):
404
405    @classmethod
406    def _impl_v1(cls, inputs, attr, params):
407        channelScale = attr['scale']
408        bias_attr = attr['bias']
409        bias = SymbolTable().new_const(np.array(bias_attr).reshape([3, 1, 1]))
410        scaledChannel = _sym.__mul_scalar__(inputs[0], scalar=channelScale)
411        ret = _sym.broadcast_add(scaledChannel, bias)
412        return ret
413
414
415def _broadcast_constraint():
416
417    def _broadcast_check(attrs):
418        if attrs.get('axis', None):
419            return False
420        return True
421
422    return _broadcast_check, "Specifying broadcast axis not allowed."
423
424
425def _fully_connected(opset):
426
427    def _impl(inputs, attr, params):
428        # get number of channels
429        channels = infer_channels(inputs[1], params)
430        attr['units'] = channels
431        return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr)
432
433    return _impl
434
435
436class Upsample(OnnxOpConverter):
437    """ Operator converter for Upsample (nearest mode).
438    """
439
440    @classmethod
441    def _impl_v9(cls, inputs, attr, params):
442        scales = attr.get('scales')
443        if not scales:
444            #Here we are going to higher OPSET version.
445            assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs))
446            input_name = inputs[1].list_input_names()[0]
447            scales = params[input_name].asnumpy()
448            inputs = inputs[:1]
449        assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3]
450        mode = attr.get('mode')
451        if mode == b'nearest':
452            method = "NEAREST_NEIGHBOR"
453        elif mode == b'linear':
454            method = "BILINEAR"
455        else:
456            raise tvm.error.OpAttributeInvalid(
457                'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
458        return _sym.upsampling(inputs[0], scale=int(scales[-1]), method=method, layout='NCHW')
459
460
461class Shape(OnnxOpConverter):
462    """ Operator converter for Shape.
463    """
464
465    @classmethod
466    def _impl_v1(cls, inputs, attr, params):
467        # Result of this operator is prominently used by reshape operator.
468        # Just pass the input as it is so that reshape_like can be used there.
469        print("Shape: Differently implemented in NNVM as a bypass (dummy operator)")
470        return inputs[0]
471
472class Cast(OnnxOpConverter):
473    """ Operator converter for Cast.
474    """
475
476    @classmethod
477    def _impl_v1(cls, inputs, attr, params):
478        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)
479
480    @classmethod
481    def _impl_v5(cls, inputs, attr, params):
482        try:
483            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
484            attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']]
485        except ImportError as e:
486            raise ImportError(
487                "Unable to import onnx.mapping which is required {}".format(e))
488        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)
489
490
491class Unsqueeze(OnnxOpConverter):
492    """ Operator converter for Unsqueeze.
493    """
494
495    @classmethod
496    def _impl_v1(cls, inputs, attr, params):
497        for axes in attr['axes']:
498            inputs[0] = _sym.expand_dims(inputs[0], axis=axes, num_newaxis=1)
499        return inputs[0]
500
501
502class Split(OnnxOpConverter):
503    """ Operator converter for Split.
504    """
505
506    @classmethod
507    def _impl_v1(cls, inputs, attr, params):
508        attr['indices_or_sections'] = []
509        index = 0
510        for i in attr['split'][:-1]:
511            index += i
512            attr['indices_or_sections'].append(index)
513        return AttrCvt(
514            op_name='split',
515            ignores=['split'])(inputs, attr, params)
516
517
518class Slice(OnnxOpConverter):
519    """ Operator converter for Slice.
520    """
521    @classmethod
522    def _impl_v1(cls, inputs, attr, params):
523        if isinstance(attr['starts'], int):
524            attr['starts'] = (attr['starts'],)
525            attr['ends'] = (attr['ends'],)
526
527        try:
528            # Update the starts and ends according to axes if required.
529            if isinstance(attr['axes'], int):
530                attr['axes'] = (attr['axes'],)
531
532            if (max(attr['axes']) + 1) != len(attr['axes']):
533                new_axes = []
534                new_starts = []
535                new_ends = []
536                pop_index = 0
537                for i in range(max(attr['axes']) + 1):
538                    if i in attr['axes']:
539                        new_axes.append(i)
540                        new_starts.append(attr['starts'][pop_index])
541                        new_ends.append(attr['ends'][pop_index])
542                        pop_index += 1
543                    else:
544                        new_axes.append(i)
545                        new_starts.append(0)
546                        new_ends.append(np.iinfo(np.int32).max)
547                attr['axes'] = new_axes
548                attr['starts'] = new_starts
549                attr['ends'] = new_ends
550        except KeyError:
551            pass
552
553        return AttrCvt(op_name='strided_slice',
554                       transforms={'starts': 'begin',
555                                   'ends': 'end'},
556                       ignores=['axes'])(inputs, attr)
557
558class Gather(OnnxOpConverter):
559    """ Operator converter for Gather.
560    """
561    @classmethod
562    def _impl_v1(cls, inputs, attr, params):
563        axis = attr.get('axis', 0)
564        return AttrCvt(op_name='take',
565                       extras={'axis':axis})(inputs, attr)
566
567class LRN(OnnxOpConverter):
568    """ Operator converter for Local Response Normalization.
569    """
570    @classmethod
571    def _impl_v1(cls, inputs, attr, params):
572        """LRN support only NCHW format
573        https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN
574        """
575        axis = 1
576        alpha = attr.get('alpha', 0.0001)
577        beta = attr.get('beta', 0.75)
578        bias = attr.get('bias', 1.0)
579        nsize = attr.get('size')
580        return _sym.lrn(inputs[0], size=nsize, axis=axis,
581                        alpha=alpha, beta=beta, bias=bias)
582
583class Maximum(OnnxOpConverter):
584    """ Operator converter for Maximum.
585    """
586    @classmethod
587    def _impl_v1(cls, inputs, attr, params):
588        if not isinstance(inputs, list) or len(inputs) < 2:
589            raise ValueError("Expect minimum 2 inputs")
590        _max = inputs[0]
591        for i in range(1, len(inputs)):
592            _max = AttrCvt(op_name='broadcast_max')([_max, inputs[i]], {})
593        return _max
594
595class Minimum(OnnxOpConverter):
596    """ Operator converter for Minimum.
597    """
598    @classmethod
599    def _impl_v1(cls, inputs, attr, params):
600        if not isinstance(inputs, list) or len(inputs) < 2:
601            raise ValueError("Expect minimum 2 inputs")
602        _min = inputs[0]
603        for i in range(1, len(inputs)):
604            _min = AttrCvt(op_name='broadcast_min')([_min, inputs[i]], {})
605        return _min
606
607class Mean(OnnxOpConverter):
608    """ Operator converter for Mean.
609    """
610    @classmethod
611    def _impl_v1(cls, inputs, attr, params):
612        if not isinstance(inputs, list) or len(inputs) < 2:
613            raise ValueError("Expect minimum 2 inputs")
614        count = len(inputs)
615        _sum = inputs[0]
616        for i in range(1, count):
617            _sum = AttrCvt(op_name='broadcast_add')([_sum, inputs[i]], {})
618        return _sum / count
619
620class HardSigmoid(OnnxOpConverter):
621    """ Operator converter for HardSigmoid.
622    """
623    @classmethod
624    def _impl_v1(cls, inputs, attr, params):
625        alpha = attr.get('alpha', 0.2)
626        beta = attr.get('beta', 0.5)
627        transformX = (inputs[0] * alpha) + beta
628        attr = {'a_min':0, 'a_max':1}
629        return AttrCvt(op_name='clip')([transformX], attr)
630
631class ArgMax(OnnxOpConverter):
632    """ Operator converter for ArgMax.
633    """
634    @classmethod
635    def _impl_v1(cls, inputs, attr, params):
636        axis = attr.get('axis', 0)
637        keepdims = attr.get('keepdims', True)
638        attr = {'axis':axis, 'keepdims':keepdims}
639        return AttrCvt(op_name='argmax')(inputs, attr)
640
641class ArgMin(OnnxOpConverter):
642    """ Operator converter for ArgMin.
643    """
644    @classmethod
645    def _impl_v1(cls, inputs, attr, params):
646        axis = attr.get('axis', 0)
647        keepdims = attr.get('keepdims', True)
648        attr = {'axis':axis, 'keepdims':keepdims}
649        return AttrCvt(op_name='argmin')(inputs, attr)
650
651class Softmax(OnnxOpConverter):
652    """ Operator converter for Softmax.
653    """
654    @classmethod
655    def _impl_v1(cls, inputs, attr, params):
656        # set default value when axis is not set in the model
657        if 'axis' not in attr:
658            attr['axis'] = 1
659        return AttrCvt(
660            op_name='softmax',
661            transforms={
662                'axis': ('axis', 1),
663            })(inputs, attr, params)
664
665class ConstantFill(OnnxOpConverter):
666    """ Operator converter for ConstantFill.
667    """
668    @classmethod
669    def _impl_v1(cls, inputs, attr, params):
670        is_full = True
671        num_inputs = len(inputs)
672        if 'shape' in attr:
673            if num_inputs > 0:
674                raise ImportError(
675                    "Can't set shape and input tensor at a time")
676            shape = attr.pop('shape')
677        else:
678            if num_inputs == 0:
679                raise ImportError(
680                    "Either shape attribute or input should be set")
681            if 'input_as_shape' in attr and attr['input_as_shape']:
682                shape = params[inputs[0].list_output_names()[0]].asnumpy()
683            else:
684                is_full = False
685
686        if not is_full:
687            if 'extra_shape' in attr:
688                raise ImportError(
689                    "Extra Shape not supported with fill_like")
690
691            out = AttrCvt(
692                op_name='full_like',
693                transforms={'value': 'fill_value'},
694                ignores=['dtype'])(inputs, attr)
695            return _sym.cast(out, dtype=attr['dtype'].decode("utf-8"))
696        if 'extra_shape' in attr:
697            shape = shape + attr.pop('extra_shape')
698
699        return AttrCvt(
700            op_name='full',
701            transforms={'value': 'fill_value'},
702            extras={'shape':shape})(inputs, attr)
703
704# compatible operators that do NOT require any conversion.
705_identity_list = []
706
707
708# _convert_map defines maps of name to converter functor(callable)
709# for 1 to 1 mapping, use Renamer if nothing but name is different
710# use AttrCvt if attributes need to be converted
711# for 1 to N mapping(composed), use custom callable functions
712# for N to 1 mapping, currently not supported(?)
713def _get_convert_map(opset):
714    return {
715        # defs/experimental
716        'Identity': Renamer('copy'),
717        # 'Affine'
718        'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
719        'ScaledTanh': ScaledTanh.get_converter(opset),
720        'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
721        'ConstantFill': ConstantFill.get_converter(opset),
722        # 'GivenTensorFill'
723        'FC': AttrCvt('dense', ignores=['axis', 'axis_w']),
724        'Scale': Scale.get_converter(opset),
725        # 'GRUUnit'
726        # 'ATen'
727        'ImageScaler': ImageScaler.get_converter(opset),
728        # 'MeanVarianceNormalization'
729        # 'Crop'
730        # 'Embedding'
731        'Upsample' : Upsample.get_converter(opset),
732        'SpatialBN': BatchNorm.get_converter(opset),
733
734        # defs/generator
735        # 'Constant' # Implemented
736        # 'RandomUniform'
737        # 'RandomNormal'
738        # 'RandomUniformLike'
739        # 'RandomNormalLike'
740
741        # defs/logical
742
743        # defs/math
744        'Add': Add.get_converter(opset),
745        'Sub': Sub.get_converter(opset),
746        'Mul': Mul.get_converter(opset),
747        'Div': Div.get_converter(opset),
748        'Neg': Renamer('negative'),
749        'Abs': Absolute.get_converter(opset),
750        'Reciprocal': Reciprocal.get_converter(opset),
751        'Floor': Renamer('floor'),
752        'Ceil': Renamer('ceil'),
753        'Sqrt': Renamer('sqrt'),
754        'Relu': Renamer('relu'),
755        'LeakyRelu': Renamer('leaky_relu'),
756        'Selu': Selu.get_converter(opset),
757        'Elu': Elu.get_converter(opset),
758        'Exp': Renamer('exp'),
759        'Log': Renamer('log'),
760        'Tanh': Renamer('tanh'),
761        'Pow': Renamer('broadcast_pow'),
762        'PRelu': Prelu.get_converter(opset),
763        'Sigmoid': Renamer('sigmoid'),
764        'HardSigmoid': HardSigmoid.get_converter(opset),
765        'Max': Maximum.get_converter(opset),
766        'Min': Minimum.get_converter(opset),
767        'Sum': Sum.get_converter(opset),
768        'Mean': Mean.get_converter(opset),
769        'Clip': AttrCvt('clip', transforms={'min': 'a_min', 'max': 'a_max'}),
770        # softmax default axis is different in onnx
771        'Softmax': Softmax.get_converter(opset),
772        'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}),
773        # 'Hardmax'
774        'Softsign': Softsign.get_converter(opset),
775        'SoftPlus': SoftPlus.get_converter(opset),
776        'Gemm': Gemm.get_converter(opset),
777        'MatMul': Renamer('matmul'),
778
779        # defs/nn
780        'AveragePool': AveragePool.get_converter(opset),
781        'MaxPool': MaxPool.get_converter(opset),
782        'Conv': Conv.get_converter(opset),
783        'ConvTranspose': ConvTranspose.get_converter(opset),
784        'GlobalAveragePool': Renamer('global_avg_pool2d'),
785        'GlobalMaxPool': Renamer('global_max_pool2d'),
786        'BatchNormalization': BatchNorm.get_converter(opset),
787        # 'InstanceNormalization'
788        # 'LpNormalization'
789        'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
790        'Flatten': Renamer('flatten'),
791        'LRN': LRN.get_converter(opset),
792
793        # defs/reduction
794        'ReduceMax': AttrCvt('max', {'axes': 'axis'}),
795        'ReduceMin': AttrCvt('min', {'axes': 'axis'}),
796        'ReduceSum': AttrCvt('sum', {'axes': 'axis'}),
797        'ReduceMean': AttrCvt('mean', {'axes': 'axis'}),
798        # 'ReduceProd'
799        # 'ReduceLogSumExp'
800        'ArgMax': ArgMax.get_converter(opset),
801        'ArgMin': ArgMin.get_converter(opset),
802
803        # defs/tensor
804        'Cast': Cast.get_converter(opset),
805        'Reshape': Reshape.get_converter(opset),
806        'Concat': Renamer('concatenate'),
807        'Split': Split.get_converter(opset),
808        'Slice': Slice.get_converter(opset),
809        'Transpose': AttrCvt('transpose', {'perm': 'axes'}),
810        'Gather': Gather.get_converter(opset),
811        'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}),
812        'Unsqueeze': Unsqueeze.get_converter(opset),
813        'Pad': Pad.get_converter(opset),
814        'Shape': Shape.get_converter(opset),
815    }
816
817
818class GraphProto(object):
819    """A helper class for handling nnvm graph copying from pb2.GraphProto.
820    Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
821    """
822
823    def __init__(self):
824        self._nodes = {}
825        self._params = {}
826        self._renames = {}
827        self._num_input = 0
828        self._num_param = 0
829
830    def from_onnx(self, graph, opset):
831        """Construct nnvm nodes from onnx graph.
832        The inputs from onnx graph is vague, only providing "1", "2"...
833        For convenience, we rename the `real` input names to "input_0",
834        "input_1"... And renaming parameters to "param_0", "param_1"...
835
836        Parameters
837        ----------
838        graph : onnx protobuf object
839            The loaded onnx graph
840        opset : opset version
841
842        Returns
843        -------
844        sym : nnvm.sym.Symbol
845            The returned nnvm symbol
846        params : dict
847            A dict of name: tvm.nd.array pairs, used as pretrained weights
848        """
849        # parse network inputs to nnvm, aka parameters
850        for init_tensor in graph.initializer:
851            if not init_tensor.name.strip():
852                raise ValueError("Tensor's name is required.")
853            self._params[init_tensor.name] = self._parse_array(init_tensor)
854        for i in graph.input:
855            # from onnx v0.2, GraphProto.input has type ValueInfoProto,
856            #  and the name is 'i.name'
857            i_name = self._parse_value_proto(i)
858            if i_name in self._params:
859                # i is a param instead of input
860                self._num_param += 1
861                self._params[i_name] = self._params.pop(i_name)
862                self._nodes[i_name] = _sym.Variable(
863                    name=i_name, shape=self._params[i_name].shape)
864            else:
865                self._num_input += 1
866                self._nodes[i_name] = _sym.Variable(name=i_name)
867        # get list of unsupported ops
868        convert_map = _get_convert_map(opset)
869        unsupported_ops = set()
870        for node in graph.node:
871            op_name = node.op_type
872            if op_name not in convert_map and \
873               op_name != 'Constant' and \
874               op_name not in _identity_list:
875                unsupported_ops.add(op_name)
876        if unsupported_ops:
877            msg = 'The following operators are not supported for frontend ONNX: '
878            msg += ', '.join(unsupported_ops)
879            raise tvm.error.OpNotImplemented(msg)
880        # construct nodes, nodes are stored as directed acyclic graph
881        for node in graph.node:
882            op_name = node.op_type
883            attr = self._parse_attr(node.attribute)
884            inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
885            if op_name == "Constant":
886                t_proto = self._parse_attr(node.attribute)["value"]
887                self._num_param += 1
888                self._params[node.output[0]] = self._parse_array(t_proto)
889                self._nodes[node.output[0]] = _sym.Variable(name=node.output[0],
890                                                            shape=list(t_proto.dims))
891            else:
892                op = self._convert_operator(op_name, inputs, attr, opset)
893                node_output = self._fix_outputs(op_name, node.output)
894                assert len(node_output) == len(op.list_output_names()), (
895                    "Number of output mismatch {} vs {} in {}.".format(
896                        len(node_output), len(op.list_output_names()), op_name))
897                for k, i in zip(list(node_output), range(len(node_output))):
898                    self._nodes[k] = op[i]
899        # now return the outputs
900        out = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
901        if len(out) > 1:
902            out = _sym.Group(out)
903        else:
904            out = out[0]
905        return out, self._params
906
907    def _parse_value_proto(self, value_proto):
908        """Parse ValueProto or raw str."""
909        try:
910            name = value_proto.name
911        except AttributeError:
912            name = value_proto
913        return name
914
915    def _parse_array(self, tensor_proto):
916        """Grab data in TensorProto and convert to numpy array."""
917        try:
918            from onnx.numpy_helper import to_array
919        except ImportError as e:
920            raise ImportError(
921                "Unable to import onnx which is required {}".format(e))
922        np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims))
923        return tvm.nd.array(np_array)
924
925    def _parse_attr(self, attr_proto):
926        """Convert a list of AttributeProto to a dict, with names as keys."""
927        attrs = {}
928        for a in attr_proto:
929            for f in ['f', 'i', 's']:
930                if a.HasField(f):
931                    attrs[a.name] = getattr(a, f)
932            for f in ['floats', 'ints', 'strings']:
933                if list(getattr(a, f)):
934                    assert a.name not in attrs, "Only one type of attr is allowed"
935                    attrs[a.name] = tuple(getattr(a, f))
936            for f in ['t']:
937                if a.HasField(f):
938                    attrs[a.name] = getattr(a, f)
939            for f in ['tensors']:
940                if list(getattr(a, f)):
941                    assert a.name not in attrs, "Only one type of attr is allowed"
942                    attrs[a.name] = tuple(getattr(a, f))
943            for f in ['g']:
944                if a.HasField(f):
945                    raise NotImplementedError(
946                        "Filed {} is not supported in nnvm.".format(f))
947            for f in ['graphs']:
948                if list(getattr(a, f)):
949                    raise NotImplementedError(
950                        "Filed {} is not supported in nnvm.".format(f))
951            if a.name not in attrs:
952                raise ValueError("Cannot parse attribute: \n{}\n.".format(a))
953        return attrs
954
955    def _convert_operator(self,
956                          op_name,
957                          inputs,
958                          attrs,
959                          opset,
960                          identity_list=None,
961                          convert_map=None):
962        """Convert from onnx operator to nnvm operator.
963        The converter must specify conversions explicitly for incompatible name, and
964        apply handlers to operator attributes.
965
966        Parameters
967        ----------
968        op_name : str
969            Operator name, such as Convolution, FullyConnected
970        inputs : list of nnvm.Symbol
971            List of input symbols.
972        attrs : dict
973            Dict of operator attributes
974        opset : int
975            Opset version
976        identity_list : list
977            List of operators that don't require conversion
978        convert_map : dict
979            Dict of name : callable, where name is the op's name that
980            require conversion to nnvm, callable are functions which
981            take attrs and return (new_op_name, new_attrs)
982
983        Returns
984        -------
985        sym : nnvm.Symbol
986            Converted nnvm Symbol
987        """
988        identity_list = identity_list if identity_list else _identity_list
989        convert_map = convert_map if convert_map else _get_convert_map(opset)
990        if op_name in identity_list:
991            sym = get_nnvm_op(op_name)(*inputs, **attrs)
992        elif op_name in convert_map:
993            sym = convert_map[op_name](inputs, attrs, self._params)
994        else:
995            raise tvm.error.OpNotImplemented(
996                'Operator {} is not supported in frontend ONNX.')
997        return sym
998
999    def _fix_outputs(self, op_name, outputs):
1000        """A hack to handle dropout or similar operator that have more than one out
1001        in ONNX.
1002        """
1003        if op_name == 'Dropout':
1004            if len(outputs) == 1:
1005                return outputs
1006            # TODO(zhreshold): support dropout mask?
1007            outputs = outputs[:-1]
1008        return outputs
1009
1010
1011def from_onnx(model):
1012    """Load onnx graph which is a python protobuf object into nnvm graph.
1013    The companion parameters will be handled automatically.
1014    The inputs from onnx graph is vague, only providing "1", "2"...
1015    For convenience, we rename the `real` input names to "input_0",
1016    "input_1"... And renaming parameters to "param_0", "param_1"...
1017
1018    Parameters
1019    ----------
1020    model : protobuf object
1021        ONNX ModelProto after ONNX v1.1.0
1022
1023    Returns
1024    -------
1025    sym : nnvm.Symbol
1026        Compatible nnvm symbol
1027
1028    params : dict of str to tvm.ndarray
1029        Dict of converted parameters stored in tvm.ndarray format
1030    """
1031    g = GraphProto()
1032    graph = model.graph
1033    try:
1034        opset = model.opset_import[0].version if model.opset_import else 1
1035    except AttributeError:
1036        opset = 1
1037    sym, params = g.from_onnx(graph, opset)
1038    return sym, params
1039