xref: /freebsd/tests/atf_python/atf_pytest.py (revision 9768746b)
1import types
2from typing import Any
3from typing import Dict
4from typing import List
5from typing import NamedTuple
6from typing import Optional
7from typing import Tuple
8
9import pytest
10import os
11
12
13def nodeid_to_method_name(nodeid: str) -> str:
14    """file_name.py::ClassName::method_name[parametrize] -> method_name"""
15    return nodeid.split("::")[-1].split("[")[0]
16
17
18class ATFCleanupItem(pytest.Item):
19    def runtest(self):
20        """Runs cleanup procedure for the test instead of the test itself"""
21        instance = self.parent.cls()
22        cleanup_name = "cleanup_{}".format(nodeid_to_method_name(self.nodeid))
23        if hasattr(instance, cleanup_name):
24            cleanup = getattr(instance, cleanup_name)
25            cleanup(self.nodeid)
26        elif hasattr(instance, "cleanup"):
27            instance.cleanup(self.nodeid)
28
29    def setup_method_noop(self, method):
30        """Overrides runtest setup method"""
31        pass
32
33    def teardown_method_noop(self, method):
34        """Overrides runtest teardown method"""
35        pass
36
37
38class ATFTestObj(object):
39    def __init__(self, obj, has_cleanup):
40        # Use nodeid without name to properly name class-derived tests
41        self.ident = obj.nodeid.split("::", 1)[1]
42        self.description = self._get_test_description(obj)
43        self.has_cleanup = has_cleanup
44        self.obj = obj
45
46    def _get_test_description(self, obj):
47        """Returns first non-empty line from func docstring or func name"""
48        docstr = obj.function.__doc__
49        if docstr:
50            for line in docstr.split("\n"):
51                if line:
52                    return line
53        return obj.name
54
55    @staticmethod
56    def _convert_user_mark(mark, obj, ret: Dict):
57        username = mark.args[0]
58        if username == "unprivileged":
59            # Special unprivileged user requested.
60            # First, require the unprivileged-user config option presence
61            key = "require.config"
62            if key not in ret:
63                ret[key] = "unprivileged_user"
64            else:
65                ret[key] = "{} {}".format(ret[key], "unprivileged_user")
66        # Check if the framework requires root
67        test_cls = ATFHandler.get_test_class(obj)
68        if test_cls and getattr(test_cls, "NEED_ROOT", False):
69            # Yes, so we ask kyua to run us under root instead
70            # It is up to the implementation to switch back to the desired
71            # user
72            ret["require.user"] = "root"
73        else:
74            ret["require.user"] = username
75
76
77    def _convert_marks(self, obj) -> Dict[str, Any]:
78        wj_func = lambda x: " ".join(x)  # noqa: E731
79        _map: Dict[str, Dict] = {
80            "require_user": {"handler": self._convert_user_mark},
81            "require_arch": {"name": "require.arch", "fmt": wj_func},
82            "require_diskspace": {"name": "require.diskspace"},
83            "require_files": {"name": "require.files", "fmt": wj_func},
84            "require_machine": {"name": "require.machine", "fmt": wj_func},
85            "require_memory": {"name": "require.memory"},
86            "require_progs": {"name": "require.progs", "fmt": wj_func},
87            "timeout": {},
88        }
89        ret = {}
90        for mark in obj.iter_markers():
91            if mark.name in _map:
92                if "handler" in _map[mark.name]:
93                    _map[mark.name]["handler"](mark, obj, ret)
94                    continue
95                name = _map[mark.name].get("name", mark.name)
96                if "fmt" in _map[mark.name]:
97                    val = _map[mark.name]["fmt"](mark.args[0])
98                else:
99                    val = mark.args[0]
100                ret[name] = val
101        return ret
102
103    def as_lines(self) -> List[str]:
104        """Output test definition in ATF-specific format"""
105        ret = []
106        ret.append("ident: {}".format(self.ident))
107        ret.append("descr: {}".format(self._get_test_description(self.obj)))
108        if self.has_cleanup:
109            ret.append("has.cleanup: true")
110        for key, value in self._convert_marks(self.obj).items():
111            ret.append("{}: {}".format(key, value))
112        return ret
113
114
115class ATFHandler(object):
116    class ReportState(NamedTuple):
117        state: str
118        reason: str
119
120    def __init__(self, report_file_name: Optional[str]):
121        self._tests_state_map: Dict[str, ReportStatus] = {}
122        self._report_file_name = report_file_name
123        self._report_file_handle = None
124
125    def setup_configure(self):
126        fname = self._report_file_name
127        if fname:
128            self._report_file_handle = open(fname, mode="w")
129
130    def setup_method_pre(self, item):
131        """Called before actually running the test setup_method"""
132        # Check if we need to manually drop the privileges
133        for mark in item.iter_markers():
134            if mark.name == "require_user":
135                cls = self.get_test_class(item)
136                cls.TARGET_USER = mark.args[0]
137                break
138
139    def override_runtest(self, obj):
140        # Override basic runtest command
141        obj.runtest = types.MethodType(ATFCleanupItem.runtest, obj)
142        # Override class setup/teardown
143        obj.parent.cls.setup_method = ATFCleanupItem.setup_method_noop
144        obj.parent.cls.teardown_method = ATFCleanupItem.teardown_method_noop
145
146    @staticmethod
147    def get_test_class(obj):
148        if hasattr(obj, "parent") and obj.parent is not None:
149            if hasattr(obj.parent, "cls"):
150                return obj.parent.cls
151
152    def has_object_cleanup(self, obj):
153        cls = self.get_test_class(obj)
154        if cls is not None:
155            method_name = nodeid_to_method_name(obj.nodeid)
156            cleanup_name = "cleanup_{}".format(method_name)
157            if hasattr(cls, "cleanup") or hasattr(cls, cleanup_name):
158                return True
159        return False
160
161    def list_tests(self, tests: List[str]):
162        print('Content-Type: application/X-atf-tp; version="1"')
163        print()
164        for test_obj in tests:
165            has_cleanup = self.has_object_cleanup(test_obj)
166            atf_test = ATFTestObj(test_obj, has_cleanup)
167            for line in atf_test.as_lines():
168                print(line)
169            print()
170
171    def set_report_state(self, test_name: str, state: str, reason: str):
172        self._tests_state_map[test_name] = self.ReportState(state, reason)
173
174    def _extract_report_reason(self, report):
175        data = report.longrepr
176        if data is None:
177            return None
178        if isinstance(data, Tuple):
179            # ('/path/to/test.py', 23, 'Skipped: unable to test')
180            reason = data[2]
181            for prefix in "Skipped: ":
182                if reason.startswith(prefix):
183                    reason = reason[len(prefix):]
184            return reason
185        else:
186            # string/ traceback / exception report. Capture the last line
187            return str(data).split("\n")[-1]
188        return None
189
190    def add_report(self, report):
191        # MAP pytest report state to the atf-desired state
192        #
193        # ATF test states:
194        # (1) expected_death, (2) expected_exit, (3) expected_failure
195        # (4) expected_signal, (5) expected_timeout, (6) passed
196        # (7) skipped, (8) failed
197        #
198        # Note that ATF don't have the concept of "soft xfail" - xpass
199        # is a failure. It also calls teardown routine in a separate
200        # process, thus teardown states (pytest-only) are handled as
201        # body continuation.
202
203        # (stage, state, wasxfail)
204
205        # Just a passing test: WANT: passed
206        # GOT: (setup, passed, F), (call, passed, F), (teardown, passed, F)
207        #
208        # Failing body test: WHAT: failed
209        # GOT: (setup, passed, F), (call, failed, F), (teardown, passed, F)
210        #
211        # pytest.skip test decorator: WANT: skipped
212        # GOT: (setup,skipped, False), (teardown, passed, False)
213        #
214        # pytest.skip call inside test function: WANT: skipped
215        # GOT: (setup, passed, F), (call, skipped, F), (teardown,passed, F)
216        #
217        # mark.xfail decorator+pytest.xfail: WANT: expected_failure
218        # GOT: (setup, passed, F), (call, skipped, T), (teardown, passed, F)
219        #
220        # mark.xfail decorator+pass: WANT: failed
221        # GOT: (setup, passed, F), (call, passed, T), (teardown, passed, F)
222
223        test_name = report.location[2]
224        stage = report.when
225        state = report.outcome
226        reason = self._extract_report_reason(report)
227
228        # We don't care about strict xfail - it gets translated to False
229
230        if stage == "setup":
231            if state in ("skipped", "failed"):
232                # failed init -> failed test, skipped setup -> xskip
233                # for the whole test
234                self.set_report_state(test_name, state, reason)
235        elif stage == "call":
236            # "call" stage shouldn't matter if setup failed
237            if test_name in self._tests_state_map:
238                if self._tests_state_map[test_name].state == "failed":
239                    return
240            if state == "failed":
241                # Record failure  & override "skipped" state
242                self.set_report_state(test_name, state, reason)
243            elif state == "skipped":
244                if hasattr(reason, "wasxfail"):
245                    # xfail() called in the test body
246                    state = "expected_failure"
247                else:
248                    # skip inside the body
249                    pass
250                self.set_report_state(test_name, state, reason)
251            elif state == "passed":
252                if hasattr(reason, "wasxfail"):
253                    # the test was expected to fail but didn't
254                    # mark as hard failure
255                    state = "failed"
256                self.set_report_state(test_name, state, reason)
257        elif stage == "teardown":
258            if state == "failed":
259                # teardown should be empty, as the cleanup
260                # procedures should be implemented as a separate
261                # function/method, so mark teardown failure as
262                # global failure
263                self.set_report_state(test_name, state, reason)
264
265    def write_report(self):
266        if self._report_file_handle is None:
267            return
268        if self._tests_state_map:
269            # If we're executing in ATF mode, there has to be just one test
270            # Anyway, deterministically pick the first one
271            first_test_name = next(iter(self._tests_state_map))
272            test = self._tests_state_map[first_test_name]
273            if test.state == "passed":
274                line = test.state
275            else:
276                line = "{}: {}".format(test.state, test.reason)
277            print(line, file=self._report_file_handle)
278        self._report_file_handle.close()
279
280    @staticmethod
281    def get_atf_vars() -> Dict[str, str]:
282        px = "_ATF_VAR_"
283        return {k[len(px):]: v for k, v in os.environ.items() if k.startswith(px)}
284