1#!/usr/local/bin/python3.8
2# vim:fileencoding=utf-8
3# License: GPLv3 Copyright: 2021, Kovid Goyal <kovid at kovidgoyal.net>
4
5import importlib
6import os
7import sys
8import unittest
9try:
10    from importlib.resources import contents
11except Exception:
12    from importlib_resources import contents
13from typing import Callable, Generator, NoReturn, Sequence, Set
14
15
16def itertests(suite: unittest.TestSuite) -> Generator[unittest.TestCase, None, None]:
17    stack = [suite]
18    while stack:
19        suite = stack.pop()
20        for test in suite:
21            if isinstance(test, unittest.TestSuite):
22                stack.append(test)
23                continue
24            if test.__class__.__name__ == 'ModuleImportFailure':
25                raise Exception('Failed to import a test module: %s' % test)
26            yield test
27
28
29def find_all_tests(package: str = '', excludes: Sequence[str] = ('main', 'gr')) -> unittest.TestSuite:
30    suits = []
31    if not package:
32        package = __name__.rpartition('.')[0] if '.' in __name__ else 'kitty_tests'
33    for x in contents(package):
34        name, ext = os.path.splitext(x)
35        if ext in ('.py', '.pyc') and name not in excludes:
36            m = importlib.import_module(package + '.' + x.partition('.')[0])
37            suits.append(unittest.defaultTestLoader.loadTestsFromModule(m))
38    return unittest.TestSuite(suits)
39
40
41def filter_tests(suite: unittest.TestSuite, test_ok: Callable[[unittest.TestCase], bool]) -> unittest.TestSuite:
42    ans = unittest.TestSuite()
43    added: Set[unittest.TestCase] = set()
44    for test in itertests(suite):
45        if test_ok(test) and test not in added:
46            ans.addTest(test)
47            added.add(test)
48    return ans
49
50
51def filter_tests_by_name(suite: unittest.TestSuite, *names: str) -> unittest.TestSuite:
52    names_ = {x if x.startswith('test_') else 'test_' + x for x in names}
53
54    def q(test: unittest.TestCase) -> bool:
55        return test._testMethodName in names_
56    return filter_tests(suite, q)
57
58
59def filter_tests_by_module(suite: unittest.TestSuite, *names: str) -> unittest.TestSuite:
60    names_ = frozenset(names)
61
62    def q(test: unittest.TestCase) -> bool:
63        m = test.__class__.__module__.rpartition('.')[-1]
64        return m in names_
65    return filter_tests(suite, q)
66
67
68def type_check() -> NoReturn:
69    from kitty.cli_stub import generate_stub  # type:ignore
70    generate_stub()
71    from kittens.tui.operations_stub import generate_stub  # type: ignore
72    generate_stub()
73    os.execlp(sys.executable, 'python', '-m', 'mypy', '--pretty')
74
75
76def run_cli(suite: unittest.TestSuite, verbosity: int = 4) -> None:
77    r = unittest.TextTestRunner
78    r.resultclass = unittest.TextTestResult
79    runner = r(verbosity=verbosity)
80    runner.tb_locals = True  # type: ignore
81    result = runner.run(suite)
82    if not result.wasSuccessful():
83        raise SystemExit(1)
84
85
86def run_tests() -> None:
87    import argparse
88    parser = argparse.ArgumentParser()
89    parser.add_argument(
90        'name', nargs='*', default=[],
91        help='The name of the test to run, for e.g. linebuf corresponds to test_linebuf. Can be specified multiple times')
92    parser.add_argument('--verbosity', default=4, type=int, help='Test verbosity')
93    args = parser.parse_args()
94    if args.name and args.name[0] in ('type-check', 'type_check', 'mypy'):
95        type_check()
96    tests = find_all_tests()
97    if args.name:
98        tests = filter_tests_by_name(tests, *args.name)
99        if not tests._tests:
100            raise SystemExit('No test named %s found' % args.name)
101    run_cli(tests, args.verbosity)
102