1#----------------------------------------------------------------------------------------------
2#  Copyright (c) Microsoft Corporation. All rights reserved.
3#  Licensed under the MIT License. See License.txt in the project root for license information.
4#----------------------------------------------------------------------------------------------
5
6from collections import namedtuple
7import numpy as np
8from mmdnn.conversion.examples.imagenet_test import TestKit
9import mxnet as mx
10
11Batch = namedtuple('Batch', ['data'])
12
13
14class TestMXNet(TestKit):
15
16    def __init__(self):
17        super(TestMXNet, self).__init__()
18
19        self.truth['tensorflow']['inception_v3'] = [(22, 9.6691055), (24, 4.3524752), (25, 3.5957956), (132, 3.5657482), (23, 3.3462858)]
20        self.truth['keras']['inception_v3'] = [(21, 0.93430501), (23, 0.0028834261), (131, 0.0014781745), (24, 0.0014518937), (22, 0.0014435325)]
21
22        self.model = self.MainModel.RefactorModel()
23        self.model = self.MainModel.deploy_weight(self.model, self.args.w)
24
25
26    def preprocess(self, image_path):
27        self.data = super(TestMXNet, self).preprocess(image_path)
28        self.data = np.swapaxes(self.data, 0, 2)
29        self.data = np.swapaxes(self.data, 1, 2)
30        self.data = np.expand_dims(self.data, 0)
31
32
33    def print_result(self):
34        self.model.forward(Batch([mx.nd.array(self.data)]))
35        prob = self.model.get_outputs()[0].asnumpy()
36        super(TestMXNet, self).print_result(prob)
37
38
39    def inference(self, image_path):
40        self.preprocess(image_path)
41
42        # self.print_intermediate_result('pooling0', False)
43
44        self.print_result()
45
46        self.test_truth()
47
48
49    def print_intermediate_result(self, layer_name, if_transpose = False):
50        internals = self.model.symbol.get_internals()
51        intermediate_output = internals[layer_name + "_output"]
52        test_model = mx.mod.Module(symbol=intermediate_output, context=mx.cpu(), data_names=['data'])
53        if self.args.preprocess == 'vgg19' or self.args.preprocess == 'inception_v1':
54            test_model.bind(for_training=False, data_shapes = [('data', (1, 3, 224, 224))])
55        elif 'resnet' in self.args.preprocess or self.args.preprocess == 'inception_v3':
56            test_model.bind(for_training=False, data_shapes = [('data', (1, 3, 299, 299))])
57        else:
58            assert False
59
60        arg_params, aux_params = self.model.get_params()
61
62        test_model.set_params(arg_params = arg_params, aux_params = aux_params, allow_missing = True, allow_extra = True)
63        test_model.forward(Batch([mx.nd.array(self.data)]))
64        intermediate_output = test_model.get_outputs()[0].asnumpy()
65
66        super(TestMXNet, self).print_intermediate_result(intermediate_output, if_transpose)
67
68
69    def dump(self, path = None):
70        if path is None: path = self.args.dump
71        self.model.save_checkpoint(path, 0)
72        print ('MXNet checkpoint file is saved with prefix [{}] and iteration 0, generated by [{}.py] and [{}].'.format(
73            path, self.args.n, self.args.w))
74
75
76if __name__ == '__main__':
77    tester = TestMXNet()
78    if tester.args.dump:
79        tester.dump()
80    else:
81        tester.inference(tester.args.image)
82