1#----------------------------------------------------------------------------------------------
2#  Copyright (c) Microsoft Corporation. All rights reserved.
3#  Licensed under the MIT License. See License.txt in the project root for license information.
4#----------------------------------------------------------------------------------------------
5
6import os
7import numpy as np
8import mmdnn.conversion.common.IR.graph_pb2 as graph_pb2
9from mmdnn.conversion.common.IR.graph_pb2 import NodeDef, GraphDef, DataType
10from mmdnn.conversion.common.utils import *
11from mmdnn.conversion.common.DataStructure.parser import Parser
12from mmdnn.conversion.pytorch.pytorch_graph import PytorchGraph040
13from mmdnn.conversion.pytorch.pytorch_graph import PytorchGraph151
14import torch
15import torchvision
16
17class PytorchParser(Parser):
18
19    layer_map = {
20    'onnx::Conv': 'Conv',
21    'onnx::Flatten': 'Flatten',
22    'onnx::Gemm': 'FullyConnected',
23    'onnx::MaxPool': 'Maxpool',
24    'onnx::AveragePool': 'Avgpool',
25    'onnx::GlobalAveragePool': 'GAvgpool',
26    'onnx::Dropout': 'Dropout',
27    'onnx::BatchNormalization': 'BatchNormalization',
28    'onnx::Add': 'Add',
29    'onnx::Concat': 'Concat',
30    'onnx::Relu': 'Relu',
31    'onnx::Tanh': 'Tanh',
32    'onnx::Sigmoid': 'Sigmoid',
33    'onnx::Mul': 'Mul',
34    'onnx::Pad': 'Pad'
35
36
37    # TODO
38    # 'max_pool2d': convert_maxpool,
39    # 'onnx::Mul': convert_elementwise_mul,
40    # 'onnx::Sub': convert_elementwise_sub,
41    # 'onnx::ConvTranspose': convert_convtranspose,
42    # 'onnx::LeakyRelu': convert_lrelu,
43    # 'onnx::Sigmoid': convert_sigmoid,
44    # 'onnx::Softmax': convert_softmax,
45    # 'onnx::Selu': convert_selu,
46    # 'onnx::Transpose': convert_transpose,
47    # 'onnx::Reshape': convert_reshape,
48    # 'onnx::MatMul': convert_matmul,
49    # 'onnx::Gather': convert_gather,
50    # 'onnx::ReduceSum': convert_reduce_sum,
51    # 'onnx::Constant': convert_constant,
52    # 'onnx::Upsample': convert_upsample,
53    # 'onnx::Pad': convert_padding,
54}
55
56
57    ############
58    # property #
59    ############
60
61    @property
62    def src_graph(self):
63        return self.pytorch_graph
64
65    def get_weight_name(self, node):
66        pass
67
68    ####################
69    # Public Functions #
70    ####################
71
72    def __init__(self, model_file_name, input_shape):
73        super(PytorchParser, self).__init__()
74        if not os.path.exists(model_file_name):
75            print("Pytorch model file [{}] is not found.".format(model_file_name))
76            assert False
77        # test
78
79        # cpu: https://github.com/pytorch/pytorch/issues/5286
80        try:
81            model = torch.load(model_file_name)
82        except:
83            model = torch.load(model_file_name, map_location='cpu')
84
85        self.weight_loaded = True
86        self.model = model
87        # Build network graph
88        self.pytorch_graph = None
89
90    def build_graph(self, input_shape):
91        self.input_shape = tuple([1] + input_shape)
92        self.pytorch_graph.build(self.input_shape)
93        self.state_dict = self.pytorch_graph.state_dict
94        self.shape_dict = self.pytorch_graph.shape_dict
95
96    def gen_IR(self):
97        for layer in self.src_graph.topological_sort:
98            current_node = self.src_graph.get_node(layer)
99            onnx_node_type = current_node.type
100            node_type = PytorchParser.layer_map[onnx_node_type]
101
102
103            if hasattr(self, "rename_" + node_type):
104                func = getattr(self, "rename_" + node_type)
105                func(current_node)
106
107            else:
108                self.rename_UNKNOWN(current_node)
109
110        self.gen_Input()
111
112
113
114    def _set_output_shape(self, source_node, IR_node):
115
116        shape = graph_pb2.TensorShape()
117
118
119        layer_name = source_node.name
120
121        shape_pytorch = self.shape_dict[layer_name]
122
123
124        new_dim = shape.dim.add()
125
126        # (batch, C, H, W)  & NHWC
127        if len(shape_pytorch) == 4:
128
129            if shape_pytorch[0] == 1:
130                new_dim.size = -1
131            else:
132                new_dim.size = shape_pytorch[0]
133            for index in [2, 3, 1]:
134                new_dim = shape.dim.add()
135                dim = shape_pytorch[index]
136                new_dim.size = dim if dim else -1
137        elif len(shape_pytorch) == 2:
138            if shape_pytorch[0] == 1:
139                new_dim.size = -1
140            else:
141                new_dim.size = shape_pytorch[0]
142            for _ in range(2):
143                new_dim = shape.dim.add()
144                new_dim.size = 1
145            new_dim = shape.dim.add()
146            dim = shape_pytorch[1]
147            new_dim.size = dim if dim else -1
148
149
150        IR_node.attr["_output_shapes"].list.shape.extend([shape])
151
152    ##########
153    # Layers #
154    ##########
155    def rename_UNKNOWN(self, source_node):
156        print("PyTorch parser has not supported operator [%s] with name [%s]."
157              % (source_node.type, source_node.name))
158        assert False
159        print(source_node.layer)
160        print(source_node.layer.data.size())
161
162
163
164
165    def gen_Input(self):
166        IR_node = self.IR_graph.node.add()
167        IR_node.name = 'input'
168        IR_node.op = "DataInput"
169
170        for node in self.IR_graph.node:
171            if node.name in self.src_graph.input_layers:
172                node.input.append('input')
173
174        assert len(self.input_shape) == 4
175        new_dim = IR_node.attr["shape"].shape.dim.add()
176        if self.input_shape[0] == 1:
177            new_dim.size = -1
178        else:
179            new_dim.size = self.input_shape[0]
180        for index in [2, 3, 1]:
181            new_dim = IR_node.attr["shape"].shape.dim.add()
182            new_dim.size = self.input_shape[index]
183
184        shape = graph_pb2.TensorShape()
185        new_dim = shape.dim.add()
186        shape_pytorch = self.input_shape
187
188        if len(shape_pytorch) == 4:
189
190            if shape_pytorch[0] == 1:
191                new_dim.size = -1
192            else:
193                new_dim.size = shape_pytorch[0]
194            for index in [2, 3, 1]:
195                new_dim = shape.dim.add()
196                dim = shape_pytorch[index]
197                new_dim.size = dim if dim else -1
198        elif len(shape_pytorch) == 2:
199            if shape_pytorch[0] == 1:
200                new_dim.size = -1
201            else:
202                new_dim.size = shape_pytorch[0]
203            for _ in range(2):
204                new_dim = shape.dim.add()
205                new_dim.size = 1
206            new_dim = shape.dim.add()
207            dim = shape_pytorch[1]
208            new_dim.size = dim if dim else -1
209
210
211        IR_node.attr["_output_shapes"].list.shape.extend([shape])
212
213
214    def rename_Conv(self, source_node):
215
216        attr = source_node.attrs
217        kwargs = dict()
218
219        # dilation
220        if 'dilations' in attr:
221            kwargs['dilations'] = [1] + attr['dilations'] + [1]
222        else:
223            kwargs['dilations'] = [1] + [1, 1] + [1]
224
225        if len(attr['pads']) == 4:
226            kwargs['pads'] = [0] + attr['pads'][0:2] + [0, 0] + attr['pads'][2:] + [0]
227        elif len(attr['pads']) == 2:
228            kwargs['pads'] = ( [0] + attr['pads'][0:2] + [0] ) *2
229
230        if 'strides' not in attr:
231            kwargs['strides'] = [1] + [1, 1] + [1]
232        else:
233            kwargs['strides'] = [1] + attr['strides'] + [1]
234
235        kwargs['group'] = attr['group']
236
237        weights_scope = self.get_weight_name(source_node)
238
239        bias_name = '{0}.bias'.format(weights_scope)
240        weights_name = '{0}.weight'.format(weights_scope)
241        weight = self.state_dict[weights_name]
242
243        weight = weight.numpy()
244        dim = weight.ndim - 2
245
246
247        IR_node = self._convert_identity_operation(source_node, new_op="Conv")
248        weight = np.transpose(weight, list(range(2, dim + 2)) + [1, 0])
249
250        self.set_weight(source_node.name, 'weights', weight)
251        kwargs['kernel_shape'] = list(weight.shape)
252
253
254        # handle bias
255        if bias_name in self.state_dict:
256            bias = self.state_dict[bias_name].numpy()
257            self.set_weight(source_node.name, 'bias', bias)
258            kwargs['use_bias'] = True
259        else:
260            kwargs['use_bias'] = False
261
262
263        assign_IRnode_values(IR_node, kwargs)
264
265
266    def rename_BatchNormalization(self, source_node):
267        # TODO
268        # output_shape
269
270        IR_node = self._convert_identity_operation(source_node, new_op="BatchNorm")
271
272
273        attr = source_node.attrs
274        # epsilon
275        IR_node.attr['epsilon'].f = attr['epsilon']
276        weights_scope = self.get_weight_name(source_node)
277
278        bias_name = '{0}.bias'.format(weights_scope)
279        weights_name = '{0}.weight'.format(weights_scope)
280        mean_name = '{0}.running_mean'.format(weights_scope)
281        var_name = '{0}.running_var'.format(weights_scope)
282
283
284
285        if bias_name in self.state_dict:
286            beta = self.state_dict[bias_name].numpy()
287            IR_node.attr['bias'].b = True
288        else:
289            IR_node.attr['bias'].b = False
290
291        if weights_name in self.state_dict:
292            gamma = self.state_dict[weights_name].numpy()
293            IR_node.attr['scale'].b = True
294        else:
295            IR_node.attr['scale'].b = False
296
297        mean = self.state_dict[mean_name].numpy()
298        variance = self.state_dict[var_name].numpy()
299
300
301
302        if IR_node.attr['scale'].b:
303            self.set_weight(source_node.name, "scale", gamma)
304
305        if IR_node.attr['bias'].b:
306            self.set_weight(source_node.name, "bias", beta)
307
308        # mean
309        self.set_weight(source_node.name, "mean", mean)
310
311        # var
312        self.set_weight(source_node.name, "var", variance)
313
314    def rename_Pad(self, source_node):
315        IR_node = self._convert_identity_operation(source_node, new_op="Pad")
316        attr = source_node.attrs
317        kwargs = dict()
318        kwargs['mode'] = attr['mode']
319        kwargs['pads'] = attr['pads']
320        kwargs['constant_values'] = attr['value']
321        assign_IRnode_values(IR_node, kwargs)
322
323    def rename_Relu(self, source_node):
324        IR_node = self._convert_identity_operation(source_node, new_op="Relu")
325
326    def rename_Tanh(self, source_node):
327        IR_node = self._convert_identity_operation(source_node, new_op="Tanh")
328
329    def rename_Sigmoid(self, source_node):
330        IR_node = self._convert_identity_operation(source_node, new_op="Sigmoid")
331
332    def rename_Mul(self, source_node):
333        IR_node = self._convert_identity_operation(source_node, new_op="Mul")
334
335    def rename_Maxpool(self, source_node):
336        attr = source_node.attrs
337        kwargs = dict()
338        kwargs['strides'] = [1] + attr['strides'] + [1]
339        if 'dilations' not in attr:
340            kwargs['dilations'] = [1] + [1, 1] + [1]
341        else:
342            kwargs['dilations'] = [1] + attr['dilations'] + [1]
343        kwargs['pads'] = [0] + attr['pads'][0:2] + [0, 0] + attr['pads'][2:] + [0]
344        kwargs['kernel_shape'] = [1] + attr['kernel_shape'] + [1]
345        IR_node = self._convert_identity_operation(source_node, new_op="Pool")
346
347        kwargs['pooling_type'] = 'MAX'
348
349        assign_IRnode_values(IR_node, kwargs)
350
351    def rename_Avgpool(self, source_node):
352        attr = source_node.attrs
353        kwargs = dict()
354        kwargs['strides'] = [1] + attr['strides'] + [1]
355        if 'dilations' not in attr:
356            kwargs['dilations'] = [1] + [1, 1] + [1]
357        else:
358            kwargs['dilations'] = [1] + attr['dilations'] + [1]
359        if 'pads' in attr:
360            kwargs['pads'] = [0] + attr['pads'][0:2] + [0, 0] + attr['pads'][2:] + [0]
361        else:
362            kwargs['pads'] = [0, 0, 0, 0, 0, 0, 0, 0]
363        kwargs['kernel_shape'] = [1] + attr['kernel_shape'] + [1]
364        IR_node = self._convert_identity_operation(source_node, new_op="Pool")
365
366        kwargs['pooling_type'] = 'AVG'
367
368        assign_IRnode_values(IR_node, kwargs)
369
370    def rename_GAvgpool(self, source_node):
371        attr = source_node.attrs
372        input_shape = self.pytorch_graph.shape_dict[source_node.in_edges[0]]
373        kwargs = dict()
374        kwargs['strides'] = [1, 1, 1, 1]
375        kwargs['dilations'] = [1] + [1, 1] + [1]
376        kwargs['pads'] = [0, 0, 0, 0, 0, 0, 0, 0]
377        kwargs['kernel_shape'] = [1] + input_shape[2:] + [1]
378        IR_node = self._convert_identity_operation(source_node, new_op="Pool")
379
380        kwargs['pooling_type'] = 'AVG'
381
382        assign_IRnode_values(IR_node, kwargs)
383
384    def rename_Flatten(self, source_node):
385        IR_node = self._convert_identity_operation(source_node, new_op="Flatten")
386
387    def rename_FullyConnected(self, source_node):
388        IR_node = self._convert_identity_operation(source_node, new_op="FullyConnected")
389        weights_scope = self.get_weight_name(source_node)
390        bias_name = '{0}.bias'.format(weights_scope)
391        weights_name = '{0}.weight'.format(weights_scope)
392
393
394        W = self.state_dict[weights_name].numpy().transpose()
395        input_channels, output_channels = W.shape
396
397        # Kit weight tranpose
398        # weight: N x M -> C x H x W x M -> H x W x C x M -> N x M
399        if self.weight_loaded:
400            parent = self.src_graph.get_parent(source_node.name, [0])
401            while parent.type == 'onnx::Flatten' or parent.type == 'onnx::Dropout':
402                parent = self.src_graph.get_parent(parent.name, [0])
403            if len(self.shape_dict[parent.name]) == 4:
404                #
405                original_shape = W.shape
406                channel_first_list = self.shape_dict[parent.name][1:]
407                dim = len(channel_first_list) + 1
408                weight = W.reshape(channel_first_list + [original_shape[1]])
409                assert dim > 2
410                weight = weight.transpose(list(range(1, dim-1)) + [0, dim-1])
411                W = weight.reshape(original_shape)
412
413        # weights
414        self.set_weight(source_node.name, 'weights', W )
415
416        # use_bias
417        if bias_name in self.state_dict:
418            IR_node.attr['use_bias'].b = True
419            bias = self.state_dict[bias_name].numpy()
420            self.set_weight(source_node.name, 'bias', bias )
421        else:
422            IR_node.attr['use_bias'].b = False
423
424        # units
425        IR_node.attr['units'].i = output_channels
426
427
428    def rename_Dropout(self, source_node):
429        IR_node = self._convert_identity_operation(source_node, new_op='Dropout')
430        IR_node.attr['keep_prob'].f = source_node.attrs['ratio']
431
432    def rename_Concat(self, source_node):
433        IR_node = self._convert_identity_operation(source_node, new_op='Concat')
434
435        if source_node.attrs['axis'] == 1:
436            IR_node.attr['axis'].i = len(self.shape_dict[source_node.name]) - 1
437        else:
438            IR_node.attr['axis'].i = source_node.attrs['axis']
439
440    def rename_Add(self, source_node):
441        IR_node = self._convert_identity_operation(source_node, new_op='Add')
442
443
444    def rename_MaxPool2d(self, source_node):
445        self._convert_pooling(source_node)
446
447
448    def rename_View(self, source_node):
449        IR_node = self._convert_identity_operation(source_node, new_op='Reshape')
450        assign_IRnode_values(IR_node, {'shape' : list(source_node.get_attr('new_sizes'))[1:]})
451
452
453    def rename_Addmm(self, source_node):
454        IR_node = self._convert_identity_operation(source_node, new_op='FullyConnected')
455        kwargs = dict()
456
457        # handle weight
458        weight = source_node.get_attr('next_functions')[2][0].next_functions[0][0].variable.data.numpy()
459        weight = np.transpose(weight)
460        kwargs['units'] = weight.shape[1]
461        self.set_weight(source_node.name, 'weights', weight)
462
463        # handle bias
464        if source_node.get_attr('next_functions')[0][0]:
465            bias = source_node.get_attr('next_functions')[0][0].variable.data.numpy()
466            kwargs['use_bias'] = True
467            self.set_weight(source_node.name, 'bias', weight)
468
469        assign_IRnode_values(IR_node, kwargs)
470
471
472
473    ####################
474    # Helper Functions #
475    ####################
476
477    @staticmethod
478    def _copy_and_reop(source_node, IR_node, new_op = None):
479        if new_op == None: new_op = source_node.type
480        IR_node.name = source_node.name
481        IR_node.op = new_op
482
483
484    def _convert_identity_operation(self, source_node, in_edge_count = None, new_op = None):
485        IR_node = self.IR_graph.node.add()
486        PytorchParser._copy_and_reop(source_node, IR_node, new_op)
487        self.convert_inedge(source_node, IR_node, 0, in_edge_count)
488        self._set_output_shape(source_node, IR_node)
489        return IR_node
490
491    def _convert_pooling(self, source_node):
492        kwargs = dict()
493        kwargs['strides'] = [1] + list(source_node.get_attr('stride')) + [1]
494        kwargs['dilations'] = [1] + list(source_node.get_attr('dilation')) + [1]
495        kwargs['pads'] = ([0] + list(source_node.get_attr('padding')) + [0]) * 2
496        kwargs['kernel_shape'] = [1] + list(source_node.get_attr('kernel_size')) + [1]
497        IR_node = self._convert_identity_operation(source_node, new_op="Pool")
498
499        if source_node.name.startswith('Max'):
500            kwargs['pooling_type'] = 'MAX'
501        elif source_node.name.startswith('Avg'):
502            kwargs['pooling_type'] = 'AVG'
503        else:
504            raise ValueError('Unknown pooling type')
505
506        assign_IRnode_values(IR_node, kwargs)
507
508class PytorchParser040(PytorchParser):
509
510    def __init__(self, model_file_name, input_shape):
511        super(PytorchParser040, self).__init__(model_file_name, input_shape)
512        self.pytorch_graph = PytorchGraph040(self.model)
513        self.build_graph(input_shape)
514
515    def get_weight_name(self, node):
516        return node.weights_name
517
518class PytorchParser151(PytorchParser):
519
520    def __init__(self, model_file_name, input_shape):
521        super(PytorchParser151, self).__init__(model_file_name, input_shape)
522        self.pytorch_graph = PytorchGraph151(self.model)
523        self.build_graph(input_shape)
524
525    def get_weight_name(self, node):
526        return self.pytorch_graph.layer_weight_map[node.name]
527
528