1"""Rewrite assertion AST to produce nice error messages.""" 2import ast 3import errno 4import functools 5import importlib.abc 6import importlib.machinery 7import importlib.util 8import io 9import itertools 10import marshal 11import os 12import struct 13import sys 14import tokenize 15import types 16from typing import Callable 17from typing import Dict 18from typing import IO 19from typing import Iterable 20from typing import List 21from typing import Optional 22from typing import Sequence 23from typing import Set 24from typing import Tuple 25from typing import Union 26 27import py 28 29from _pytest._io.saferepr import saferepr 30from _pytest._version import version 31from _pytest.assertion import util 32from _pytest.assertion.util import ( # noqa: F401 33 format_explanation as _format_explanation, 34) 35from _pytest.compat import fspath 36from _pytest.compat import TYPE_CHECKING 37from _pytest.config import Config 38from _pytest.main import Session 39from _pytest.pathlib import fnmatch_ex 40from _pytest.pathlib import Path 41from _pytest.pathlib import PurePath 42from _pytest.store import StoreKey 43 44if TYPE_CHECKING: 45 from _pytest.assertion import AssertionState # noqa: F401 46 47 48assertstate_key = StoreKey["AssertionState"]() 49 50 51# pytest caches rewritten pycs in pycache dirs 52PYTEST_TAG = "{}-pytest-{}".format(sys.implementation.cache_tag, version) 53PYC_EXT = ".py" + (__debug__ and "c" or "o") 54PYC_TAIL = "." + PYTEST_TAG + PYC_EXT 55 56 57class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): 58 """PEP302/PEP451 import hook which rewrites asserts.""" 59 60 def __init__(self, config: Config) -> None: 61 self.config = config 62 try: 63 self.fnpats = config.getini("python_files") 64 except ValueError: 65 self.fnpats = ["test_*.py", "*_test.py"] 66 self.session = None # type: Optional[Session] 67 self._rewritten_names = set() # type: Set[str] 68 self._must_rewrite = set() # type: Set[str] 69 # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, 70 # which might result in infinite recursion (#3506) 71 self._writing_pyc = False 72 self._basenames_to_check_rewrite = {"conftest"} 73 self._marked_for_rewrite_cache = {} # type: Dict[str, bool] 74 self._session_paths_checked = False 75 76 def set_session(self, session: Optional[Session]) -> None: 77 self.session = session 78 self._session_paths_checked = False 79 80 # Indirection so we can mock calls to find_spec originated from the hook during testing 81 _find_spec = importlib.machinery.PathFinder.find_spec 82 83 def find_spec( 84 self, 85 name: str, 86 path: Optional[Sequence[Union[str, bytes]]] = None, 87 target: Optional[types.ModuleType] = None, 88 ) -> Optional[importlib.machinery.ModuleSpec]: 89 if self._writing_pyc: 90 return None 91 state = self.config._store[assertstate_key] 92 if self._early_rewrite_bailout(name, state): 93 return None 94 state.trace("find_module called for: %s" % name) 95 96 # Type ignored because mypy is confused about the `self` binding here. 97 spec = self._find_spec(name, path) # type: ignore 98 if ( 99 # the import machinery could not find a file to import 100 spec is None 101 # this is a namespace package (without `__init__.py`) 102 # there's nothing to rewrite there 103 # python3.5 - python3.6: `namespace` 104 # python3.7+: `None` 105 or spec.origin == "namespace" 106 or spec.origin is None 107 # we can only rewrite source files 108 or not isinstance(spec.loader, importlib.machinery.SourceFileLoader) 109 # if the file doesn't exist, we can't rewrite it 110 or not os.path.exists(spec.origin) 111 ): 112 return None 113 else: 114 fn = spec.origin 115 116 if not self._should_rewrite(name, fn, state): 117 return None 118 119 return importlib.util.spec_from_file_location( 120 name, 121 fn, 122 loader=self, 123 submodule_search_locations=spec.submodule_search_locations, 124 ) 125 126 def create_module( 127 self, spec: importlib.machinery.ModuleSpec 128 ) -> Optional[types.ModuleType]: 129 return None # default behaviour is fine 130 131 def exec_module(self, module: types.ModuleType) -> None: 132 assert module.__spec__ is not None 133 assert module.__spec__.origin is not None 134 fn = Path(module.__spec__.origin) 135 state = self.config._store[assertstate_key] 136 137 self._rewritten_names.add(module.__name__) 138 139 # The requested module looks like a test file, so rewrite it. This is 140 # the most magical part of the process: load the source, rewrite the 141 # asserts, and load the rewritten source. We also cache the rewritten 142 # module code in a special pyc. We must be aware of the possibility of 143 # concurrent pytest processes rewriting and loading pycs. To avoid 144 # tricky race conditions, we maintain the following invariant: The 145 # cached pyc is always a complete, valid pyc. Operations on it must be 146 # atomic. POSIX's atomic rename comes in handy. 147 write = not sys.dont_write_bytecode 148 cache_dir = get_cache_dir(fn) 149 if write: 150 ok = try_makedirs(cache_dir) 151 if not ok: 152 write = False 153 state.trace("read only directory: {}".format(cache_dir)) 154 155 cache_name = fn.name[:-3] + PYC_TAIL 156 pyc = cache_dir / cache_name 157 # Notice that even if we're in a read-only directory, I'm going 158 # to check for a cached pyc. This may not be optimal... 159 co = _read_pyc(fn, pyc, state.trace) 160 if co is None: 161 state.trace("rewriting {!r}".format(fn)) 162 source_stat, co = _rewrite_test(fn, self.config) 163 if write: 164 self._writing_pyc = True 165 try: 166 _write_pyc(state, co, source_stat, pyc) 167 finally: 168 self._writing_pyc = False 169 else: 170 state.trace("found cached rewritten pyc for {}".format(fn)) 171 exec(co, module.__dict__) 172 173 def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool: 174 """A fast way to get out of rewriting modules. 175 176 Profiling has shown that the call to PathFinder.find_spec (inside of 177 the find_spec from this class) is a major slowdown, so, this method 178 tries to filter what we're sure won't be rewritten before getting to 179 it. 180 """ 181 if self.session is not None and not self._session_paths_checked: 182 self._session_paths_checked = True 183 for initial_path in self.session._initialpaths: 184 # Make something as c:/projects/my_project/path.py -> 185 # ['c:', 'projects', 'my_project', 'path.py'] 186 parts = str(initial_path).split(os.path.sep) 187 # add 'path' to basenames to be checked. 188 self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0]) 189 190 # Note: conftest already by default in _basenames_to_check_rewrite. 191 parts = name.split(".") 192 if parts[-1] in self._basenames_to_check_rewrite: 193 return False 194 195 # For matching the name it must be as if it was a filename. 196 path = PurePath(os.path.sep.join(parts) + ".py") 197 198 for pat in self.fnpats: 199 # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based 200 # on the name alone because we need to match against the full path 201 if os.path.dirname(pat): 202 return False 203 if fnmatch_ex(pat, path): 204 return False 205 206 if self._is_marked_for_rewrite(name, state): 207 return False 208 209 state.trace("early skip of rewriting module: {}".format(name)) 210 return True 211 212 def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool: 213 # always rewrite conftest files 214 if os.path.basename(fn) == "conftest.py": 215 state.trace("rewriting conftest file: {!r}".format(fn)) 216 return True 217 218 if self.session is not None: 219 if self.session.isinitpath(py.path.local(fn)): 220 state.trace( 221 "matched test file (was specified on cmdline): {!r}".format(fn) 222 ) 223 return True 224 225 # modules not passed explicitly on the command line are only 226 # rewritten if they match the naming convention for test files 227 fn_path = PurePath(fn) 228 for pat in self.fnpats: 229 if fnmatch_ex(pat, fn_path): 230 state.trace("matched test file {!r}".format(fn)) 231 return True 232 233 return self._is_marked_for_rewrite(name, state) 234 235 def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool: 236 try: 237 return self._marked_for_rewrite_cache[name] 238 except KeyError: 239 for marked in self._must_rewrite: 240 if name == marked or name.startswith(marked + "."): 241 state.trace( 242 "matched marked file {!r} (from {!r})".format(name, marked) 243 ) 244 self._marked_for_rewrite_cache[name] = True 245 return True 246 247 self._marked_for_rewrite_cache[name] = False 248 return False 249 250 def mark_rewrite(self, *names: str) -> None: 251 """Mark import names as needing to be rewritten. 252 253 The named module or package as well as any nested modules will 254 be rewritten on import. 255 """ 256 already_imported = ( 257 set(names).intersection(sys.modules).difference(self._rewritten_names) 258 ) 259 for name in already_imported: 260 mod = sys.modules[name] 261 if not AssertionRewriter.is_rewrite_disabled( 262 mod.__doc__ or "" 263 ) and not isinstance(mod.__loader__, type(self)): 264 self._warn_already_imported(name) 265 self._must_rewrite.update(names) 266 self._marked_for_rewrite_cache.clear() 267 268 def _warn_already_imported(self, name: str) -> None: 269 from _pytest.warning_types import PytestAssertRewriteWarning 270 271 self.config.issue_config_time_warning( 272 PytestAssertRewriteWarning( 273 "Module already imported so cannot be rewritten: %s" % name 274 ), 275 stacklevel=5, 276 ) 277 278 def get_data(self, pathname: Union[str, bytes]) -> bytes: 279 """Optional PEP302 get_data API.""" 280 with open(pathname, "rb") as f: 281 return f.read() 282 283 284def _write_pyc_fp( 285 fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType 286) -> None: 287 # Technically, we don't have to have the same pyc format as 288 # (C)Python, since these "pycs" should never be seen by builtin 289 # import. However, there's little reason deviate. 290 fp.write(importlib.util.MAGIC_NUMBER) 291 # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) 292 mtime = int(source_stat.st_mtime) & 0xFFFFFFFF 293 size = source_stat.st_size & 0xFFFFFFFF 294 # "<LL" stands for 2 unsigned longs, little-ending 295 fp.write(struct.pack("<LL", mtime, size)) 296 fp.write(marshal.dumps(co)) 297 298 299if sys.platform == "win32": 300 from atomicwrites import atomic_write 301 302 def _write_pyc( 303 state: "AssertionState", 304 co: types.CodeType, 305 source_stat: os.stat_result, 306 pyc: Path, 307 ) -> bool: 308 try: 309 with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp: 310 _write_pyc_fp(fp, source_stat, co) 311 except OSError as e: 312 state.trace("error writing pyc file at {}: {}".format(pyc, e)) 313 # we ignore any failure to write the cache file 314 # there are many reasons, permission-denied, pycache dir being a 315 # file etc. 316 return False 317 return True 318 319 320else: 321 322 def _write_pyc( 323 state: "AssertionState", 324 co: types.CodeType, 325 source_stat: os.stat_result, 326 pyc: Path, 327 ) -> bool: 328 proc_pyc = "{}.{}".format(pyc, os.getpid()) 329 try: 330 fp = open(proc_pyc, "wb") 331 except OSError as e: 332 state.trace( 333 "error writing pyc file at {}: errno={}".format(proc_pyc, e.errno) 334 ) 335 return False 336 337 try: 338 _write_pyc_fp(fp, source_stat, co) 339 os.rename(proc_pyc, fspath(pyc)) 340 except OSError as e: 341 state.trace("error writing pyc file at {}: {}".format(pyc, e)) 342 # we ignore any failure to write the cache file 343 # there are many reasons, permission-denied, pycache dir being a 344 # file etc. 345 return False 346 finally: 347 fp.close() 348 return True 349 350 351def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]: 352 """Read and rewrite *fn* and return the code object.""" 353 fn_ = fspath(fn) 354 stat = os.stat(fn_) 355 with open(fn_, "rb") as f: 356 source = f.read() 357 tree = ast.parse(source, filename=fn_) 358 rewrite_asserts(tree, source, fn_, config) 359 co = compile(tree, fn_, "exec", dont_inherit=True) 360 return stat, co 361 362 363def _read_pyc( 364 source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None 365) -> Optional[types.CodeType]: 366 """Possibly read a pytest pyc containing rewritten code. 367 368 Return rewritten code if successful or None if not. 369 """ 370 try: 371 fp = open(fspath(pyc), "rb") 372 except OSError: 373 return None 374 with fp: 375 try: 376 stat_result = os.stat(fspath(source)) 377 mtime = int(stat_result.st_mtime) 378 size = stat_result.st_size 379 data = fp.read(12) 380 except OSError as e: 381 trace("_read_pyc({}): OSError {}".format(source, e)) 382 return None 383 # Check for invalid or out of date pyc file. 384 if ( 385 len(data) != 12 386 or data[:4] != importlib.util.MAGIC_NUMBER 387 or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF) 388 ): 389 trace("_read_pyc(%s): invalid or out of date pyc" % source) 390 return None 391 try: 392 co = marshal.load(fp) 393 except Exception as e: 394 trace("_read_pyc({}): marshal.load error {}".format(source, e)) 395 return None 396 if not isinstance(co, types.CodeType): 397 trace("_read_pyc(%s): not a code object" % source) 398 return None 399 return co 400 401 402def rewrite_asserts( 403 mod: ast.Module, 404 source: bytes, 405 module_path: Optional[str] = None, 406 config: Optional[Config] = None, 407) -> None: 408 """Rewrite the assert statements in mod.""" 409 AssertionRewriter(module_path, config, source).run(mod) 410 411 412def _saferepr(obj: object) -> str: 413 r"""Get a safe repr of an object for assertion error messages. 414 415 The assertion formatting (util.format_explanation()) requires 416 newlines to be escaped since they are a special character for it. 417 Normally assertion.util.format_explanation() does this but for a 418 custom repr it is possible to contain one of the special escape 419 sequences, especially '\n{' and '\n}' are likely to be present in 420 JSON reprs. 421 """ 422 return saferepr(obj).replace("\n", "\\n") 423 424 425def _format_assertmsg(obj: object) -> str: 426 r"""Format the custom assertion message given. 427 428 For strings this simply replaces newlines with '\n~' so that 429 util.format_explanation() will preserve them instead of escaping 430 newlines. For other objects saferepr() is used first. 431 """ 432 # reprlib appears to have a bug which means that if a string 433 # contains a newline it gets escaped, however if an object has a 434 # .__repr__() which contains newlines it does not get escaped. 435 # However in either case we want to preserve the newline. 436 replaces = [("\n", "\n~"), ("%", "%%")] 437 if not isinstance(obj, str): 438 obj = saferepr(obj) 439 replaces.append(("\\n", "\n~")) 440 441 for r1, r2 in replaces: 442 obj = obj.replace(r1, r2) 443 444 return obj 445 446 447def _should_repr_global_name(obj: object) -> bool: 448 if callable(obj): 449 return False 450 451 try: 452 return not hasattr(obj, "__name__") 453 except Exception: 454 return True 455 456 457def _format_boolop(explanations: Iterable[str], is_or: bool) -> str: 458 explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")" 459 return explanation.replace("%", "%%") 460 461 462def _call_reprcompare( 463 ops: Sequence[str], 464 results: Sequence[bool], 465 expls: Sequence[str], 466 each_obj: Sequence[object], 467) -> str: 468 for i, res, expl in zip(range(len(ops)), results, expls): 469 try: 470 done = not res 471 except Exception: 472 done = True 473 if done: 474 break 475 if util._reprcompare is not None: 476 custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) 477 if custom is not None: 478 return custom 479 return expl 480 481 482def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None: 483 if util._assertion_pass is not None: 484 util._assertion_pass(lineno, orig, expl) 485 486 487def _check_if_assertion_pass_impl() -> bool: 488 """Check if any plugins implement the pytest_assertion_pass hook 489 in order not to generate explanation unecessarily (might be expensive).""" 490 return True if util._assertion_pass else False 491 492 493UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} 494 495BINOP_MAP = { 496 ast.BitOr: "|", 497 ast.BitXor: "^", 498 ast.BitAnd: "&", 499 ast.LShift: "<<", 500 ast.RShift: ">>", 501 ast.Add: "+", 502 ast.Sub: "-", 503 ast.Mult: "*", 504 ast.Div: "/", 505 ast.FloorDiv: "//", 506 ast.Mod: "%%", # escaped for string formatting 507 ast.Eq: "==", 508 ast.NotEq: "!=", 509 ast.Lt: "<", 510 ast.LtE: "<=", 511 ast.Gt: ">", 512 ast.GtE: ">=", 513 ast.Pow: "**", 514 ast.Is: "is", 515 ast.IsNot: "is not", 516 ast.In: "in", 517 ast.NotIn: "not in", 518 ast.MatMult: "@", 519} 520 521 522def set_location(node, lineno, col_offset): 523 """Set node location information recursively.""" 524 525 def _fix(node, lineno, col_offset): 526 if "lineno" in node._attributes: 527 node.lineno = lineno 528 if "col_offset" in node._attributes: 529 node.col_offset = col_offset 530 for child in ast.iter_child_nodes(node): 531 _fix(child, lineno, col_offset) 532 533 _fix(node, lineno, col_offset) 534 return node 535 536 537def _get_assertion_exprs(src: bytes) -> Dict[int, str]: 538 """Return a mapping from {lineno: "assertion test expression"}.""" 539 ret = {} # type: Dict[int, str] 540 541 depth = 0 542 lines = [] # type: List[str] 543 assert_lineno = None # type: Optional[int] 544 seen_lines = set() # type: Set[int] 545 546 def _write_and_reset() -> None: 547 nonlocal depth, lines, assert_lineno, seen_lines 548 assert assert_lineno is not None 549 ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\") 550 depth = 0 551 lines = [] 552 assert_lineno = None 553 seen_lines = set() 554 555 tokens = tokenize.tokenize(io.BytesIO(src).readline) 556 for tp, source, (lineno, offset), _, line in tokens: 557 if tp == tokenize.NAME and source == "assert": 558 assert_lineno = lineno 559 elif assert_lineno is not None: 560 # keep track of depth for the assert-message `,` lookup 561 if tp == tokenize.OP and source in "([{": 562 depth += 1 563 elif tp == tokenize.OP and source in ")]}": 564 depth -= 1 565 566 if not lines: 567 lines.append(line[offset:]) 568 seen_lines.add(lineno) 569 # a non-nested comma separates the expression from the message 570 elif depth == 0 and tp == tokenize.OP and source == ",": 571 # one line assert with message 572 if lineno in seen_lines and len(lines) == 1: 573 offset_in_trimmed = offset + len(lines[-1]) - len(line) 574 lines[-1] = lines[-1][:offset_in_trimmed] 575 # multi-line assert with message 576 elif lineno in seen_lines: 577 lines[-1] = lines[-1][:offset] 578 # multi line assert with escapd newline before message 579 else: 580 lines.append(line[:offset]) 581 _write_and_reset() 582 elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}: 583 _write_and_reset() 584 elif lines and lineno not in seen_lines: 585 lines.append(line) 586 seen_lines.add(lineno) 587 588 return ret 589 590 591class AssertionRewriter(ast.NodeVisitor): 592 """Assertion rewriting implementation. 593 594 The main entrypoint is to call .run() with an ast.Module instance, 595 this will then find all the assert statements and rewrite them to 596 provide intermediate values and a detailed assertion error. See 597 http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html 598 for an overview of how this works. 599 600 The entry point here is .run() which will iterate over all the 601 statements in an ast.Module and for each ast.Assert statement it 602 finds call .visit() with it. Then .visit_Assert() takes over and 603 is responsible for creating new ast statements to replace the 604 original assert statement: it rewrites the test of an assertion 605 to provide intermediate values and replace it with an if statement 606 which raises an assertion error with a detailed explanation in 607 case the expression is false and calls pytest_assertion_pass hook 608 if expression is true. 609 610 For this .visit_Assert() uses the visitor pattern to visit all the 611 AST nodes of the ast.Assert.test field, each visit call returning 612 an AST node and the corresponding explanation string. During this 613 state is kept in several instance attributes: 614 615 :statements: All the AST statements which will replace the assert 616 statement. 617 618 :variables: This is populated by .variable() with each variable 619 used by the statements so that they can all be set to None at 620 the end of the statements. 621 622 :variable_counter: Counter to create new unique variables needed 623 by statements. Variables are created using .variable() and 624 have the form of "@py_assert0". 625 626 :expl_stmts: The AST statements which will be executed to get 627 data from the assertion. This is the code which will construct 628 the detailed assertion message that is used in the AssertionError 629 or for the pytest_assertion_pass hook. 630 631 :explanation_specifiers: A dict filled by .explanation_param() 632 with %-formatting placeholders and their corresponding 633 expressions to use in the building of an assertion message. 634 This is used by .pop_format_context() to build a message. 635 636 :stack: A stack of the explanation_specifiers dicts maintained by 637 .push_format_context() and .pop_format_context() which allows 638 to build another %-formatted string while already building one. 639 640 This state is reset on every new assert statement visited and used 641 by the other visitors. 642 """ 643 644 def __init__( 645 self, module_path: Optional[str], config: Optional[Config], source: bytes 646 ) -> None: 647 super().__init__() 648 self.module_path = module_path 649 self.config = config 650 if config is not None: 651 self.enable_assertion_pass_hook = config.getini( 652 "enable_assertion_pass_hook" 653 ) 654 else: 655 self.enable_assertion_pass_hook = False 656 self.source = source 657 658 @functools.lru_cache(maxsize=1) 659 def _assert_expr_to_lineno(self) -> Dict[int, str]: 660 return _get_assertion_exprs(self.source) 661 662 def run(self, mod: ast.Module) -> None: 663 """Find all assert statements in *mod* and rewrite them.""" 664 if not mod.body: 665 # Nothing to do. 666 return 667 # Insert some special imports at the top of the module but after any 668 # docstrings and __future__ imports. 669 aliases = [ 670 ast.alias("builtins", "@py_builtins"), 671 ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), 672 ] 673 doc = getattr(mod, "docstring", None) 674 expect_docstring = doc is None 675 if doc is not None and self.is_rewrite_disabled(doc): 676 return 677 pos = 0 678 lineno = 1 679 for item in mod.body: 680 if ( 681 expect_docstring 682 and isinstance(item, ast.Expr) 683 and isinstance(item.value, ast.Str) 684 ): 685 doc = item.value.s 686 if self.is_rewrite_disabled(doc): 687 return 688 expect_docstring = False 689 elif ( 690 isinstance(item, ast.ImportFrom) 691 and item.level == 0 692 and item.module == "__future__" 693 ): 694 pass 695 else: 696 break 697 pos += 1 698 # Special case: for a decorated function, set the lineno to that of the 699 # first decorator, not the `def`. Issue #4984. 700 if isinstance(item, ast.FunctionDef) and item.decorator_list: 701 lineno = item.decorator_list[0].lineno 702 else: 703 lineno = item.lineno 704 imports = [ 705 ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases 706 ] 707 mod.body[pos:pos] = imports 708 # Collect asserts. 709 nodes = [mod] # type: List[ast.AST] 710 while nodes: 711 node = nodes.pop() 712 for name, field in ast.iter_fields(node): 713 if isinstance(field, list): 714 new = [] # type: List[ast.AST] 715 for i, child in enumerate(field): 716 if isinstance(child, ast.Assert): 717 # Transform assert. 718 new.extend(self.visit(child)) 719 else: 720 new.append(child) 721 if isinstance(child, ast.AST): 722 nodes.append(child) 723 setattr(node, name, new) 724 elif ( 725 isinstance(field, ast.AST) 726 # Don't recurse into expressions as they can't contain 727 # asserts. 728 and not isinstance(field, ast.expr) 729 ): 730 nodes.append(field) 731 732 @staticmethod 733 def is_rewrite_disabled(docstring: str) -> bool: 734 return "PYTEST_DONT_REWRITE" in docstring 735 736 def variable(self) -> str: 737 """Get a new variable.""" 738 # Use a character invalid in python identifiers to avoid clashing. 739 name = "@py_assert" + str(next(self.variable_counter)) 740 self.variables.append(name) 741 return name 742 743 def assign(self, expr: ast.expr) -> ast.Name: 744 """Give *expr* a name.""" 745 name = self.variable() 746 self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) 747 return ast.Name(name, ast.Load()) 748 749 def display(self, expr: ast.expr) -> ast.expr: 750 """Call saferepr on the expression.""" 751 return self.helper("_saferepr", expr) 752 753 def helper(self, name: str, *args: ast.expr) -> ast.expr: 754 """Call a helper in this module.""" 755 py_name = ast.Name("@pytest_ar", ast.Load()) 756 attr = ast.Attribute(py_name, name, ast.Load()) 757 return ast.Call(attr, list(args), []) 758 759 def builtin(self, name: str) -> ast.Attribute: 760 """Return the builtin called *name*.""" 761 builtin_name = ast.Name("@py_builtins", ast.Load()) 762 return ast.Attribute(builtin_name, name, ast.Load()) 763 764 def explanation_param(self, expr: ast.expr) -> str: 765 """Return a new named %-formatting placeholder for expr. 766 767 This creates a %-formatting placeholder for expr in the 768 current formatting context, e.g. ``%(py0)s``. The placeholder 769 and expr are placed in the current format context so that it 770 can be used on the next call to .pop_format_context(). 771 """ 772 specifier = "py" + str(next(self.variable_counter)) 773 self.explanation_specifiers[specifier] = expr 774 return "%(" + specifier + ")s" 775 776 def push_format_context(self) -> None: 777 """Create a new formatting context. 778 779 The format context is used for when an explanation wants to 780 have a variable value formatted in the assertion message. In 781 this case the value required can be added using 782 .explanation_param(). Finally .pop_format_context() is used 783 to format a string of %-formatted values as added by 784 .explanation_param(). 785 """ 786 self.explanation_specifiers = {} # type: Dict[str, ast.expr] 787 self.stack.append(self.explanation_specifiers) 788 789 def pop_format_context(self, expl_expr: ast.expr) -> ast.Name: 790 """Format the %-formatted string with current format context. 791 792 The expl_expr should be an str ast.expr instance constructed from 793 the %-placeholders created by .explanation_param(). This will 794 add the required code to format said string to .expl_stmts and 795 return the ast.Name instance of the formatted string. 796 """ 797 current = self.stack.pop() 798 if self.stack: 799 self.explanation_specifiers = self.stack[-1] 800 keys = [ast.Str(key) for key in current.keys()] 801 format_dict = ast.Dict(keys, list(current.values())) 802 form = ast.BinOp(expl_expr, ast.Mod(), format_dict) 803 name = "@py_format" + str(next(self.variable_counter)) 804 if self.enable_assertion_pass_hook: 805 self.format_variables.append(name) 806 self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) 807 return ast.Name(name, ast.Load()) 808 809 def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]: 810 """Handle expressions we don't have custom code for.""" 811 assert isinstance(node, ast.expr) 812 res = self.assign(node) 813 return res, self.explanation_param(self.display(res)) 814 815 def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: 816 """Return the AST statements to replace the ast.Assert instance. 817 818 This rewrites the test of an assertion to provide 819 intermediate values and replace it with an if statement which 820 raises an assertion error with a detailed explanation in case 821 the expression is false. 822 """ 823 if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1: 824 from _pytest.warning_types import PytestAssertRewriteWarning 825 import warnings 826 827 # TODO: This assert should not be needed. 828 assert self.module_path is not None 829 warnings.warn_explicit( 830 PytestAssertRewriteWarning( 831 "assertion is always true, perhaps remove parentheses?" 832 ), 833 category=None, 834 filename=fspath(self.module_path), 835 lineno=assert_.lineno, 836 ) 837 838 self.statements = [] # type: List[ast.stmt] 839 self.variables = [] # type: List[str] 840 self.variable_counter = itertools.count() 841 842 if self.enable_assertion_pass_hook: 843 self.format_variables = [] # type: List[str] 844 845 self.stack = [] # type: List[Dict[str, ast.expr]] 846 self.expl_stmts = [] # type: List[ast.stmt] 847 self.push_format_context() 848 # Rewrite assert into a bunch of statements. 849 top_condition, explanation = self.visit(assert_.test) 850 851 negation = ast.UnaryOp(ast.Not(), top_condition) 852 853 if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook 854 msg = self.pop_format_context(ast.Str(explanation)) 855 856 # Failed 857 if assert_.msg: 858 assertmsg = self.helper("_format_assertmsg", assert_.msg) 859 gluestr = "\n>assert " 860 else: 861 assertmsg = ast.Str("") 862 gluestr = "assert " 863 err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg) 864 err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) 865 err_name = ast.Name("AssertionError", ast.Load()) 866 fmt = self.helper("_format_explanation", err_msg) 867 exc = ast.Call(err_name, [fmt], []) 868 raise_ = ast.Raise(exc, None) 869 statements_fail = [] 870 statements_fail.extend(self.expl_stmts) 871 statements_fail.append(raise_) 872 873 # Passed 874 fmt_pass = self.helper("_format_explanation", msg) 875 orig = self._assert_expr_to_lineno()[assert_.lineno] 876 hook_call_pass = ast.Expr( 877 self.helper( 878 "_call_assertion_pass", 879 ast.Num(assert_.lineno), 880 ast.Str(orig), 881 fmt_pass, 882 ) 883 ) 884 # If any hooks implement assert_pass hook 885 hook_impl_test = ast.If( 886 self.helper("_check_if_assertion_pass_impl"), 887 self.expl_stmts + [hook_call_pass], 888 [], 889 ) 890 statements_pass = [hook_impl_test] 891 892 # Test for assertion condition 893 main_test = ast.If(negation, statements_fail, statements_pass) 894 self.statements.append(main_test) 895 if self.format_variables: 896 variables = [ 897 ast.Name(name, ast.Store()) for name in self.format_variables 898 ] 899 clear_format = ast.Assign(variables, ast.NameConstant(None)) 900 self.statements.append(clear_format) 901 902 else: # Original assertion rewriting 903 # Create failure message. 904 body = self.expl_stmts 905 self.statements.append(ast.If(negation, body, [])) 906 if assert_.msg: 907 assertmsg = self.helper("_format_assertmsg", assert_.msg) 908 explanation = "\n>assert " + explanation 909 else: 910 assertmsg = ast.Str("") 911 explanation = "assert " + explanation 912 template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) 913 msg = self.pop_format_context(template) 914 fmt = self.helper("_format_explanation", msg) 915 err_name = ast.Name("AssertionError", ast.Load()) 916 exc = ast.Call(err_name, [fmt], []) 917 raise_ = ast.Raise(exc, None) 918 919 body.append(raise_) 920 921 # Clear temporary variables by setting them to None. 922 if self.variables: 923 variables = [ast.Name(name, ast.Store()) for name in self.variables] 924 clear = ast.Assign(variables, ast.NameConstant(None)) 925 self.statements.append(clear) 926 # Fix line numbers. 927 for stmt in self.statements: 928 set_location(stmt, assert_.lineno, assert_.col_offset) 929 return self.statements 930 931 def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: 932 # Display the repr of the name if it's a local variable or 933 # _should_repr_global_name() thinks it's acceptable. 934 locs = ast.Call(self.builtin("locals"), [], []) 935 inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs]) 936 dorepr = self.helper("_should_repr_global_name", name) 937 test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) 938 expr = ast.IfExp(test, self.display(name), ast.Str(name.id)) 939 return name, self.explanation_param(expr) 940 941 def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: 942 res_var = self.variable() 943 expl_list = self.assign(ast.List([], ast.Load())) 944 app = ast.Attribute(expl_list, "append", ast.Load()) 945 is_or = int(isinstance(boolop.op, ast.Or)) 946 body = save = self.statements 947 fail_save = self.expl_stmts 948 levels = len(boolop.values) - 1 949 self.push_format_context() 950 # Process each operand, short-circuiting if needed. 951 for i, v in enumerate(boolop.values): 952 if i: 953 fail_inner = [] # type: List[ast.stmt] 954 # cond is set in a prior loop iteration below 955 self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa 956 self.expl_stmts = fail_inner 957 self.push_format_context() 958 res, expl = self.visit(v) 959 body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) 960 expl_format = self.pop_format_context(ast.Str(expl)) 961 call = ast.Call(app, [expl_format], []) 962 self.expl_stmts.append(ast.Expr(call)) 963 if i < levels: 964 cond = res # type: ast.expr 965 if is_or: 966 cond = ast.UnaryOp(ast.Not(), cond) 967 inner = [] # type: List[ast.stmt] 968 self.statements.append(ast.If(cond, inner, [])) 969 self.statements = body = inner 970 self.statements = save 971 self.expl_stmts = fail_save 972 expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or)) 973 expl = self.pop_format_context(expl_template) 974 return ast.Name(res_var, ast.Load()), self.explanation_param(expl) 975 976 def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]: 977 pattern = UNARY_MAP[unary.op.__class__] 978 operand_res, operand_expl = self.visit(unary.operand) 979 res = self.assign(ast.UnaryOp(unary.op, operand_res)) 980 return res, pattern % (operand_expl,) 981 982 def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]: 983 symbol = BINOP_MAP[binop.op.__class__] 984 left_expr, left_expl = self.visit(binop.left) 985 right_expr, right_expl = self.visit(binop.right) 986 explanation = "({} {} {})".format(left_expl, symbol, right_expl) 987 res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) 988 return res, explanation 989 990 def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: 991 new_func, func_expl = self.visit(call.func) 992 arg_expls = [] 993 new_args = [] 994 new_kwargs = [] 995 for arg in call.args: 996 res, expl = self.visit(arg) 997 arg_expls.append(expl) 998 new_args.append(res) 999 for keyword in call.keywords: 1000 res, expl = self.visit(keyword.value) 1001 new_kwargs.append(ast.keyword(keyword.arg, res)) 1002 if keyword.arg: 1003 arg_expls.append(keyword.arg + "=" + expl) 1004 else: # **args have `arg` keywords with an .arg of None 1005 arg_expls.append("**" + expl) 1006 1007 expl = "{}({})".format(func_expl, ", ".join(arg_expls)) 1008 new_call = ast.Call(new_func, new_args, new_kwargs) 1009 res = self.assign(new_call) 1010 res_expl = self.explanation_param(self.display(res)) 1011 outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl) 1012 return res, outer_expl 1013 1014 def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]: 1015 # From Python 3.5, a Starred node can appear in a function call. 1016 res, expl = self.visit(starred.value) 1017 new_starred = ast.Starred(res, starred.ctx) 1018 return new_starred, "*" + expl 1019 1020 def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: 1021 if not isinstance(attr.ctx, ast.Load): 1022 return self.generic_visit(attr) 1023 value, value_expl = self.visit(attr.value) 1024 res = self.assign(ast.Attribute(value, attr.attr, ast.Load())) 1025 res_expl = self.explanation_param(self.display(res)) 1026 pat = "%s\n{%s = %s.%s\n}" 1027 expl = pat % (res_expl, res_expl, value_expl, attr.attr) 1028 return res, expl 1029 1030 def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: 1031 self.push_format_context() 1032 left_res, left_expl = self.visit(comp.left) 1033 if isinstance(comp.left, (ast.Compare, ast.BoolOp)): 1034 left_expl = "({})".format(left_expl) 1035 res_variables = [self.variable() for i in range(len(comp.ops))] 1036 load_names = [ast.Name(v, ast.Load()) for v in res_variables] 1037 store_names = [ast.Name(v, ast.Store()) for v in res_variables] 1038 it = zip(range(len(comp.ops)), comp.ops, comp.comparators) 1039 expls = [] 1040 syms = [] 1041 results = [left_res] 1042 for i, op, next_operand in it: 1043 next_res, next_expl = self.visit(next_operand) 1044 if isinstance(next_operand, (ast.Compare, ast.BoolOp)): 1045 next_expl = "({})".format(next_expl) 1046 results.append(next_res) 1047 sym = BINOP_MAP[op.__class__] 1048 syms.append(ast.Str(sym)) 1049 expl = "{} {} {}".format(left_expl, sym, next_expl) 1050 expls.append(ast.Str(expl)) 1051 res_expr = ast.Compare(left_res, [op], [next_res]) 1052 self.statements.append(ast.Assign([store_names[i]], res_expr)) 1053 left_res, left_expl = next_res, next_expl 1054 # Use pytest.assertion.util._reprcompare if that's available. 1055 expl_call = self.helper( 1056 "_call_reprcompare", 1057 ast.Tuple(syms, ast.Load()), 1058 ast.Tuple(load_names, ast.Load()), 1059 ast.Tuple(expls, ast.Load()), 1060 ast.Tuple(results, ast.Load()), 1061 ) 1062 if len(comp.ops) > 1: 1063 res = ast.BoolOp(ast.And(), load_names) # type: ast.expr 1064 else: 1065 res = load_names[0] 1066 return res, self.explanation_param(self.pop_format_context(expl_call)) 1067 1068 1069def try_makedirs(cache_dir: Path) -> bool: 1070 """Attempt to create the given directory and sub-directories exist. 1071 1072 Returns True if successful or if it already exists. 1073 """ 1074 try: 1075 os.makedirs(fspath(cache_dir), exist_ok=True) 1076 except (FileNotFoundError, NotADirectoryError, FileExistsError): 1077 # One of the path components was not a directory: 1078 # - we're in a zip file 1079 # - it is a file 1080 return False 1081 except PermissionError: 1082 return False 1083 except OSError as e: 1084 # as of now, EROFS doesn't have an equivalent OSError-subclass 1085 if e.errno == errno.EROFS: 1086 return False 1087 raise 1088 return True 1089 1090 1091def get_cache_dir(file_path: Path) -> Path: 1092 """Return the cache directory to write .pyc files for the given .py file path.""" 1093 if sys.version_info >= (3, 8) and sys.pycache_prefix: 1094 # given: 1095 # prefix = '/tmp/pycs' 1096 # path = '/home/user/proj/test_app.py' 1097 # we want: 1098 # '/tmp/pycs/home/user/proj' 1099 return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1]) 1100 else: 1101 # classic pycache directory 1102 return file_path.parent / "__pycache__" 1103