1# -*- coding: utf-8 -*-
2
3# stdlib imports
4import subprocess
5import re
6import sys
7
8# third-party imports
9import pytest
10import toml
11
12
13HISTKEY = "black/mtimes"
14
15
16def pytest_addoption(parser):
17    group = parser.getgroup("general")
18    group.addoption(
19        "--black", action="store_true", help="enable format checking with black"
20    )
21
22
23def pytest_collect_file(path, parent):
24    config = parent.config
25    if config.option.black and path.ext == ".py":
26        if hasattr(BlackItem, "from_parent"):
27            return BlackItem.from_parent(parent, fspath=path)
28        else:
29            return BlackItem(path, parent)
30
31
32def pytest_configure(config):
33    # load cached mtimes at session startup
34    if config.option.black and hasattr(config, "cache"):
35        config._blackmtimes = config.cache.get(HISTKEY, {})
36    config.addinivalue_line("markers", "black: enable format checking with black")
37
38
39def pytest_unconfigure(config):
40    # save cached mtimes at end of session
41    if hasattr(config, "_blackmtimes"):
42        config.cache.set(HISTKEY, config._blackmtimes)
43
44
45class BlackItem(pytest.Item, pytest.File):
46    def __init__(self, fspath, parent):
47        super(BlackItem, self).__init__(fspath, parent)
48        self._nodeid += "::BLACK"
49        self.add_marker("black")
50        try:
51            with open("pyproject.toml") as toml_file:
52                settings = toml.load(toml_file)["tool"]["black"]
53            if "include" in settings.keys():
54                settings["include"] = self._re_fix_verbose(settings["include"])
55            if "exclude" in settings.keys():
56                settings["exclude"] = self._re_fix_verbose(settings["exclude"])
57            self.pyproject = settings
58        except Exception:
59            self.pyproject = {}
60
61    def setup(self):
62        pytest.importorskip("black")
63        mtimes = getattr(self.config, "_blackmtimes", {})
64        self._blackmtime = self.fspath.mtime()
65        old = mtimes.get(str(self.fspath), 0)
66        if self._blackmtime == old:
67            pytest.skip("file(s) previously passed black format checks")
68
69        if self._skip_test():
70            pytest.skip("file(s) excluded by pyproject.toml")
71
72    def runtest(self):
73        cmd = [sys.executable, "-m", "black", "--check", "--diff", "--quiet", str(self.fspath)]
74        try:
75            subprocess.run(
76                cmd, check=True, stdout=subprocess.PIPE, universal_newlines=True
77            )
78        except subprocess.CalledProcessError as e:
79            raise BlackError(e)
80
81        mtimes = getattr(self.config, "_blackmtimes", {})
82        mtimes[str(self.fspath)] = self._blackmtime
83
84    def repr_failure(self, excinfo):
85        if excinfo.errisinstance(BlackError):
86            return excinfo.value.args[0].stdout
87        return super(BlackItem, self).repr_failure(excinfo)
88
89    def reportinfo(self):
90        return (self.fspath, -1, "Black format check")
91
92    def _skip_test(self):
93        return self._excluded() or (not self._included())
94
95    def _included(self):
96        if "include" not in self.pyproject:
97            return True
98        return re.search(self.pyproject["include"], str(self.fspath))
99
100    def _excluded(self):
101        if "exclude" not in self.pyproject:
102            return False
103        return re.search(self.pyproject["exclude"], str(self.fspath))
104
105    def _re_fix_verbose(self, regex):
106        if "\n" in regex:
107            regex = "(?x)" + regex
108        return re.compile(regex)
109
110    def collect(self):
111        """ returns a list of children (items and collectors)
112            for this collection node.
113        """
114        return (self,)
115
116
117class BlackError(Exception):
118    pass
119