1# -*- coding: utf-8 -*-
2"""Pytest plugin for testing xsh files."""
3import sys
4import importlib
5from traceback import format_list, extract_tb
6
7import pytest
8
9from xonsh.imphooks import install_import_hooks
10
11
12def pytest_configure(config):
13    install_import_hooks()
14
15
16def pytest_collection_modifyitems(items):
17    items.sort(key=lambda x: 0 if isinstance(x, XshFunction) else 1)
18
19
20def _limited_traceback(excinfo):
21    """ Return a formatted traceback with all the stack
22        from this frame (i.e __file__) up removed
23    """
24    tb = extract_tb(excinfo.tb)
25    try:
26        idx = [__file__ in e for e in tb].index(True)
27        return format_list(tb[idx + 1 :])
28    except ValueError:
29        return format_list(tb)
30
31
32def pytest_collect_file(parent, path):
33    if path.ext.lower() == ".xsh" and path.basename.startswith("test_"):
34        return XshFile(path, parent)
35
36
37class XshFile(pytest.File):
38    def collect(self):
39        sys.path.append(self.fspath.dirname)
40        mod = importlib.import_module(self.fspath.purebasename)
41        sys.path.pop(0)
42        tests = [t for t in dir(mod) if t.startswith("test_")]
43        for test_name in tests:
44            obj = getattr(mod, test_name)
45            if hasattr(obj, "__call__"):
46                yield XshFunction(
47                    name=test_name, parent=self, test_func=obj, test_module=mod
48                )
49
50
51class XshFunction(pytest.Item):
52    def __init__(self, name, parent, test_func, test_module):
53        super().__init__(name, parent)
54        self._test_func = test_func
55        self._test_module = test_module
56
57    def runtest(self, *args, **kwargs):
58        self._test_func(*args, **kwargs)
59
60    def repr_failure(self, excinfo):
61        """ called when self.runtest() raises an exception. """
62        formatted_tb = _limited_traceback(excinfo)
63        formatted_tb.insert(0, "xonsh execution failed\n")
64        formatted_tb.append("{}: {}".format(excinfo.type.__name__, excinfo.value))
65        return "".join(formatted_tb)
66
67    def reportinfo(self):
68        return self.fspath, 0, "xonsh test: {}".format(self.name)
69