1#---------------------------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for license information. 4#---------------------------------------------------------------------------------------------- 5 6from collections import namedtuple 7import numpy as np 8from mmdnn.conversion.examples.imagenet_test import TestKit 9import mxnet as mx 10 11Batch = namedtuple('Batch', ['data']) 12 13 14class TestMXNet(TestKit): 15 16 def __init__(self): 17 super(TestMXNet, self).__init__() 18 19 self.truth['tensorflow']['inception_v3'] = [(22, 9.6691055), (24, 4.3524752), (25, 3.5957956), (132, 3.5657482), (23, 3.3462858)] 20 self.truth['keras']['inception_v3'] = [(21, 0.93430501), (23, 0.0028834261), (131, 0.0014781745), (24, 0.0014518937), (22, 0.0014435325)] 21 22 self.model = self.MainModel.RefactorModel() 23 self.model = self.MainModel.deploy_weight(self.model, self.args.w) 24 25 26 def preprocess(self, image_path): 27 self.data = super(TestMXNet, self).preprocess(image_path) 28 self.data = np.swapaxes(self.data, 0, 2) 29 self.data = np.swapaxes(self.data, 1, 2) 30 self.data = np.expand_dims(self.data, 0) 31 32 33 def print_result(self): 34 self.model.forward(Batch([mx.nd.array(self.data)])) 35 prob = self.model.get_outputs()[0].asnumpy() 36 super(TestMXNet, self).print_result(prob) 37 38 39 def inference(self, image_path): 40 self.preprocess(image_path) 41 42 # self.print_intermediate_result('pooling0', False) 43 44 self.print_result() 45 46 self.test_truth() 47 48 49 def print_intermediate_result(self, layer_name, if_transpose = False): 50 internals = self.model.symbol.get_internals() 51 intermediate_output = internals[layer_name + "_output"] 52 test_model = mx.mod.Module(symbol=intermediate_output, context=mx.cpu(), data_names=['data']) 53 if self.args.preprocess == 'vgg19' or self.args.preprocess == 'inception_v1': 54 test_model.bind(for_training=False, data_shapes = [('data', (1, 3, 224, 224))]) 55 elif 'resnet' in self.args.preprocess or self.args.preprocess == 'inception_v3': 56 test_model.bind(for_training=False, data_shapes = [('data', (1, 3, 299, 299))]) 57 else: 58 assert False 59 60 arg_params, aux_params = self.model.get_params() 61 62 test_model.set_params(arg_params = arg_params, aux_params = aux_params, allow_missing = True, allow_extra = True) 63 test_model.forward(Batch([mx.nd.array(self.data)])) 64 intermediate_output = test_model.get_outputs()[0].asnumpy() 65 66 super(TestMXNet, self).print_intermediate_result(intermediate_output, if_transpose) 67 68 69 def dump(self, path = None): 70 if path is None: path = self.args.dump 71 self.model.save_checkpoint(path, 0) 72 print ('MXNet checkpoint file is saved with prefix [{}] and iteration 0, generated by [{}.py] and [{}].'.format( 73 path, self.args.n, self.args.w)) 74 75 76if __name__ == '__main__': 77 tester = TestMXNet() 78 if tester.args.dump: 79 tester.dump() 80 else: 81 tester.inference(tester.args.image) 82