1#---------------------------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for license information. 4#---------------------------------------------------------------------------------------------- 5 6import argparse 7import numpy as np 8import sys 9import os 10from mmdnn.conversion.examples.imagenet_test import TestKit 11import torch 12 13 14class TestTorch(TestKit): 15 16 def __init__(self): 17 super(TestTorch, self).__init__() 18 19 self.truth['tensorflow']['inception_v3'] = [(22, 9.6691055), (24, 4.3524747), (25, 3.5957973), (132, 3.5657473), (23, 3.346283)] 20 self.truth['keras']['inception_v3'] = [(21, 0.93430489), (23, 0.002883445), (131, 0.0014781791), (24, 0.0014518998), (22, 0.0014435351)] 21 22 self.model = self.MainModel.KitModel(self.args.w) 23 self.model.eval() 24 25 def preprocess(self, image_path): 26 x = super(TestTorch, self).preprocess(image_path) 27 x = np.transpose(x, (2, 0, 1)) 28 x = np.expand_dims(x, 0).copy() 29 self.data = torch.from_numpy(x) 30 self.data = torch.autograd.Variable(self.data, requires_grad = False) 31 32 33 def print_result(self): 34 predict = self.model(self.data) 35 predict = predict.data.numpy() 36 super(TestTorch, self).print_result(predict) 37 38 39 def print_intermediate_result(self, layer_name, if_transpose=False): 40 intermediate_output = self.model.test.data.numpy() 41 super(TestTorch, self).print_intermediate_result(intermediate_output, if_transpose) 42 43 44 def inference(self, image_path): 45 self.preprocess(image_path) 46 47 self.print_result() 48 49 # self.print_intermediate_result(None, False) 50 51 self.test_truth() 52 53 54 def dump(self, path=None): 55 if path is None: path = self.args.dump 56 torch.save(self.model, path) 57 print('PyTorch model file is saved as [{}], generated by [{}.py] and [{}].'.format( 58 path, self.args.n, self.args.w)) 59 60 61if __name__=='__main__': 62 tester = TestTorch() 63 if tester.args.dump: 64 tester.dump() 65 else: 66 tester.inference(tester.args.image) 67