1# -*- coding: utf-8 -*- 2 3# stdlib imports 4import subprocess 5import re 6import sys 7 8# third-party imports 9import pytest 10import toml 11 12 13HISTKEY = "black/mtimes" 14 15 16def pytest_addoption(parser): 17 group = parser.getgroup("general") 18 group.addoption( 19 "--black", action="store_true", help="enable format checking with black" 20 ) 21 22 23def pytest_collect_file(path, parent): 24 config = parent.config 25 if config.option.black and path.ext == ".py": 26 if hasattr(BlackItem, "from_parent"): 27 return BlackItem.from_parent(parent, fspath=path) 28 else: 29 return BlackItem(path, parent) 30 31 32def pytest_configure(config): 33 # load cached mtimes at session startup 34 if config.option.black and hasattr(config, "cache"): 35 config._blackmtimes = config.cache.get(HISTKEY, {}) 36 config.addinivalue_line("markers", "black: enable format checking with black") 37 38 39def pytest_unconfigure(config): 40 # save cached mtimes at end of session 41 if hasattr(config, "_blackmtimes"): 42 config.cache.set(HISTKEY, config._blackmtimes) 43 44 45class BlackItem(pytest.Item, pytest.File): 46 def __init__(self, fspath, parent): 47 super(BlackItem, self).__init__(fspath, parent) 48 self._nodeid += "::BLACK" 49 self.add_marker("black") 50 try: 51 with open("pyproject.toml") as toml_file: 52 settings = toml.load(toml_file)["tool"]["black"] 53 if "include" in settings.keys(): 54 settings["include"] = self._re_fix_verbose(settings["include"]) 55 if "exclude" in settings.keys(): 56 settings["exclude"] = self._re_fix_verbose(settings["exclude"]) 57 self.pyproject = settings 58 except Exception: 59 self.pyproject = {} 60 61 def setup(self): 62 pytest.importorskip("black") 63 mtimes = getattr(self.config, "_blackmtimes", {}) 64 self._blackmtime = self.fspath.mtime() 65 old = mtimes.get(str(self.fspath), 0) 66 if self._blackmtime == old: 67 pytest.skip("file(s) previously passed black format checks") 68 69 if self._skip_test(): 70 pytest.skip("file(s) excluded by pyproject.toml") 71 72 def runtest(self): 73 cmd = [sys.executable, "-m", "black", "--check", "--diff", "--quiet", str(self.fspath)] 74 try: 75 subprocess.run( 76 cmd, check=True, stdout=subprocess.PIPE, universal_newlines=True 77 ) 78 except subprocess.CalledProcessError as e: 79 raise BlackError(e) 80 81 mtimes = getattr(self.config, "_blackmtimes", {}) 82 mtimes[str(self.fspath)] = self._blackmtime 83 84 def repr_failure(self, excinfo): 85 if excinfo.errisinstance(BlackError): 86 return excinfo.value.args[0].stdout 87 return super(BlackItem, self).repr_failure(excinfo) 88 89 def reportinfo(self): 90 return (self.fspath, -1, "Black format check") 91 92 def _skip_test(self): 93 return self._excluded() or (not self._included()) 94 95 def _included(self): 96 if "include" not in self.pyproject: 97 return True 98 return re.search(self.pyproject["include"], str(self.fspath)) 99 100 def _excluded(self): 101 if "exclude" not in self.pyproject: 102 return False 103 return re.search(self.pyproject["exclude"], str(self.fspath)) 104 105 def _re_fix_verbose(self, regex): 106 if "\n" in regex: 107 regex = "(?x)" + regex 108 return re.compile(regex) 109 110 def collect(self): 111 """ returns a list of children (items and collectors) 112 for this collection node. 113 """ 114 return (self,) 115 116 117class BlackError(Exception): 118 pass 119