1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements.  See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership.  The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License.  You may obtain a copy of the License at
11#
12#   http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing,
15# software distributed under the License is distributed on an
16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17# KIND, either express or implied.  See the License for the
18# specific language governing permissions and limitations
19# under the License.
20
21"""ONNX test backend wrapper"""
22try:
23    import onnx.backend.test
24except ImportError:
25    raise ImportError("Onnx and protobuf need to be installed")
26
27import test_cases
28import unittest
29import backend as mxnet_backend
30import logging
31
32operations = ['import', 'export']
33backends = ['mxnet', 'gluon']
34# This is a pytest magic variable to load extra plugins
35pytest_plugins = "onnx.backend.test.report",
36
37
38def test_suite(backend_tests):  # type: () -> unittest.TestSuite
39    '''
40    TestSuite that can be run by TestRunner
41    This has been borrowed from onnx/onnx/backend/test/runner/__init__.py,
42    since Python3 cannot sort objects of type 'Type' as Runner.test_suite()
43    expects.
44    '''
45    suite = unittest.TestSuite()
46    for case in backend_tests.test_cases.values():
47        suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(case))
48    return suite
49
50
51def prepare_tests(backend, oper):
52    """
53    Prepare the test list
54    :param backend: mxnet/gluon backend
55    :param oper: str. export or import
56    :return: backend test list
57    """
58    BACKEND_TESTS = onnx.backend.test.BackendTest(backend, __name__)
59    implemented_ops = test_cases.IMPLEMENTED_OPERATORS_TEST.get('both', []) + \
60                      test_cases.IMPLEMENTED_OPERATORS_TEST.get(oper, [])
61
62    for op_test in implemented_ops:
63        BACKEND_TESTS.include(op_test)
64
65    basic_models = test_cases.BASIC_MODEL_TESTS.get('both', []) + \
66                   test_cases.BASIC_MODEL_TESTS.get(oper, [])
67
68    for basic_model_test in basic_models:
69        BACKEND_TESTS.include(basic_model_test)
70
71    std_models = test_cases.STANDARD_MODEL.get('both', []) + \
72                 test_cases.STANDARD_MODEL.get(oper, [])
73
74    for std_model_test in std_models:
75        BACKEND_TESTS.include(std_model_test)
76
77    # Tests for scalar ops are in test_node.py
78    BACKEND_TESTS.exclude('.*scalar.*')
79
80    return BACKEND_TESTS
81
82
83for bkend in backends:
84    for operation in operations:
85        log = logging.getLogger(bkend + operation)
86        if bkend == 'gluon' and operation == 'export':
87            log.warning('Gluon->ONNX export not implemented. Skipping tests...')
88            continue
89        log.info('Executing tests for ' + bkend + ' backend: ' + operation)
90        mxnet_backend.MXNetBackend.set_params(bkend, operation)
91        BACKEND_TESTS = prepare_tests(mxnet_backend, operation)
92        unittest.TextTestRunner().run(test_suite(BACKEND_TESTS.enable_report()))
93