1#----------------------------------------------------------------------------------------------
2#  Copyright (c) Microsoft Corporation. All rights reserved.
3#  Licensed under the MIT License. See License.txt in the project root for license information.
4#----------------------------------------------------------------------------------------------
5
6import argparse
7import numpy as np
8import sys
9import os
10from mmdnn.conversion.examples.imagenet_test import TestKit
11import torch
12
13
14class TestTorch(TestKit):
15
16    def __init__(self):
17        super(TestTorch, self).__init__()
18
19        self.truth['tensorflow']['inception_v3'] = [(22, 9.6691055), (24, 4.3524747), (25, 3.5957973), (132, 3.5657473), (23, 3.346283)]
20        self.truth['keras']['inception_v3'] = [(21, 0.93430489), (23, 0.002883445), (131, 0.0014781791), (24, 0.0014518998), (22, 0.0014435351)]
21
22        self.model = self.MainModel.KitModel(self.args.w)
23        self.model.eval()
24
25    def preprocess(self, image_path):
26        x = super(TestTorch, self).preprocess(image_path)
27        x = np.transpose(x, (2, 0, 1))
28        x = np.expand_dims(x, 0).copy()
29        self.data = torch.from_numpy(x)
30        self.data = torch.autograd.Variable(self.data, requires_grad = False)
31
32
33    def print_result(self):
34        predict = self.model(self.data)
35        predict = predict.data.numpy()
36        super(TestTorch, self).print_result(predict)
37
38
39    def print_intermediate_result(self, layer_name, if_transpose=False):
40        intermediate_output = self.model.test.data.numpy()
41        super(TestTorch, self).print_intermediate_result(intermediate_output, if_transpose)
42
43
44    def inference(self, image_path):
45        self.preprocess(image_path)
46
47        self.print_result()
48
49        # self.print_intermediate_result(None, False)
50
51        self.test_truth()
52
53
54    def dump(self, path=None):
55        if path is None: path = self.args.dump
56        torch.save(self.model, path)
57        print('PyTorch model file is saved as [{}], generated by [{}.py] and [{}].'.format(
58              path, self.args.n, self.args.w))
59
60
61if __name__=='__main__':
62    tester = TestTorch()
63    if tester.args.dump:
64        tester.dump()
65    else:
66        tester.inference(tester.args.image)
67