1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3# 4# Licensed to the Apache Software Foundation (ASF) under one 5# or more contributor license agreements. See the NOTICE file 6# distributed with this work for additional information 7# regarding copyright ownership. The ASF licenses this file 8# to you under the Apache License, Version 2.0 (the 9# "License"); you may not use this file except in compliance 10# with the License. You may obtain a copy of the License at 11# 12# http://www.apache.org/licenses/LICENSE-2.0 13# 14# Unless required by applicable law or agreed to in writing, 15# software distributed under the License is distributed on an 16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 17# KIND, either express or implied. See the License for the 18# specific language governing permissions and limitations 19# under the License. 20 21"""ONNX test backend wrapper""" 22try: 23 import onnx.backend.test 24except ImportError: 25 raise ImportError("Onnx and protobuf need to be installed") 26 27import test_cases 28import unittest 29import backend as mxnet_backend 30import logging 31 32operations = ['import', 'export'] 33backends = ['mxnet', 'gluon'] 34# This is a pytest magic variable to load extra plugins 35pytest_plugins = "onnx.backend.test.report", 36 37 38def test_suite(backend_tests): # type: () -> unittest.TestSuite 39 ''' 40 TestSuite that can be run by TestRunner 41 This has been borrowed from onnx/onnx/backend/test/runner/__init__.py, 42 since Python3 cannot sort objects of type 'Type' as Runner.test_suite() 43 expects. 44 ''' 45 suite = unittest.TestSuite() 46 for case in backend_tests.test_cases.values(): 47 suite.addTests(unittest.defaultTestLoader.loadTestsFromTestCase(case)) 48 return suite 49 50 51def prepare_tests(backend, oper): 52 """ 53 Prepare the test list 54 :param backend: mxnet/gluon backend 55 :param oper: str. export or import 56 :return: backend test list 57 """ 58 BACKEND_TESTS = onnx.backend.test.BackendTest(backend, __name__) 59 implemented_ops = test_cases.IMPLEMENTED_OPERATORS_TEST.get('both', []) + \ 60 test_cases.IMPLEMENTED_OPERATORS_TEST.get(oper, []) 61 62 for op_test in implemented_ops: 63 BACKEND_TESTS.include(op_test) 64 65 basic_models = test_cases.BASIC_MODEL_TESTS.get('both', []) + \ 66 test_cases.BASIC_MODEL_TESTS.get(oper, []) 67 68 for basic_model_test in basic_models: 69 BACKEND_TESTS.include(basic_model_test) 70 71 std_models = test_cases.STANDARD_MODEL.get('both', []) + \ 72 test_cases.STANDARD_MODEL.get(oper, []) 73 74 for std_model_test in std_models: 75 BACKEND_TESTS.include(std_model_test) 76 77 # Tests for scalar ops are in test_node.py 78 BACKEND_TESTS.exclude('.*scalar.*') 79 80 return BACKEND_TESTS 81 82 83for bkend in backends: 84 for operation in operations: 85 log = logging.getLogger(bkend + operation) 86 if bkend == 'gluon' and operation == 'export': 87 log.warning('Gluon->ONNX export not implemented. Skipping tests...') 88 continue 89 log.info('Executing tests for ' + bkend + ' backend: ' + operation) 90 mxnet_backend.MXNetBackend.set_params(bkend, operation) 91 BACKEND_TESTS = prepare_tests(mxnet_backend, operation) 92 unittest.TextTestRunner().run(test_suite(BACKEND_TESTS.enable_report())) 93