1#---------------------------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for license information. 4#---------------------------------------------------------------------------------------------- 5 6import os 7import numpy as np 8import mmdnn.conversion.common.IR.graph_pb2 as graph_pb2 9from mmdnn.conversion.common.IR.graph_pb2 import NodeDef, GraphDef, DataType 10from mmdnn.conversion.common.utils import * 11from mmdnn.conversion.common.DataStructure.parser import Parser 12from mmdnn.conversion.pytorch.pytorch_graph import PytorchGraph040 13from mmdnn.conversion.pytorch.pytorch_graph import PytorchGraph151 14import torch 15import torchvision 16 17class PytorchParser(Parser): 18 19 layer_map = { 20 'onnx::Conv': 'Conv', 21 'onnx::Flatten': 'Flatten', 22 'onnx::Gemm': 'FullyConnected', 23 'onnx::MaxPool': 'Maxpool', 24 'onnx::AveragePool': 'Avgpool', 25 'onnx::GlobalAveragePool': 'GAvgpool', 26 'onnx::Dropout': 'Dropout', 27 'onnx::BatchNormalization': 'BatchNormalization', 28 'onnx::Add': 'Add', 29 'onnx::Concat': 'Concat', 30 'onnx::Relu': 'Relu', 31 'onnx::Tanh': 'Tanh', 32 'onnx::Sigmoid': 'Sigmoid', 33 'onnx::Mul': 'Mul', 34 'onnx::Pad': 'Pad' 35 36 37 # TODO 38 # 'max_pool2d': convert_maxpool, 39 # 'onnx::Mul': convert_elementwise_mul, 40 # 'onnx::Sub': convert_elementwise_sub, 41 # 'onnx::ConvTranspose': convert_convtranspose, 42 # 'onnx::LeakyRelu': convert_lrelu, 43 # 'onnx::Sigmoid': convert_sigmoid, 44 # 'onnx::Softmax': convert_softmax, 45 # 'onnx::Selu': convert_selu, 46 # 'onnx::Transpose': convert_transpose, 47 # 'onnx::Reshape': convert_reshape, 48 # 'onnx::MatMul': convert_matmul, 49 # 'onnx::Gather': convert_gather, 50 # 'onnx::ReduceSum': convert_reduce_sum, 51 # 'onnx::Constant': convert_constant, 52 # 'onnx::Upsample': convert_upsample, 53 # 'onnx::Pad': convert_padding, 54} 55 56 57 ############ 58 # property # 59 ############ 60 61 @property 62 def src_graph(self): 63 return self.pytorch_graph 64 65 def get_weight_name(self, node): 66 pass 67 68 #################### 69 # Public Functions # 70 #################### 71 72 def __init__(self, model_file_name, input_shape): 73 super(PytorchParser, self).__init__() 74 if not os.path.exists(model_file_name): 75 print("Pytorch model file [{}] is not found.".format(model_file_name)) 76 assert False 77 # test 78 79 # cpu: https://github.com/pytorch/pytorch/issues/5286 80 try: 81 model = torch.load(model_file_name) 82 except: 83 model = torch.load(model_file_name, map_location='cpu') 84 85 self.weight_loaded = True 86 self.model = model 87 # Build network graph 88 self.pytorch_graph = None 89 90 def build_graph(self, input_shape): 91 self.input_shape = tuple([1] + input_shape) 92 self.pytorch_graph.build(self.input_shape) 93 self.state_dict = self.pytorch_graph.state_dict 94 self.shape_dict = self.pytorch_graph.shape_dict 95 96 def gen_IR(self): 97 for layer in self.src_graph.topological_sort: 98 current_node = self.src_graph.get_node(layer) 99 onnx_node_type = current_node.type 100 node_type = PytorchParser.layer_map[onnx_node_type] 101 102 103 if hasattr(self, "rename_" + node_type): 104 func = getattr(self, "rename_" + node_type) 105 func(current_node) 106 107 else: 108 self.rename_UNKNOWN(current_node) 109 110 self.gen_Input() 111 112 113 114 def _set_output_shape(self, source_node, IR_node): 115 116 shape = graph_pb2.TensorShape() 117 118 119 layer_name = source_node.name 120 121 shape_pytorch = self.shape_dict[layer_name] 122 123 124 new_dim = shape.dim.add() 125 126 # (batch, C, H, W) & NHWC 127 if len(shape_pytorch) == 4: 128 129 if shape_pytorch[0] == 1: 130 new_dim.size = -1 131 else: 132 new_dim.size = shape_pytorch[0] 133 for index in [2, 3, 1]: 134 new_dim = shape.dim.add() 135 dim = shape_pytorch[index] 136 new_dim.size = dim if dim else -1 137 elif len(shape_pytorch) == 2: 138 if shape_pytorch[0] == 1: 139 new_dim.size = -1 140 else: 141 new_dim.size = shape_pytorch[0] 142 for _ in range(2): 143 new_dim = shape.dim.add() 144 new_dim.size = 1 145 new_dim = shape.dim.add() 146 dim = shape_pytorch[1] 147 new_dim.size = dim if dim else -1 148 149 150 IR_node.attr["_output_shapes"].list.shape.extend([shape]) 151 152 ########## 153 # Layers # 154 ########## 155 def rename_UNKNOWN(self, source_node): 156 print("PyTorch parser has not supported operator [%s] with name [%s]." 157 % (source_node.type, source_node.name)) 158 assert False 159 print(source_node.layer) 160 print(source_node.layer.data.size()) 161 162 163 164 165 def gen_Input(self): 166 IR_node = self.IR_graph.node.add() 167 IR_node.name = 'input' 168 IR_node.op = "DataInput" 169 170 for node in self.IR_graph.node: 171 if node.name in self.src_graph.input_layers: 172 node.input.append('input') 173 174 assert len(self.input_shape) == 4 175 new_dim = IR_node.attr["shape"].shape.dim.add() 176 if self.input_shape[0] == 1: 177 new_dim.size = -1 178 else: 179 new_dim.size = self.input_shape[0] 180 for index in [2, 3, 1]: 181 new_dim = IR_node.attr["shape"].shape.dim.add() 182 new_dim.size = self.input_shape[index] 183 184 shape = graph_pb2.TensorShape() 185 new_dim = shape.dim.add() 186 shape_pytorch = self.input_shape 187 188 if len(shape_pytorch) == 4: 189 190 if shape_pytorch[0] == 1: 191 new_dim.size = -1 192 else: 193 new_dim.size = shape_pytorch[0] 194 for index in [2, 3, 1]: 195 new_dim = shape.dim.add() 196 dim = shape_pytorch[index] 197 new_dim.size = dim if dim else -1 198 elif len(shape_pytorch) == 2: 199 if shape_pytorch[0] == 1: 200 new_dim.size = -1 201 else: 202 new_dim.size = shape_pytorch[0] 203 for _ in range(2): 204 new_dim = shape.dim.add() 205 new_dim.size = 1 206 new_dim = shape.dim.add() 207 dim = shape_pytorch[1] 208 new_dim.size = dim if dim else -1 209 210 211 IR_node.attr["_output_shapes"].list.shape.extend([shape]) 212 213 214 def rename_Conv(self, source_node): 215 216 attr = source_node.attrs 217 kwargs = dict() 218 219 # dilation 220 if 'dilations' in attr: 221 kwargs['dilations'] = [1] + attr['dilations'] + [1] 222 else: 223 kwargs['dilations'] = [1] + [1, 1] + [1] 224 225 if len(attr['pads']) == 4: 226 kwargs['pads'] = [0] + attr['pads'][0:2] + [0, 0] + attr['pads'][2:] + [0] 227 elif len(attr['pads']) == 2: 228 kwargs['pads'] = ( [0] + attr['pads'][0:2] + [0] ) *2 229 230 if 'strides' not in attr: 231 kwargs['strides'] = [1] + [1, 1] + [1] 232 else: 233 kwargs['strides'] = [1] + attr['strides'] + [1] 234 235 kwargs['group'] = attr['group'] 236 237 weights_scope = self.get_weight_name(source_node) 238 239 bias_name = '{0}.bias'.format(weights_scope) 240 weights_name = '{0}.weight'.format(weights_scope) 241 weight = self.state_dict[weights_name] 242 243 weight = weight.numpy() 244 dim = weight.ndim - 2 245 246 247 IR_node = self._convert_identity_operation(source_node, new_op="Conv") 248 weight = np.transpose(weight, list(range(2, dim + 2)) + [1, 0]) 249 250 self.set_weight(source_node.name, 'weights', weight) 251 kwargs['kernel_shape'] = list(weight.shape) 252 253 254 # handle bias 255 if bias_name in self.state_dict: 256 bias = self.state_dict[bias_name].numpy() 257 self.set_weight(source_node.name, 'bias', bias) 258 kwargs['use_bias'] = True 259 else: 260 kwargs['use_bias'] = False 261 262 263 assign_IRnode_values(IR_node, kwargs) 264 265 266 def rename_BatchNormalization(self, source_node): 267 # TODO 268 # output_shape 269 270 IR_node = self._convert_identity_operation(source_node, new_op="BatchNorm") 271 272 273 attr = source_node.attrs 274 # epsilon 275 IR_node.attr['epsilon'].f = attr['epsilon'] 276 weights_scope = self.get_weight_name(source_node) 277 278 bias_name = '{0}.bias'.format(weights_scope) 279 weights_name = '{0}.weight'.format(weights_scope) 280 mean_name = '{0}.running_mean'.format(weights_scope) 281 var_name = '{0}.running_var'.format(weights_scope) 282 283 284 285 if bias_name in self.state_dict: 286 beta = self.state_dict[bias_name].numpy() 287 IR_node.attr['bias'].b = True 288 else: 289 IR_node.attr['bias'].b = False 290 291 if weights_name in self.state_dict: 292 gamma = self.state_dict[weights_name].numpy() 293 IR_node.attr['scale'].b = True 294 else: 295 IR_node.attr['scale'].b = False 296 297 mean = self.state_dict[mean_name].numpy() 298 variance = self.state_dict[var_name].numpy() 299 300 301 302 if IR_node.attr['scale'].b: 303 self.set_weight(source_node.name, "scale", gamma) 304 305 if IR_node.attr['bias'].b: 306 self.set_weight(source_node.name, "bias", beta) 307 308 # mean 309 self.set_weight(source_node.name, "mean", mean) 310 311 # var 312 self.set_weight(source_node.name, "var", variance) 313 314 def rename_Pad(self, source_node): 315 IR_node = self._convert_identity_operation(source_node, new_op="Pad") 316 attr = source_node.attrs 317 kwargs = dict() 318 kwargs['mode'] = attr['mode'] 319 kwargs['pads'] = attr['pads'] 320 kwargs['constant_values'] = attr['value'] 321 assign_IRnode_values(IR_node, kwargs) 322 323 def rename_Relu(self, source_node): 324 IR_node = self._convert_identity_operation(source_node, new_op="Relu") 325 326 def rename_Tanh(self, source_node): 327 IR_node = self._convert_identity_operation(source_node, new_op="Tanh") 328 329 def rename_Sigmoid(self, source_node): 330 IR_node = self._convert_identity_operation(source_node, new_op="Sigmoid") 331 332 def rename_Mul(self, source_node): 333 IR_node = self._convert_identity_operation(source_node, new_op="Mul") 334 335 def rename_Maxpool(self, source_node): 336 attr = source_node.attrs 337 kwargs = dict() 338 kwargs['strides'] = [1] + attr['strides'] + [1] 339 if 'dilations' not in attr: 340 kwargs['dilations'] = [1] + [1, 1] + [1] 341 else: 342 kwargs['dilations'] = [1] + attr['dilations'] + [1] 343 kwargs['pads'] = [0] + attr['pads'][0:2] + [0, 0] + attr['pads'][2:] + [0] 344 kwargs['kernel_shape'] = [1] + attr['kernel_shape'] + [1] 345 IR_node = self._convert_identity_operation(source_node, new_op="Pool") 346 347 kwargs['pooling_type'] = 'MAX' 348 349 assign_IRnode_values(IR_node, kwargs) 350 351 def rename_Avgpool(self, source_node): 352 attr = source_node.attrs 353 kwargs = dict() 354 kwargs['strides'] = [1] + attr['strides'] + [1] 355 if 'dilations' not in attr: 356 kwargs['dilations'] = [1] + [1, 1] + [1] 357 else: 358 kwargs['dilations'] = [1] + attr['dilations'] + [1] 359 if 'pads' in attr: 360 kwargs['pads'] = [0] + attr['pads'][0:2] + [0, 0] + attr['pads'][2:] + [0] 361 else: 362 kwargs['pads'] = [0, 0, 0, 0, 0, 0, 0, 0] 363 kwargs['kernel_shape'] = [1] + attr['kernel_shape'] + [1] 364 IR_node = self._convert_identity_operation(source_node, new_op="Pool") 365 366 kwargs['pooling_type'] = 'AVG' 367 368 assign_IRnode_values(IR_node, kwargs) 369 370 def rename_GAvgpool(self, source_node): 371 attr = source_node.attrs 372 input_shape = self.pytorch_graph.shape_dict[source_node.in_edges[0]] 373 kwargs = dict() 374 kwargs['strides'] = [1, 1, 1, 1] 375 kwargs['dilations'] = [1] + [1, 1] + [1] 376 kwargs['pads'] = [0, 0, 0, 0, 0, 0, 0, 0] 377 kwargs['kernel_shape'] = [1] + input_shape[2:] + [1] 378 IR_node = self._convert_identity_operation(source_node, new_op="Pool") 379 380 kwargs['pooling_type'] = 'AVG' 381 382 assign_IRnode_values(IR_node, kwargs) 383 384 def rename_Flatten(self, source_node): 385 IR_node = self._convert_identity_operation(source_node, new_op="Flatten") 386 387 def rename_FullyConnected(self, source_node): 388 IR_node = self._convert_identity_operation(source_node, new_op="FullyConnected") 389 weights_scope = self.get_weight_name(source_node) 390 bias_name = '{0}.bias'.format(weights_scope) 391 weights_name = '{0}.weight'.format(weights_scope) 392 393 394 W = self.state_dict[weights_name].numpy().transpose() 395 input_channels, output_channels = W.shape 396 397 # Kit weight tranpose 398 # weight: N x M -> C x H x W x M -> H x W x C x M -> N x M 399 if self.weight_loaded: 400 parent = self.src_graph.get_parent(source_node.name, [0]) 401 while parent.type == 'onnx::Flatten' or parent.type == 'onnx::Dropout': 402 parent = self.src_graph.get_parent(parent.name, [0]) 403 if len(self.shape_dict[parent.name]) == 4: 404 # 405 original_shape = W.shape 406 channel_first_list = self.shape_dict[parent.name][1:] 407 dim = len(channel_first_list) + 1 408 weight = W.reshape(channel_first_list + [original_shape[1]]) 409 assert dim > 2 410 weight = weight.transpose(list(range(1, dim-1)) + [0, dim-1]) 411 W = weight.reshape(original_shape) 412 413 # weights 414 self.set_weight(source_node.name, 'weights', W ) 415 416 # use_bias 417 if bias_name in self.state_dict: 418 IR_node.attr['use_bias'].b = True 419 bias = self.state_dict[bias_name].numpy() 420 self.set_weight(source_node.name, 'bias', bias ) 421 else: 422 IR_node.attr['use_bias'].b = False 423 424 # units 425 IR_node.attr['units'].i = output_channels 426 427 428 def rename_Dropout(self, source_node): 429 IR_node = self._convert_identity_operation(source_node, new_op='Dropout') 430 IR_node.attr['keep_prob'].f = source_node.attrs['ratio'] 431 432 def rename_Concat(self, source_node): 433 IR_node = self._convert_identity_operation(source_node, new_op='Concat') 434 435 if source_node.attrs['axis'] == 1: 436 IR_node.attr['axis'].i = len(self.shape_dict[source_node.name]) - 1 437 else: 438 IR_node.attr['axis'].i = source_node.attrs['axis'] 439 440 def rename_Add(self, source_node): 441 IR_node = self._convert_identity_operation(source_node, new_op='Add') 442 443 444 def rename_MaxPool2d(self, source_node): 445 self._convert_pooling(source_node) 446 447 448 def rename_View(self, source_node): 449 IR_node = self._convert_identity_operation(source_node, new_op='Reshape') 450 assign_IRnode_values(IR_node, {'shape' : list(source_node.get_attr('new_sizes'))[1:]}) 451 452 453 def rename_Addmm(self, source_node): 454 IR_node = self._convert_identity_operation(source_node, new_op='FullyConnected') 455 kwargs = dict() 456 457 # handle weight 458 weight = source_node.get_attr('next_functions')[2][0].next_functions[0][0].variable.data.numpy() 459 weight = np.transpose(weight) 460 kwargs['units'] = weight.shape[1] 461 self.set_weight(source_node.name, 'weights', weight) 462 463 # handle bias 464 if source_node.get_attr('next_functions')[0][0]: 465 bias = source_node.get_attr('next_functions')[0][0].variable.data.numpy() 466 kwargs['use_bias'] = True 467 self.set_weight(source_node.name, 'bias', weight) 468 469 assign_IRnode_values(IR_node, kwargs) 470 471 472 473 #################### 474 # Helper Functions # 475 #################### 476 477 @staticmethod 478 def _copy_and_reop(source_node, IR_node, new_op = None): 479 if new_op == None: new_op = source_node.type 480 IR_node.name = source_node.name 481 IR_node.op = new_op 482 483 484 def _convert_identity_operation(self, source_node, in_edge_count = None, new_op = None): 485 IR_node = self.IR_graph.node.add() 486 PytorchParser._copy_and_reop(source_node, IR_node, new_op) 487 self.convert_inedge(source_node, IR_node, 0, in_edge_count) 488 self._set_output_shape(source_node, IR_node) 489 return IR_node 490 491 def _convert_pooling(self, source_node): 492 kwargs = dict() 493 kwargs['strides'] = [1] + list(source_node.get_attr('stride')) + [1] 494 kwargs['dilations'] = [1] + list(source_node.get_attr('dilation')) + [1] 495 kwargs['pads'] = ([0] + list(source_node.get_attr('padding')) + [0]) * 2 496 kwargs['kernel_shape'] = [1] + list(source_node.get_attr('kernel_size')) + [1] 497 IR_node = self._convert_identity_operation(source_node, new_op="Pool") 498 499 if source_node.name.startswith('Max'): 500 kwargs['pooling_type'] = 'MAX' 501 elif source_node.name.startswith('Avg'): 502 kwargs['pooling_type'] = 'AVG' 503 else: 504 raise ValueError('Unknown pooling type') 505 506 assign_IRnode_values(IR_node, kwargs) 507 508class PytorchParser040(PytorchParser): 509 510 def __init__(self, model_file_name, input_shape): 511 super(PytorchParser040, self).__init__(model_file_name, input_shape) 512 self.pytorch_graph = PytorchGraph040(self.model) 513 self.build_graph(input_shape) 514 515 def get_weight_name(self, node): 516 return node.weights_name 517 518class PytorchParser151(PytorchParser): 519 520 def __init__(self, model_file_name, input_shape): 521 super(PytorchParser151, self).__init__(model_file_name, input_shape) 522 self.pytorch_graph = PytorchGraph151(self.model) 523 self.build_graph(input_shape) 524 525 def get_weight_name(self, node): 526 return self.pytorch_graph.layer_weight_map[node.name] 527 528