1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines 18"""ONNX: Open Neural Network Exchange frontend.""" 19from __future__ import absolute_import as _abs 20import numpy as np 21import tvm 22from .. import symbol as _sym 23from .common import get_nnvm_op, Renamer, SymbolTable, AttrConverter as AttrCvt 24from .onnx_caffe2_utils import dimension_picker, dimension_constraint, \ 25 infer_channels, revert_caffe2_pad 26 27__all__ = ['from_onnx'] 28 29 30def onnx_storage_order2layout(storage_order): 31 if storage_order not in (0, 1): 32 raise tvm.error.OpAttributeInvalid('Mode of storage_order must be either 0 or 1') 33 34 return 'NCHW' if storage_order == 0 else 'NHWC' 35 36 37class OnnxOpConverter(object): 38 """ A helper class for holding onnx op converters. 39 """ 40 41 @classmethod 42 def get_converter(cls, opset): 43 """ Get converter matches given opset. 44 45 :param opset: opset from model. 46 :return: converter, which should be `_impl_vx`. Number x is the biggest 47 number smaller than or equal to opset belongs to all support versions. 48 """ 49 versions = [ 50 int(d.replace('_impl_v', '')) for d in dir(cls) if '_impl_v' in d 51 ] 52 versions = sorted(versions + [opset]) 53 version = versions[ 54 max([i for i, v in enumerate(versions) if v == opset]) - 1] 55 if hasattr(cls, '_impl_v{}'.format(version)): 56 return getattr(cls, '_impl_v{}'.format(version)) 57 raise NotImplementedError( 58 'opset version {} of {} not implemented'.format( 59 version, cls.__name__)) 60 61 62class Elemwise(OnnxOpConverter): 63 """ A helper class for elemwise op converters. 64 """ 65 66 name = '' 67 68 @classmethod 69 def _math_name_picker(cls, suffix): 70 71 def _impl(attr): 72 if attr.get('broadcast', 0): 73 return 'broadcast_' + suffix 74 return 'elemwise_' + suffix 75 76 return _impl 77 78 @classmethod 79 def _impl_v1(cls, inputs, attr, params): 80 assert len(inputs) == 2, "Math op take 2 inputs, {} given".format( 81 len(inputs)) 82 op_name = cls._math_name_picker(cls.name)(attr) 83 axis = int(attr.get('axis', 0)) 84 conv_ops = ["conv2d", "conv2d_transpose"] 85 if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops: 86 # TODO(zhreshold): remove hard coded infershape 87 inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2) 88 return get_nnvm_op(op_name)(*inputs) 89 90 91class Pool(OnnxOpConverter): 92 """ A helper class for pool op converters. 93 """ 94 95 name = '' 96 97 @classmethod 98 def _impl_v1(cls, inputs, attr, params): 99 return AttrCvt( 100 op_name=dimension_picker(cls.name), 101 transforms={ 102 'kernel_shape': 'pool_size', 103 'pads': ('padding', (0, 0), revert_caffe2_pad) 104 }, 105 # very weird attributes here in onnx, force check 106 ignores=['dilations'], 107 # TODO(zhreshold): make sure ceil_mode in onnx, and layout? 108 extras={'ceil_mode': False}, 109 custom_check=dimension_constraint())(inputs, attr, params) 110 111 112class Absolute(OnnxOpConverter): 113 114 @classmethod 115 def _impl_v1(cls, inputs, attr, params): 116 return _sym.relu(inputs[0]) + _sym.relu(_sym.negative(inputs[0])) 117 118 119class Add(Elemwise): 120 name = 'add' 121 122 123class AveragePool(Pool): 124 name = 'avg_pool' 125 126 127class BatchNorm(OnnxOpConverter): 128 129 @classmethod 130 def _impl_v1(cls, inputs, attr, params): 131 # TODO(zhreshold): 'spatial' is not properly handled here. 132 return AttrCvt( 133 op_name='batch_norm', 134 disables=['momentum'], 135 ignores=['spatial', 'is_test', 'consumed_inputs'])(inputs, attr, 136 params) 137 138 139class Conv(OnnxOpConverter): 140 141 @classmethod 142 def _impl_v1(cls, inputs, attr, params): 143 # get number of channels 144 channels = infer_channels(inputs[1], params) 145 attr['channels'] = channels 146 return AttrCvt( 147 op_name=dimension_picker('conv'), 148 transforms={ 149 'kernel_shape': 'kernel_size', 150 'dilations': ('dilation', (0, 0)), 151 'pads': ('padding', (0, 0), revert_caffe2_pad), 152 'group': ('groups', 1) 153 }, 154 extras={'use_bias': len(inputs) == 3}, 155 custom_check=dimension_constraint())(inputs, attr, params) 156 157 158class ConvTranspose(OnnxOpConverter): 159 160 @classmethod 161 def _impl_v1(cls, inputs, attr, params): 162 # get number of channels 163 channels = infer_channels(inputs[1], params, True) 164 attr['channels'] = channels 165 groups = attr.pop('group') 166 attr['groups'] = groups 167 return AttrCvt( 168 op_name=dimension_picker('conv', '_transpose'), 169 transforms={ 170 'kernel_shape': 'kernel_size', 171 'dilations': ('dilation', (0, 0)), 172 'pads': ('padding', (0, 0), revert_caffe2_pad) 173 }, 174 disables=['output_shape'], 175 extras={'use_bias': len(inputs) == 3}, 176 custom_check=dimension_constraint())(inputs, attr, params) 177 178 179class Div(Elemwise): 180 name = 'div' 181 182 183class Elu(OnnxOpConverter): 184 185 @classmethod 186 def _impl_v1(cls, inputs, attr, params): 187 alpha = float(attr.get('alpha', 1.0)) 188 return -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu( 189 inputs[0]) 190 191 192class Gemm(OnnxOpConverter): 193 """ Operator converter for Gemm. 194 """ 195 196 @classmethod 197 def _impl_v1(cls, inputs, attr, params): 198 assert len(inputs) == 3, "Gemm op take 3 inputs, {} given".format( 199 len(inputs)) 200 # Y = alpha * A * B + beta * C 201 alpha = float(attr.get('alpha', 1.0)) 202 beta = float(attr.get('beta', 1.0)) 203 transA = int(attr.get('transA', 0)) 204 transB = int(attr.get('transB', 0)) 205 # get number of channels 206 channels = infer_channels(inputs[1], params, not transB) 207 if transA: 208 inputs[0] = _sym.transpose(inputs[0], axes=(1, 0)) 209 if not transB: 210 inputs[1] = _sym.transpose(inputs[1], axes=(1, 0)) 211 inputs[0] = _sym.flatten(inputs[0]) 212 return _sym.dense( 213 alpha * inputs[0], inputs[1], beta * inputs[2], units=channels) 214 215 216class MaxPool(Pool): 217 """ Operator converter for MaxPool 218 """ 219 name = 'max_pool' 220 221 @classmethod 222 def _impl_v8(cls, inputs, attr, params): 223 return AttrCvt( 224 op_name=dimension_picker(cls.name), 225 transforms={ 226 'kernel_shape': 'pool_size', 227 'pads': ('padding', (0, 0), revert_caffe2_pad), 228 'storage_order': ('layout', 'NCHW', onnx_storage_order2layout), 229 }, 230 # very weird attributes here in onnx, force check 231 ignores=['dilations', 'auto_pad'], 232 # TODO(higumachan): make sure ceil_mode in onnx, and layout? 233 extras={'ceil_mode': False}, 234 custom_check=dimension_constraint())(inputs, attr, params) 235 236 @classmethod 237 def _impl_v10(cls, inputs, attr, params): 238 return AttrCvt( 239 op_name=dimension_picker(cls.name), 240 transforms={ 241 'kernel_shape': 'pool_size', 242 'pads': ('padding', (0, 0), revert_caffe2_pad), 243 'storage_order': ('layout', 'NCHW', onnx_storage_order2layout), 244 'ceil_mode': 'ceil_mode' 245 }, 246 # very weird attributes here in onnx, force check 247 ignores=['dilations', 'auto_pad'], 248 custom_check=dimension_constraint())(inputs, attr, params) 249 250class Mul(Elemwise): 251 name = 'mul' 252 253 254class Pad(OnnxOpConverter): 255 """ Operator converter for Pad. 256 """ 257 258 @classmethod 259 def _impl_v1(cls, inputs, attr, params): 260 pad_width = [] 261 pads = attr.pop('paddings') 262 dims = int(len(pads) / 2) 263 for i in range(dims): 264 pad_width.append((pads[i], pads[i+dims])) 265 attr['pad_width'] = pad_width 266 267 return AttrCvt( 268 op_name='pad', 269 transforms={ 270 'value': 'pad_value', 271 }, 272 ignores=['mode'], 273 custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant', 274 'split mode != constant'))(inputs, attr, params) 275 276 @classmethod 277 def _impl_v2(cls, inputs, attr, params): 278 pad_width = [] 279 pads = attr.pop('pads') 280 dims = int(len(pads) / 2) 281 for i in range(dims): 282 pad_width.append((pads[i], pads[i+dims])) 283 attr['pad_width'] = pad_width 284 285 return AttrCvt( 286 op_name='pad', 287 transforms={ 288 'value': 'pad_value', 289 }, 290 ignores=['mode'], 291 custom_check=(lambda attrs: attrs.get('mode', 'constant').decode("utf-8") == 'constant', 292 'split mode != constant'))(inputs, attr, params) 293 294 295class ParametricSoftPlus(OnnxOpConverter): 296 297 @classmethod 298 def _impl_v1(cls, inputs, attr, params): 299 alpha = float(attr.get('alpha', 1.0)) 300 beta = float(attr.get('beta', 1.0)) 301 return _sym.log(_sym.exp(beta * inputs[0]) + 1) * alpha 302 303 304class Prelu(OnnxOpConverter): 305 306 @classmethod 307 def _impl_v1(cls, inputs, attr, params): 308 assert len(inputs) == 2, "Prelu need 2 inputs, {} given".format( 309 len(inputs)) 310 return _sym.prelu(inputs[0], inputs[1]) 311 312 313class Reciprocal(OnnxOpConverter): 314 315 @classmethod 316 def _impl_v1(cls, inputs, attr, params): 317 return 1.0 / inputs[0] 318 319 320class Reshape(OnnxOpConverter): 321 """ Operator converter for Reshape. 322 """ 323 324 @classmethod 325 def _impl_v1(cls, inputs, attr, params): 326 return _sym.reshape(inputs[0], shape=attr['shape']) 327 328 @classmethod 329 def _impl_v5(cls, inputs, attr, params): 330 if inputs[1].list_output_names()[0] in params: 331 shape = tuple(params[inputs[1].list_output_names()[0]].asnumpy()) 332 out = _sym.reshape(inputs[0], shape=shape) 333 else: 334 out = _sym.reshape_like(inputs[0], inputs[1]) 335 336 return out 337 338class Scale(OnnxOpConverter): 339 340 @classmethod 341 def _impl_v1(cls, inputs, attr, params): 342 scale = float(attr.get('scale', 1.0)) 343 return inputs[0] * scale 344 345 346class Selu(OnnxOpConverter): 347 348 @classmethod 349 def _impl_v1(cls, inputs, attr, params): 350 alpha = float(attr.get('alpha', 1.6732)) 351 gamma = float(attr.get('gamma', 1.0507)) 352 return gamma * ( 353 -alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0])) 354 355 356class ScaledTanh(OnnxOpConverter): 357 358 @classmethod 359 def _impl_v1(cls, inputs, attr, params): 360 alpha = float(attr.get('alpha', 1.0)) 361 beta = float(attr.get('beta', 1.0)) 362 return _sym.tanh(beta * inputs[0]) * alpha 363 364 365class SoftPlus(OnnxOpConverter): 366 367 @classmethod 368 def _impl_v1(cls, inputs, attr, params): 369 return _sym.log(_sym.exp(inputs[0]) + 1) 370 371 372class Softsign(OnnxOpConverter): 373 374 @classmethod 375 def _impl_v1(cls, inputs, attr, params): 376 return inputs[0] / (1 + Absolute.get_converter(1)(inputs, attr, params)) 377 378 379class Sub(Elemwise): 380 name = 'sub' 381 382 383class Sum(OnnxOpConverter): 384 385 @classmethod 386 def _impl_v1(cls, inputs, attr, params): 387 # Onnx Sum Operator 388 for in_index in range(len(inputs) - 1): 389 inputs[in_index + 1] = _sym.broadcast_add(inputs[in_index], 390 inputs[in_index + 1]) 391 392 return inputs[len(inputs) - 1] 393 394 395class ThresholdedRelu(OnnxOpConverter): 396 397 @classmethod 398 def _impl_v1(cls, inputs, attr, params): 399 alpha = float(attr.get('alpha', 1.0)) 400 alpha_tensor = _sym.full_like(inputs[0], fill_value=float(alpha)) 401 return _sym.elemwise_mul(inputs[0], _sym.greater(inputs[0], alpha_tensor)) 402 403class ImageScaler(OnnxOpConverter): 404 405 @classmethod 406 def _impl_v1(cls, inputs, attr, params): 407 channelScale = attr['scale'] 408 bias_attr = attr['bias'] 409 bias = SymbolTable().new_const(np.array(bias_attr).reshape([3, 1, 1])) 410 scaledChannel = _sym.__mul_scalar__(inputs[0], scalar=channelScale) 411 ret = _sym.broadcast_add(scaledChannel, bias) 412 return ret 413 414 415def _broadcast_constraint(): 416 417 def _broadcast_check(attrs): 418 if attrs.get('axis', None): 419 return False 420 return True 421 422 return _broadcast_check, "Specifying broadcast axis not allowed." 423 424 425def _fully_connected(opset): 426 427 def _impl(inputs, attr, params): 428 # get number of channels 429 channels = infer_channels(inputs[1], params) 430 attr['units'] = channels 431 return AttrCvt('dense', ignores=['axis', 'axis_w'])(inputs, attr) 432 433 return _impl 434 435 436class Upsample(OnnxOpConverter): 437 """ Operator converter for Upsample (nearest mode). 438 """ 439 440 @classmethod 441 def _impl_v9(cls, inputs, attr, params): 442 scales = attr.get('scales') 443 if not scales: 444 #Here we are going to higher OPSET version. 445 assert len(inputs) == 2, "Upsample op take 2 inputs, {} given".format(len(inputs)) 446 input_name = inputs[1].list_input_names()[0] 447 scales = params[input_name].asnumpy() 448 inputs = inputs[:1] 449 assert len(scales) == 4 and scales[0] == 1.0 and scales[1] == 1.0 and scales[2] == scales[3] 450 mode = attr.get('mode') 451 if mode == b'nearest': 452 method = "NEAREST_NEIGHBOR" 453 elif mode == b'linear': 454 method = "BILINEAR" 455 else: 456 raise tvm.error.OpAttributeInvalid( 457 'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode)) 458 return _sym.upsampling(inputs[0], scale=int(scales[-1]), method=method, layout='NCHW') 459 460 461class Shape(OnnxOpConverter): 462 """ Operator converter for Shape. 463 """ 464 465 @classmethod 466 def _impl_v1(cls, inputs, attr, params): 467 # Result of this operator is prominently used by reshape operator. 468 # Just pass the input as it is so that reshape_like can be used there. 469 print("Shape: Differently implemented in NNVM as a bypass (dummy operator)") 470 return inputs[0] 471 472class Cast(OnnxOpConverter): 473 """ Operator converter for Cast. 474 """ 475 476 @classmethod 477 def _impl_v1(cls, inputs, attr, params): 478 return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr) 479 480 @classmethod 481 def _impl_v5(cls, inputs, attr, params): 482 try: 483 from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 484 attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']] 485 except ImportError as e: 486 raise ImportError( 487 "Unable to import onnx.mapping which is required {}".format(e)) 488 return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr) 489 490 491class Unsqueeze(OnnxOpConverter): 492 """ Operator converter for Unsqueeze. 493 """ 494 495 @classmethod 496 def _impl_v1(cls, inputs, attr, params): 497 for axes in attr['axes']: 498 inputs[0] = _sym.expand_dims(inputs[0], axis=axes, num_newaxis=1) 499 return inputs[0] 500 501 502class Split(OnnxOpConverter): 503 """ Operator converter for Split. 504 """ 505 506 @classmethod 507 def _impl_v1(cls, inputs, attr, params): 508 attr['indices_or_sections'] = [] 509 index = 0 510 for i in attr['split'][:-1]: 511 index += i 512 attr['indices_or_sections'].append(index) 513 return AttrCvt( 514 op_name='split', 515 ignores=['split'])(inputs, attr, params) 516 517 518class Slice(OnnxOpConverter): 519 """ Operator converter for Slice. 520 """ 521 @classmethod 522 def _impl_v1(cls, inputs, attr, params): 523 if isinstance(attr['starts'], int): 524 attr['starts'] = (attr['starts'],) 525 attr['ends'] = (attr['ends'],) 526 527 try: 528 # Update the starts and ends according to axes if required. 529 if isinstance(attr['axes'], int): 530 attr['axes'] = (attr['axes'],) 531 532 if (max(attr['axes']) + 1) != len(attr['axes']): 533 new_axes = [] 534 new_starts = [] 535 new_ends = [] 536 pop_index = 0 537 for i in range(max(attr['axes']) + 1): 538 if i in attr['axes']: 539 new_axes.append(i) 540 new_starts.append(attr['starts'][pop_index]) 541 new_ends.append(attr['ends'][pop_index]) 542 pop_index += 1 543 else: 544 new_axes.append(i) 545 new_starts.append(0) 546 new_ends.append(np.iinfo(np.int32).max) 547 attr['axes'] = new_axes 548 attr['starts'] = new_starts 549 attr['ends'] = new_ends 550 except KeyError: 551 pass 552 553 return AttrCvt(op_name='strided_slice', 554 transforms={'starts': 'begin', 555 'ends': 'end'}, 556 ignores=['axes'])(inputs, attr) 557 558class Gather(OnnxOpConverter): 559 """ Operator converter for Gather. 560 """ 561 @classmethod 562 def _impl_v1(cls, inputs, attr, params): 563 axis = attr.get('axis', 0) 564 return AttrCvt(op_name='take', 565 extras={'axis':axis})(inputs, attr) 566 567class LRN(OnnxOpConverter): 568 """ Operator converter for Local Response Normalization. 569 """ 570 @classmethod 571 def _impl_v1(cls, inputs, attr, params): 572 """LRN support only NCHW format 573 https://github.com/onnx/onnx/blob/master/docs/Operators.md#LRN 574 """ 575 axis = 1 576 alpha = attr.get('alpha', 0.0001) 577 beta = attr.get('beta', 0.75) 578 bias = attr.get('bias', 1.0) 579 nsize = attr.get('size') 580 return _sym.lrn(inputs[0], size=nsize, axis=axis, 581 alpha=alpha, beta=beta, bias=bias) 582 583class Maximum(OnnxOpConverter): 584 """ Operator converter for Maximum. 585 """ 586 @classmethod 587 def _impl_v1(cls, inputs, attr, params): 588 if not isinstance(inputs, list) or len(inputs) < 2: 589 raise ValueError("Expect minimum 2 inputs") 590 _max = inputs[0] 591 for i in range(1, len(inputs)): 592 _max = AttrCvt(op_name='broadcast_max')([_max, inputs[i]], {}) 593 return _max 594 595class Minimum(OnnxOpConverter): 596 """ Operator converter for Minimum. 597 """ 598 @classmethod 599 def _impl_v1(cls, inputs, attr, params): 600 if not isinstance(inputs, list) or len(inputs) < 2: 601 raise ValueError("Expect minimum 2 inputs") 602 _min = inputs[0] 603 for i in range(1, len(inputs)): 604 _min = AttrCvt(op_name='broadcast_min')([_min, inputs[i]], {}) 605 return _min 606 607class Mean(OnnxOpConverter): 608 """ Operator converter for Mean. 609 """ 610 @classmethod 611 def _impl_v1(cls, inputs, attr, params): 612 if not isinstance(inputs, list) or len(inputs) < 2: 613 raise ValueError("Expect minimum 2 inputs") 614 count = len(inputs) 615 _sum = inputs[0] 616 for i in range(1, count): 617 _sum = AttrCvt(op_name='broadcast_add')([_sum, inputs[i]], {}) 618 return _sum / count 619 620class HardSigmoid(OnnxOpConverter): 621 """ Operator converter for HardSigmoid. 622 """ 623 @classmethod 624 def _impl_v1(cls, inputs, attr, params): 625 alpha = attr.get('alpha', 0.2) 626 beta = attr.get('beta', 0.5) 627 transformX = (inputs[0] * alpha) + beta 628 attr = {'a_min':0, 'a_max':1} 629 return AttrCvt(op_name='clip')([transformX], attr) 630 631class ArgMax(OnnxOpConverter): 632 """ Operator converter for ArgMax. 633 """ 634 @classmethod 635 def _impl_v1(cls, inputs, attr, params): 636 axis = attr.get('axis', 0) 637 keepdims = attr.get('keepdims', True) 638 attr = {'axis':axis, 'keepdims':keepdims} 639 return AttrCvt(op_name='argmax')(inputs, attr) 640 641class ArgMin(OnnxOpConverter): 642 """ Operator converter for ArgMin. 643 """ 644 @classmethod 645 def _impl_v1(cls, inputs, attr, params): 646 axis = attr.get('axis', 0) 647 keepdims = attr.get('keepdims', True) 648 attr = {'axis':axis, 'keepdims':keepdims} 649 return AttrCvt(op_name='argmin')(inputs, attr) 650 651class Softmax(OnnxOpConverter): 652 """ Operator converter for Softmax. 653 """ 654 @classmethod 655 def _impl_v1(cls, inputs, attr, params): 656 # set default value when axis is not set in the model 657 if 'axis' not in attr: 658 attr['axis'] = 1 659 return AttrCvt( 660 op_name='softmax', 661 transforms={ 662 'axis': ('axis', 1), 663 })(inputs, attr, params) 664 665class ConstantFill(OnnxOpConverter): 666 """ Operator converter for ConstantFill. 667 """ 668 @classmethod 669 def _impl_v1(cls, inputs, attr, params): 670 is_full = True 671 num_inputs = len(inputs) 672 if 'shape' in attr: 673 if num_inputs > 0: 674 raise ImportError( 675 "Can't set shape and input tensor at a time") 676 shape = attr.pop('shape') 677 else: 678 if num_inputs == 0: 679 raise ImportError( 680 "Either shape attribute or input should be set") 681 if 'input_as_shape' in attr and attr['input_as_shape']: 682 shape = params[inputs[0].list_output_names()[0]].asnumpy() 683 else: 684 is_full = False 685 686 if not is_full: 687 if 'extra_shape' in attr: 688 raise ImportError( 689 "Extra Shape not supported with fill_like") 690 691 out = AttrCvt( 692 op_name='full_like', 693 transforms={'value': 'fill_value'}, 694 ignores=['dtype'])(inputs, attr) 695 return _sym.cast(out, dtype=attr['dtype'].decode("utf-8")) 696 if 'extra_shape' in attr: 697 shape = shape + attr.pop('extra_shape') 698 699 return AttrCvt( 700 op_name='full', 701 transforms={'value': 'fill_value'}, 702 extras={'shape':shape})(inputs, attr) 703 704# compatible operators that do NOT require any conversion. 705_identity_list = [] 706 707 708# _convert_map defines maps of name to converter functor(callable) 709# for 1 to 1 mapping, use Renamer if nothing but name is different 710# use AttrCvt if attributes need to be converted 711# for 1 to N mapping(composed), use custom callable functions 712# for N to 1 mapping, currently not supported(?) 713def _get_convert_map(opset): 714 return { 715 # defs/experimental 716 'Identity': Renamer('copy'), 717 # 'Affine' 718 'ThresholdedRelu': ThresholdedRelu.get_converter(opset), 719 'ScaledTanh': ScaledTanh.get_converter(opset), 720 'ParametricSoftplus': ParametricSoftPlus.get_converter(opset), 721 'ConstantFill': ConstantFill.get_converter(opset), 722 # 'GivenTensorFill' 723 'FC': AttrCvt('dense', ignores=['axis', 'axis_w']), 724 'Scale': Scale.get_converter(opset), 725 # 'GRUUnit' 726 # 'ATen' 727 'ImageScaler': ImageScaler.get_converter(opset), 728 # 'MeanVarianceNormalization' 729 # 'Crop' 730 # 'Embedding' 731 'Upsample' : Upsample.get_converter(opset), 732 'SpatialBN': BatchNorm.get_converter(opset), 733 734 # defs/generator 735 # 'Constant' # Implemented 736 # 'RandomUniform' 737 # 'RandomNormal' 738 # 'RandomUniformLike' 739 # 'RandomNormalLike' 740 741 # defs/logical 742 743 # defs/math 744 'Add': Add.get_converter(opset), 745 'Sub': Sub.get_converter(opset), 746 'Mul': Mul.get_converter(opset), 747 'Div': Div.get_converter(opset), 748 'Neg': Renamer('negative'), 749 'Abs': Absolute.get_converter(opset), 750 'Reciprocal': Reciprocal.get_converter(opset), 751 'Floor': Renamer('floor'), 752 'Ceil': Renamer('ceil'), 753 'Sqrt': Renamer('sqrt'), 754 'Relu': Renamer('relu'), 755 'LeakyRelu': Renamer('leaky_relu'), 756 'Selu': Selu.get_converter(opset), 757 'Elu': Elu.get_converter(opset), 758 'Exp': Renamer('exp'), 759 'Log': Renamer('log'), 760 'Tanh': Renamer('tanh'), 761 'Pow': Renamer('broadcast_pow'), 762 'PRelu': Prelu.get_converter(opset), 763 'Sigmoid': Renamer('sigmoid'), 764 'HardSigmoid': HardSigmoid.get_converter(opset), 765 'Max': Maximum.get_converter(opset), 766 'Min': Minimum.get_converter(opset), 767 'Sum': Sum.get_converter(opset), 768 'Mean': Mean.get_converter(opset), 769 'Clip': AttrCvt('clip', transforms={'min': 'a_min', 'max': 'a_max'}), 770 # softmax default axis is different in onnx 771 'Softmax': Softmax.get_converter(opset), 772 'LogSoftmax': AttrCvt('log_softmax', {'axis': ('axis', 1)}), 773 # 'Hardmax' 774 'Softsign': Softsign.get_converter(opset), 775 'SoftPlus': SoftPlus.get_converter(opset), 776 'Gemm': Gemm.get_converter(opset), 777 'MatMul': Renamer('matmul'), 778 779 # defs/nn 780 'AveragePool': AveragePool.get_converter(opset), 781 'MaxPool': MaxPool.get_converter(opset), 782 'Conv': Conv.get_converter(opset), 783 'ConvTranspose': ConvTranspose.get_converter(opset), 784 'GlobalAveragePool': Renamer('global_avg_pool2d'), 785 'GlobalMaxPool': Renamer('global_max_pool2d'), 786 'BatchNormalization': BatchNorm.get_converter(opset), 787 # 'InstanceNormalization' 788 # 'LpNormalization' 789 'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), 790 'Flatten': Renamer('flatten'), 791 'LRN': LRN.get_converter(opset), 792 793 # defs/reduction 794 'ReduceMax': AttrCvt('max', {'axes': 'axis'}), 795 'ReduceMin': AttrCvt('min', {'axes': 'axis'}), 796 'ReduceSum': AttrCvt('sum', {'axes': 'axis'}), 797 'ReduceMean': AttrCvt('mean', {'axes': 'axis'}), 798 # 'ReduceProd' 799 # 'ReduceLogSumExp' 800 'ArgMax': ArgMax.get_converter(opset), 801 'ArgMin': ArgMin.get_converter(opset), 802 803 # defs/tensor 804 'Cast': Cast.get_converter(opset), 805 'Reshape': Reshape.get_converter(opset), 806 'Concat': Renamer('concatenate'), 807 'Split': Split.get_converter(opset), 808 'Slice': Slice.get_converter(opset), 809 'Transpose': AttrCvt('transpose', {'perm': 'axes'}), 810 'Gather': Gather.get_converter(opset), 811 'Squeeze': AttrCvt('squeeze', {'axes': 'axis'}), 812 'Unsqueeze': Unsqueeze.get_converter(opset), 813 'Pad': Pad.get_converter(opset), 814 'Shape': Shape.get_converter(opset), 815 } 816 817 818class GraphProto(object): 819 """A helper class for handling nnvm graph copying from pb2.GraphProto. 820 Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto 821 """ 822 823 def __init__(self): 824 self._nodes = {} 825 self._params = {} 826 self._renames = {} 827 self._num_input = 0 828 self._num_param = 0 829 830 def from_onnx(self, graph, opset): 831 """Construct nnvm nodes from onnx graph. 832 The inputs from onnx graph is vague, only providing "1", "2"... 833 For convenience, we rename the `real` input names to "input_0", 834 "input_1"... And renaming parameters to "param_0", "param_1"... 835 836 Parameters 837 ---------- 838 graph : onnx protobuf object 839 The loaded onnx graph 840 opset : opset version 841 842 Returns 843 ------- 844 sym : nnvm.sym.Symbol 845 The returned nnvm symbol 846 params : dict 847 A dict of name: tvm.nd.array pairs, used as pretrained weights 848 """ 849 # parse network inputs to nnvm, aka parameters 850 for init_tensor in graph.initializer: 851 if not init_tensor.name.strip(): 852 raise ValueError("Tensor's name is required.") 853 self._params[init_tensor.name] = self._parse_array(init_tensor) 854 for i in graph.input: 855 # from onnx v0.2, GraphProto.input has type ValueInfoProto, 856 # and the name is 'i.name' 857 i_name = self._parse_value_proto(i) 858 if i_name in self._params: 859 # i is a param instead of input 860 self._num_param += 1 861 self._params[i_name] = self._params.pop(i_name) 862 self._nodes[i_name] = _sym.Variable( 863 name=i_name, shape=self._params[i_name].shape) 864 else: 865 self._num_input += 1 866 self._nodes[i_name] = _sym.Variable(name=i_name) 867 # get list of unsupported ops 868 convert_map = _get_convert_map(opset) 869 unsupported_ops = set() 870 for node in graph.node: 871 op_name = node.op_type 872 if op_name not in convert_map and \ 873 op_name != 'Constant' and \ 874 op_name not in _identity_list: 875 unsupported_ops.add(op_name) 876 if unsupported_ops: 877 msg = 'The following operators are not supported for frontend ONNX: ' 878 msg += ', '.join(unsupported_ops) 879 raise tvm.error.OpNotImplemented(msg) 880 # construct nodes, nodes are stored as directed acyclic graph 881 for node in graph.node: 882 op_name = node.op_type 883 attr = self._parse_attr(node.attribute) 884 inputs = [self._nodes[self._renames.get(i, i)] for i in node.input] 885 if op_name == "Constant": 886 t_proto = self._parse_attr(node.attribute)["value"] 887 self._num_param += 1 888 self._params[node.output[0]] = self._parse_array(t_proto) 889 self._nodes[node.output[0]] = _sym.Variable(name=node.output[0], 890 shape=list(t_proto.dims)) 891 else: 892 op = self._convert_operator(op_name, inputs, attr, opset) 893 node_output = self._fix_outputs(op_name, node.output) 894 assert len(node_output) == len(op.list_output_names()), ( 895 "Number of output mismatch {} vs {} in {}.".format( 896 len(node_output), len(op.list_output_names()), op_name)) 897 for k, i in zip(list(node_output), range(len(node_output))): 898 self._nodes[k] = op[i] 899 # now return the outputs 900 out = [self._nodes[self._parse_value_proto(i)] for i in graph.output] 901 if len(out) > 1: 902 out = _sym.Group(out) 903 else: 904 out = out[0] 905 return out, self._params 906 907 def _parse_value_proto(self, value_proto): 908 """Parse ValueProto or raw str.""" 909 try: 910 name = value_proto.name 911 except AttributeError: 912 name = value_proto 913 return name 914 915 def _parse_array(self, tensor_proto): 916 """Grab data in TensorProto and convert to numpy array.""" 917 try: 918 from onnx.numpy_helper import to_array 919 except ImportError as e: 920 raise ImportError( 921 "Unable to import onnx which is required {}".format(e)) 922 np_array = to_array(tensor_proto).reshape(tuple(tensor_proto.dims)) 923 return tvm.nd.array(np_array) 924 925 def _parse_attr(self, attr_proto): 926 """Convert a list of AttributeProto to a dict, with names as keys.""" 927 attrs = {} 928 for a in attr_proto: 929 for f in ['f', 'i', 's']: 930 if a.HasField(f): 931 attrs[a.name] = getattr(a, f) 932 for f in ['floats', 'ints', 'strings']: 933 if list(getattr(a, f)): 934 assert a.name not in attrs, "Only one type of attr is allowed" 935 attrs[a.name] = tuple(getattr(a, f)) 936 for f in ['t']: 937 if a.HasField(f): 938 attrs[a.name] = getattr(a, f) 939 for f in ['tensors']: 940 if list(getattr(a, f)): 941 assert a.name not in attrs, "Only one type of attr is allowed" 942 attrs[a.name] = tuple(getattr(a, f)) 943 for f in ['g']: 944 if a.HasField(f): 945 raise NotImplementedError( 946 "Filed {} is not supported in nnvm.".format(f)) 947 for f in ['graphs']: 948 if list(getattr(a, f)): 949 raise NotImplementedError( 950 "Filed {} is not supported in nnvm.".format(f)) 951 if a.name not in attrs: 952 raise ValueError("Cannot parse attribute: \n{}\n.".format(a)) 953 return attrs 954 955 def _convert_operator(self, 956 op_name, 957 inputs, 958 attrs, 959 opset, 960 identity_list=None, 961 convert_map=None): 962 """Convert from onnx operator to nnvm operator. 963 The converter must specify conversions explicitly for incompatible name, and 964 apply handlers to operator attributes. 965 966 Parameters 967 ---------- 968 op_name : str 969 Operator name, such as Convolution, FullyConnected 970 inputs : list of nnvm.Symbol 971 List of input symbols. 972 attrs : dict 973 Dict of operator attributes 974 opset : int 975 Opset version 976 identity_list : list 977 List of operators that don't require conversion 978 convert_map : dict 979 Dict of name : callable, where name is the op's name that 980 require conversion to nnvm, callable are functions which 981 take attrs and return (new_op_name, new_attrs) 982 983 Returns 984 ------- 985 sym : nnvm.Symbol 986 Converted nnvm Symbol 987 """ 988 identity_list = identity_list if identity_list else _identity_list 989 convert_map = convert_map if convert_map else _get_convert_map(opset) 990 if op_name in identity_list: 991 sym = get_nnvm_op(op_name)(*inputs, **attrs) 992 elif op_name in convert_map: 993 sym = convert_map[op_name](inputs, attrs, self._params) 994 else: 995 raise tvm.error.OpNotImplemented( 996 'Operator {} is not supported in frontend ONNX.') 997 return sym 998 999 def _fix_outputs(self, op_name, outputs): 1000 """A hack to handle dropout or similar operator that have more than one out 1001 in ONNX. 1002 """ 1003 if op_name == 'Dropout': 1004 if len(outputs) == 1: 1005 return outputs 1006 # TODO(zhreshold): support dropout mask? 1007 outputs = outputs[:-1] 1008 return outputs 1009 1010 1011def from_onnx(model): 1012 """Load onnx graph which is a python protobuf object into nnvm graph. 1013 The companion parameters will be handled automatically. 1014 The inputs from onnx graph is vague, only providing "1", "2"... 1015 For convenience, we rename the `real` input names to "input_0", 1016 "input_1"... And renaming parameters to "param_0", "param_1"... 1017 1018 Parameters 1019 ---------- 1020 model : protobuf object 1021 ONNX ModelProto after ONNX v1.1.0 1022 1023 Returns 1024 ------- 1025 sym : nnvm.Symbol 1026 Compatible nnvm symbol 1027 1028 params : dict of str to tvm.ndarray 1029 Dict of converted parameters stored in tvm.ndarray format 1030 """ 1031 g = GraphProto() 1032 graph = model.graph 1033 try: 1034 opset = model.opset_import[0].version if model.opset_import else 1 1035 except AttributeError: 1036 opset = 1 1037 sym, params = g.from_onnx(graph, opset) 1038 return sym, params 1039