1from __future__ import absolute_import
2from __future__ import print_function
3
4import os
5import sys
6from conversion_imagenet import TestModels
7
8def get_test_table():
9    return { 'pytorch' :
10        {
11            'vgg19'    : [
12                #TestModels.onnx_emit,
13                TestModels.caffe_emit,
14                #TestModels.cntk_emit,
15                TestModels.coreml_emit,
16                TestModels.keras_emit,
17                TestModels.mxnet_emit,
18                TestModels.pytorch_emit,
19                TestModels.tensorflow_emit
20                ],
21            'vgg19_bn'    : [
22                #TestModels.onnx_emit,
23                TestModels.caffe_emit,
24                #TestModels.cntk_emit,
25                TestModels.coreml_emit,
26                TestModels.keras_emit,
27                TestModels.mxnet_emit,
28                TestModels.pytorch_emit,
29                TestModels.tensorflow_emit
30                ]
31        }
32    }
33
34def test_pytorch():
35    test_table = get_test_table()
36    tester = TestModels(test_table)
37    tester._test_function('pytorch', tester.pytorch_parse)
38
39
40if __name__ == '__main__':
41    test_pytorch()
42