1"""Rewrite assertion AST to produce nice error messages."""
2import ast
3import errno
4import functools
5import importlib.abc
6import importlib.machinery
7import importlib.util
8import io
9import itertools
10import marshal
11import os
12import struct
13import sys
14import tokenize
15import types
16from typing import Callable
17from typing import Dict
18from typing import IO
19from typing import Iterable
20from typing import List
21from typing import Optional
22from typing import Sequence
23from typing import Set
24from typing import Tuple
25from typing import Union
26
27import py
28
29from _pytest._io.saferepr import saferepr
30from _pytest._version import version
31from _pytest.assertion import util
32from _pytest.assertion.util import (  # noqa: F401
33    format_explanation as _format_explanation,
34)
35from _pytest.compat import fspath
36from _pytest.compat import TYPE_CHECKING
37from _pytest.config import Config
38from _pytest.main import Session
39from _pytest.pathlib import fnmatch_ex
40from _pytest.pathlib import Path
41from _pytest.pathlib import PurePath
42from _pytest.store import StoreKey
43
44if TYPE_CHECKING:
45    from _pytest.assertion import AssertionState  # noqa: F401
46
47
48assertstate_key = StoreKey["AssertionState"]()
49
50
51# pytest caches rewritten pycs in pycache dirs
52PYTEST_TAG = "{}-pytest-{}".format(sys.implementation.cache_tag, version)
53PYC_EXT = ".py" + (__debug__ and "c" or "o")
54PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
55
56
57class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
58    """PEP302/PEP451 import hook which rewrites asserts."""
59
60    def __init__(self, config: Config) -> None:
61        self.config = config
62        try:
63            self.fnpats = config.getini("python_files")
64        except ValueError:
65            self.fnpats = ["test_*.py", "*_test.py"]
66        self.session = None  # type: Optional[Session]
67        self._rewritten_names = set()  # type: Set[str]
68        self._must_rewrite = set()  # type: Set[str]
69        # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
70        # which might result in infinite recursion (#3506)
71        self._writing_pyc = False
72        self._basenames_to_check_rewrite = {"conftest"}
73        self._marked_for_rewrite_cache = {}  # type: Dict[str, bool]
74        self._session_paths_checked = False
75
76    def set_session(self, session: Optional[Session]) -> None:
77        self.session = session
78        self._session_paths_checked = False
79
80    # Indirection so we can mock calls to find_spec originated from the hook during testing
81    _find_spec = importlib.machinery.PathFinder.find_spec
82
83    def find_spec(
84        self,
85        name: str,
86        path: Optional[Sequence[Union[str, bytes]]] = None,
87        target: Optional[types.ModuleType] = None,
88    ) -> Optional[importlib.machinery.ModuleSpec]:
89        if self._writing_pyc:
90            return None
91        state = self.config._store[assertstate_key]
92        if self._early_rewrite_bailout(name, state):
93            return None
94        state.trace("find_module called for: %s" % name)
95
96        # Type ignored because mypy is confused about the `self` binding here.
97        spec = self._find_spec(name, path)  # type: ignore
98        if (
99            # the import machinery could not find a file to import
100            spec is None
101            # this is a namespace package (without `__init__.py`)
102            # there's nothing to rewrite there
103            # python3.5 - python3.6: `namespace`
104            # python3.7+: `None`
105            or spec.origin == "namespace"
106            or spec.origin is None
107            # we can only rewrite source files
108            or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
109            # if the file doesn't exist, we can't rewrite it
110            or not os.path.exists(spec.origin)
111        ):
112            return None
113        else:
114            fn = spec.origin
115
116        if not self._should_rewrite(name, fn, state):
117            return None
118
119        return importlib.util.spec_from_file_location(
120            name,
121            fn,
122            loader=self,
123            submodule_search_locations=spec.submodule_search_locations,
124        )
125
126    def create_module(
127        self, spec: importlib.machinery.ModuleSpec
128    ) -> Optional[types.ModuleType]:
129        return None  # default behaviour is fine
130
131    def exec_module(self, module: types.ModuleType) -> None:
132        assert module.__spec__ is not None
133        assert module.__spec__.origin is not None
134        fn = Path(module.__spec__.origin)
135        state = self.config._store[assertstate_key]
136
137        self._rewritten_names.add(module.__name__)
138
139        # The requested module looks like a test file, so rewrite it. This is
140        # the most magical part of the process: load the source, rewrite the
141        # asserts, and load the rewritten source. We also cache the rewritten
142        # module code in a special pyc. We must be aware of the possibility of
143        # concurrent pytest processes rewriting and loading pycs. To avoid
144        # tricky race conditions, we maintain the following invariant: The
145        # cached pyc is always a complete, valid pyc. Operations on it must be
146        # atomic. POSIX's atomic rename comes in handy.
147        write = not sys.dont_write_bytecode
148        cache_dir = get_cache_dir(fn)
149        if write:
150            ok = try_makedirs(cache_dir)
151            if not ok:
152                write = False
153                state.trace("read only directory: {}".format(cache_dir))
154
155        cache_name = fn.name[:-3] + PYC_TAIL
156        pyc = cache_dir / cache_name
157        # Notice that even if we're in a read-only directory, I'm going
158        # to check for a cached pyc. This may not be optimal...
159        co = _read_pyc(fn, pyc, state.trace)
160        if co is None:
161            state.trace("rewriting {!r}".format(fn))
162            source_stat, co = _rewrite_test(fn, self.config)
163            if write:
164                self._writing_pyc = True
165                try:
166                    _write_pyc(state, co, source_stat, pyc)
167                finally:
168                    self._writing_pyc = False
169        else:
170            state.trace("found cached rewritten pyc for {}".format(fn))
171        exec(co, module.__dict__)
172
173    def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
174        """A fast way to get out of rewriting modules.
175
176        Profiling has shown that the call to PathFinder.find_spec (inside of
177        the find_spec from this class) is a major slowdown, so, this method
178        tries to filter what we're sure won't be rewritten before getting to
179        it.
180        """
181        if self.session is not None and not self._session_paths_checked:
182            self._session_paths_checked = True
183            for initial_path in self.session._initialpaths:
184                # Make something as c:/projects/my_project/path.py ->
185                #     ['c:', 'projects', 'my_project', 'path.py']
186                parts = str(initial_path).split(os.path.sep)
187                # add 'path' to basenames to be checked.
188                self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
189
190        # Note: conftest already by default in _basenames_to_check_rewrite.
191        parts = name.split(".")
192        if parts[-1] in self._basenames_to_check_rewrite:
193            return False
194
195        # For matching the name it must be as if it was a filename.
196        path = PurePath(os.path.sep.join(parts) + ".py")
197
198        for pat in self.fnpats:
199            # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
200            # on the name alone because we need to match against the full path
201            if os.path.dirname(pat):
202                return False
203            if fnmatch_ex(pat, path):
204                return False
205
206        if self._is_marked_for_rewrite(name, state):
207            return False
208
209        state.trace("early skip of rewriting module: {}".format(name))
210        return True
211
212    def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
213        # always rewrite conftest files
214        if os.path.basename(fn) == "conftest.py":
215            state.trace("rewriting conftest file: {!r}".format(fn))
216            return True
217
218        if self.session is not None:
219            if self.session.isinitpath(py.path.local(fn)):
220                state.trace(
221                    "matched test file (was specified on cmdline): {!r}".format(fn)
222                )
223                return True
224
225        # modules not passed explicitly on the command line are only
226        # rewritten if they match the naming convention for test files
227        fn_path = PurePath(fn)
228        for pat in self.fnpats:
229            if fnmatch_ex(pat, fn_path):
230                state.trace("matched test file {!r}".format(fn))
231                return True
232
233        return self._is_marked_for_rewrite(name, state)
234
235    def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
236        try:
237            return self._marked_for_rewrite_cache[name]
238        except KeyError:
239            for marked in self._must_rewrite:
240                if name == marked or name.startswith(marked + "."):
241                    state.trace(
242                        "matched marked file {!r} (from {!r})".format(name, marked)
243                    )
244                    self._marked_for_rewrite_cache[name] = True
245                    return True
246
247            self._marked_for_rewrite_cache[name] = False
248            return False
249
250    def mark_rewrite(self, *names: str) -> None:
251        """Mark import names as needing to be rewritten.
252
253        The named module or package as well as any nested modules will
254        be rewritten on import.
255        """
256        already_imported = (
257            set(names).intersection(sys.modules).difference(self._rewritten_names)
258        )
259        for name in already_imported:
260            mod = sys.modules[name]
261            if not AssertionRewriter.is_rewrite_disabled(
262                mod.__doc__ or ""
263            ) and not isinstance(mod.__loader__, type(self)):
264                self._warn_already_imported(name)
265        self._must_rewrite.update(names)
266        self._marked_for_rewrite_cache.clear()
267
268    def _warn_already_imported(self, name: str) -> None:
269        from _pytest.warning_types import PytestAssertRewriteWarning
270
271        self.config.issue_config_time_warning(
272            PytestAssertRewriteWarning(
273                "Module already imported so cannot be rewritten: %s" % name
274            ),
275            stacklevel=5,
276        )
277
278    def get_data(self, pathname: Union[str, bytes]) -> bytes:
279        """Optional PEP302 get_data API."""
280        with open(pathname, "rb") as f:
281            return f.read()
282
283
284def _write_pyc_fp(
285    fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
286) -> None:
287    # Technically, we don't have to have the same pyc format as
288    # (C)Python, since these "pycs" should never be seen by builtin
289    # import. However, there's little reason deviate.
290    fp.write(importlib.util.MAGIC_NUMBER)
291    # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
292    mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
293    size = source_stat.st_size & 0xFFFFFFFF
294    # "<LL" stands for 2 unsigned longs, little-ending
295    fp.write(struct.pack("<LL", mtime, size))
296    fp.write(marshal.dumps(co))
297
298
299if sys.platform == "win32":
300    from atomicwrites import atomic_write
301
302    def _write_pyc(
303        state: "AssertionState",
304        co: types.CodeType,
305        source_stat: os.stat_result,
306        pyc: Path,
307    ) -> bool:
308        try:
309            with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp:
310                _write_pyc_fp(fp, source_stat, co)
311        except OSError as e:
312            state.trace("error writing pyc file at {}: {}".format(pyc, e))
313            # we ignore any failure to write the cache file
314            # there are many reasons, permission-denied, pycache dir being a
315            # file etc.
316            return False
317        return True
318
319
320else:
321
322    def _write_pyc(
323        state: "AssertionState",
324        co: types.CodeType,
325        source_stat: os.stat_result,
326        pyc: Path,
327    ) -> bool:
328        proc_pyc = "{}.{}".format(pyc, os.getpid())
329        try:
330            fp = open(proc_pyc, "wb")
331        except OSError as e:
332            state.trace(
333                "error writing pyc file at {}: errno={}".format(proc_pyc, e.errno)
334            )
335            return False
336
337        try:
338            _write_pyc_fp(fp, source_stat, co)
339            os.rename(proc_pyc, fspath(pyc))
340        except OSError as e:
341            state.trace("error writing pyc file at {}: {}".format(pyc, e))
342            # we ignore any failure to write the cache file
343            # there are many reasons, permission-denied, pycache dir being a
344            # file etc.
345            return False
346        finally:
347            fp.close()
348        return True
349
350
351def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
352    """Read and rewrite *fn* and return the code object."""
353    fn_ = fspath(fn)
354    stat = os.stat(fn_)
355    with open(fn_, "rb") as f:
356        source = f.read()
357    tree = ast.parse(source, filename=fn_)
358    rewrite_asserts(tree, source, fn_, config)
359    co = compile(tree, fn_, "exec", dont_inherit=True)
360    return stat, co
361
362
363def _read_pyc(
364    source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
365) -> Optional[types.CodeType]:
366    """Possibly read a pytest pyc containing rewritten code.
367
368    Return rewritten code if successful or None if not.
369    """
370    try:
371        fp = open(fspath(pyc), "rb")
372    except OSError:
373        return None
374    with fp:
375        try:
376            stat_result = os.stat(fspath(source))
377            mtime = int(stat_result.st_mtime)
378            size = stat_result.st_size
379            data = fp.read(12)
380        except OSError as e:
381            trace("_read_pyc({}): OSError {}".format(source, e))
382            return None
383        # Check for invalid or out of date pyc file.
384        if (
385            len(data) != 12
386            or data[:4] != importlib.util.MAGIC_NUMBER
387            or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF)
388        ):
389            trace("_read_pyc(%s): invalid or out of date pyc" % source)
390            return None
391        try:
392            co = marshal.load(fp)
393        except Exception as e:
394            trace("_read_pyc({}): marshal.load error {}".format(source, e))
395            return None
396        if not isinstance(co, types.CodeType):
397            trace("_read_pyc(%s): not a code object" % source)
398            return None
399        return co
400
401
402def rewrite_asserts(
403    mod: ast.Module,
404    source: bytes,
405    module_path: Optional[str] = None,
406    config: Optional[Config] = None,
407) -> None:
408    """Rewrite the assert statements in mod."""
409    AssertionRewriter(module_path, config, source).run(mod)
410
411
412def _saferepr(obj: object) -> str:
413    r"""Get a safe repr of an object for assertion error messages.
414
415    The assertion formatting (util.format_explanation()) requires
416    newlines to be escaped since they are a special character for it.
417    Normally assertion.util.format_explanation() does this but for a
418    custom repr it is possible to contain one of the special escape
419    sequences, especially '\n{' and '\n}' are likely to be present in
420    JSON reprs.
421    """
422    return saferepr(obj).replace("\n", "\\n")
423
424
425def _format_assertmsg(obj: object) -> str:
426    r"""Format the custom assertion message given.
427
428    For strings this simply replaces newlines with '\n~' so that
429    util.format_explanation() will preserve them instead of escaping
430    newlines.  For other objects saferepr() is used first.
431    """
432    # reprlib appears to have a bug which means that if a string
433    # contains a newline it gets escaped, however if an object has a
434    # .__repr__() which contains newlines it does not get escaped.
435    # However in either case we want to preserve the newline.
436    replaces = [("\n", "\n~"), ("%", "%%")]
437    if not isinstance(obj, str):
438        obj = saferepr(obj)
439        replaces.append(("\\n", "\n~"))
440
441    for r1, r2 in replaces:
442        obj = obj.replace(r1, r2)
443
444    return obj
445
446
447def _should_repr_global_name(obj: object) -> bool:
448    if callable(obj):
449        return False
450
451    try:
452        return not hasattr(obj, "__name__")
453    except Exception:
454        return True
455
456
457def _format_boolop(explanations: Iterable[str], is_or: bool) -> str:
458    explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
459    return explanation.replace("%", "%%")
460
461
462def _call_reprcompare(
463    ops: Sequence[str],
464    results: Sequence[bool],
465    expls: Sequence[str],
466    each_obj: Sequence[object],
467) -> str:
468    for i, res, expl in zip(range(len(ops)), results, expls):
469        try:
470            done = not res
471        except Exception:
472            done = True
473        if done:
474            break
475    if util._reprcompare is not None:
476        custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
477        if custom is not None:
478            return custom
479    return expl
480
481
482def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
483    if util._assertion_pass is not None:
484        util._assertion_pass(lineno, orig, expl)
485
486
487def _check_if_assertion_pass_impl() -> bool:
488    """Check if any plugins implement the pytest_assertion_pass hook
489    in order not to generate explanation unecessarily (might be expensive)."""
490    return True if util._assertion_pass else False
491
492
493UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
494
495BINOP_MAP = {
496    ast.BitOr: "|",
497    ast.BitXor: "^",
498    ast.BitAnd: "&",
499    ast.LShift: "<<",
500    ast.RShift: ">>",
501    ast.Add: "+",
502    ast.Sub: "-",
503    ast.Mult: "*",
504    ast.Div: "/",
505    ast.FloorDiv: "//",
506    ast.Mod: "%%",  # escaped for string formatting
507    ast.Eq: "==",
508    ast.NotEq: "!=",
509    ast.Lt: "<",
510    ast.LtE: "<=",
511    ast.Gt: ">",
512    ast.GtE: ">=",
513    ast.Pow: "**",
514    ast.Is: "is",
515    ast.IsNot: "is not",
516    ast.In: "in",
517    ast.NotIn: "not in",
518    ast.MatMult: "@",
519}
520
521
522def set_location(node, lineno, col_offset):
523    """Set node location information recursively."""
524
525    def _fix(node, lineno, col_offset):
526        if "lineno" in node._attributes:
527            node.lineno = lineno
528        if "col_offset" in node._attributes:
529            node.col_offset = col_offset
530        for child in ast.iter_child_nodes(node):
531            _fix(child, lineno, col_offset)
532
533    _fix(node, lineno, col_offset)
534    return node
535
536
537def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
538    """Return a mapping from {lineno: "assertion test expression"}."""
539    ret = {}  # type: Dict[int, str]
540
541    depth = 0
542    lines = []  # type: List[str]
543    assert_lineno = None  # type: Optional[int]
544    seen_lines = set()  # type: Set[int]
545
546    def _write_and_reset() -> None:
547        nonlocal depth, lines, assert_lineno, seen_lines
548        assert assert_lineno is not None
549        ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
550        depth = 0
551        lines = []
552        assert_lineno = None
553        seen_lines = set()
554
555    tokens = tokenize.tokenize(io.BytesIO(src).readline)
556    for tp, source, (lineno, offset), _, line in tokens:
557        if tp == tokenize.NAME and source == "assert":
558            assert_lineno = lineno
559        elif assert_lineno is not None:
560            # keep track of depth for the assert-message `,` lookup
561            if tp == tokenize.OP and source in "([{":
562                depth += 1
563            elif tp == tokenize.OP and source in ")]}":
564                depth -= 1
565
566            if not lines:
567                lines.append(line[offset:])
568                seen_lines.add(lineno)
569            # a non-nested comma separates the expression from the message
570            elif depth == 0 and tp == tokenize.OP and source == ",":
571                # one line assert with message
572                if lineno in seen_lines and len(lines) == 1:
573                    offset_in_trimmed = offset + len(lines[-1]) - len(line)
574                    lines[-1] = lines[-1][:offset_in_trimmed]
575                # multi-line assert with message
576                elif lineno in seen_lines:
577                    lines[-1] = lines[-1][:offset]
578                # multi line assert with escapd newline before message
579                else:
580                    lines.append(line[:offset])
581                _write_and_reset()
582            elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
583                _write_and_reset()
584            elif lines and lineno not in seen_lines:
585                lines.append(line)
586                seen_lines.add(lineno)
587
588    return ret
589
590
591class AssertionRewriter(ast.NodeVisitor):
592    """Assertion rewriting implementation.
593
594    The main entrypoint is to call .run() with an ast.Module instance,
595    this will then find all the assert statements and rewrite them to
596    provide intermediate values and a detailed assertion error.  See
597    http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
598    for an overview of how this works.
599
600    The entry point here is .run() which will iterate over all the
601    statements in an ast.Module and for each ast.Assert statement it
602    finds call .visit() with it.  Then .visit_Assert() takes over and
603    is responsible for creating new ast statements to replace the
604    original assert statement: it rewrites the test of an assertion
605    to provide intermediate values and replace it with an if statement
606    which raises an assertion error with a detailed explanation in
607    case the expression is false and calls pytest_assertion_pass hook
608    if expression is true.
609
610    For this .visit_Assert() uses the visitor pattern to visit all the
611    AST nodes of the ast.Assert.test field, each visit call returning
612    an AST node and the corresponding explanation string.  During this
613    state is kept in several instance attributes:
614
615    :statements: All the AST statements which will replace the assert
616       statement.
617
618    :variables: This is populated by .variable() with each variable
619       used by the statements so that they can all be set to None at
620       the end of the statements.
621
622    :variable_counter: Counter to create new unique variables needed
623       by statements.  Variables are created using .variable() and
624       have the form of "@py_assert0".
625
626    :expl_stmts: The AST statements which will be executed to get
627       data from the assertion.  This is the code which will construct
628       the detailed assertion message that is used in the AssertionError
629       or for the pytest_assertion_pass hook.
630
631    :explanation_specifiers: A dict filled by .explanation_param()
632       with %-formatting placeholders and their corresponding
633       expressions to use in the building of an assertion message.
634       This is used by .pop_format_context() to build a message.
635
636    :stack: A stack of the explanation_specifiers dicts maintained by
637       .push_format_context() and .pop_format_context() which allows
638       to build another %-formatted string while already building one.
639
640    This state is reset on every new assert statement visited and used
641    by the other visitors.
642    """
643
644    def __init__(
645        self, module_path: Optional[str], config: Optional[Config], source: bytes
646    ) -> None:
647        super().__init__()
648        self.module_path = module_path
649        self.config = config
650        if config is not None:
651            self.enable_assertion_pass_hook = config.getini(
652                "enable_assertion_pass_hook"
653            )
654        else:
655            self.enable_assertion_pass_hook = False
656        self.source = source
657
658    @functools.lru_cache(maxsize=1)
659    def _assert_expr_to_lineno(self) -> Dict[int, str]:
660        return _get_assertion_exprs(self.source)
661
662    def run(self, mod: ast.Module) -> None:
663        """Find all assert statements in *mod* and rewrite them."""
664        if not mod.body:
665            # Nothing to do.
666            return
667        # Insert some special imports at the top of the module but after any
668        # docstrings and __future__ imports.
669        aliases = [
670            ast.alias("builtins", "@py_builtins"),
671            ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
672        ]
673        doc = getattr(mod, "docstring", None)
674        expect_docstring = doc is None
675        if doc is not None and self.is_rewrite_disabled(doc):
676            return
677        pos = 0
678        lineno = 1
679        for item in mod.body:
680            if (
681                expect_docstring
682                and isinstance(item, ast.Expr)
683                and isinstance(item.value, ast.Str)
684            ):
685                doc = item.value.s
686                if self.is_rewrite_disabled(doc):
687                    return
688                expect_docstring = False
689            elif (
690                isinstance(item, ast.ImportFrom)
691                and item.level == 0
692                and item.module == "__future__"
693            ):
694                pass
695            else:
696                break
697            pos += 1
698        # Special case: for a decorated function, set the lineno to that of the
699        # first decorator, not the `def`. Issue #4984.
700        if isinstance(item, ast.FunctionDef) and item.decorator_list:
701            lineno = item.decorator_list[0].lineno
702        else:
703            lineno = item.lineno
704        imports = [
705            ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
706        ]
707        mod.body[pos:pos] = imports
708        # Collect asserts.
709        nodes = [mod]  # type: List[ast.AST]
710        while nodes:
711            node = nodes.pop()
712            for name, field in ast.iter_fields(node):
713                if isinstance(field, list):
714                    new = []  # type: List[ast.AST]
715                    for i, child in enumerate(field):
716                        if isinstance(child, ast.Assert):
717                            # Transform assert.
718                            new.extend(self.visit(child))
719                        else:
720                            new.append(child)
721                            if isinstance(child, ast.AST):
722                                nodes.append(child)
723                    setattr(node, name, new)
724                elif (
725                    isinstance(field, ast.AST)
726                    # Don't recurse into expressions as they can't contain
727                    # asserts.
728                    and not isinstance(field, ast.expr)
729                ):
730                    nodes.append(field)
731
732    @staticmethod
733    def is_rewrite_disabled(docstring: str) -> bool:
734        return "PYTEST_DONT_REWRITE" in docstring
735
736    def variable(self) -> str:
737        """Get a new variable."""
738        # Use a character invalid in python identifiers to avoid clashing.
739        name = "@py_assert" + str(next(self.variable_counter))
740        self.variables.append(name)
741        return name
742
743    def assign(self, expr: ast.expr) -> ast.Name:
744        """Give *expr* a name."""
745        name = self.variable()
746        self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
747        return ast.Name(name, ast.Load())
748
749    def display(self, expr: ast.expr) -> ast.expr:
750        """Call saferepr on the expression."""
751        return self.helper("_saferepr", expr)
752
753    def helper(self, name: str, *args: ast.expr) -> ast.expr:
754        """Call a helper in this module."""
755        py_name = ast.Name("@pytest_ar", ast.Load())
756        attr = ast.Attribute(py_name, name, ast.Load())
757        return ast.Call(attr, list(args), [])
758
759    def builtin(self, name: str) -> ast.Attribute:
760        """Return the builtin called *name*."""
761        builtin_name = ast.Name("@py_builtins", ast.Load())
762        return ast.Attribute(builtin_name, name, ast.Load())
763
764    def explanation_param(self, expr: ast.expr) -> str:
765        """Return a new named %-formatting placeholder for expr.
766
767        This creates a %-formatting placeholder for expr in the
768        current formatting context, e.g. ``%(py0)s``.  The placeholder
769        and expr are placed in the current format context so that it
770        can be used on the next call to .pop_format_context().
771        """
772        specifier = "py" + str(next(self.variable_counter))
773        self.explanation_specifiers[specifier] = expr
774        return "%(" + specifier + ")s"
775
776    def push_format_context(self) -> None:
777        """Create a new formatting context.
778
779        The format context is used for when an explanation wants to
780        have a variable value formatted in the assertion message.  In
781        this case the value required can be added using
782        .explanation_param().  Finally .pop_format_context() is used
783        to format a string of %-formatted values as added by
784        .explanation_param().
785        """
786        self.explanation_specifiers = {}  # type: Dict[str, ast.expr]
787        self.stack.append(self.explanation_specifiers)
788
789    def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
790        """Format the %-formatted string with current format context.
791
792        The expl_expr should be an str ast.expr instance constructed from
793        the %-placeholders created by .explanation_param().  This will
794        add the required code to format said string to .expl_stmts and
795        return the ast.Name instance of the formatted string.
796        """
797        current = self.stack.pop()
798        if self.stack:
799            self.explanation_specifiers = self.stack[-1]
800        keys = [ast.Str(key) for key in current.keys()]
801        format_dict = ast.Dict(keys, list(current.values()))
802        form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
803        name = "@py_format" + str(next(self.variable_counter))
804        if self.enable_assertion_pass_hook:
805            self.format_variables.append(name)
806        self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
807        return ast.Name(name, ast.Load())
808
809    def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
810        """Handle expressions we don't have custom code for."""
811        assert isinstance(node, ast.expr)
812        res = self.assign(node)
813        return res, self.explanation_param(self.display(res))
814
815    def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
816        """Return the AST statements to replace the ast.Assert instance.
817
818        This rewrites the test of an assertion to provide
819        intermediate values and replace it with an if statement which
820        raises an assertion error with a detailed explanation in case
821        the expression is false.
822        """
823        if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
824            from _pytest.warning_types import PytestAssertRewriteWarning
825            import warnings
826
827            # TODO: This assert should not be needed.
828            assert self.module_path is not None
829            warnings.warn_explicit(
830                PytestAssertRewriteWarning(
831                    "assertion is always true, perhaps remove parentheses?"
832                ),
833                category=None,
834                filename=fspath(self.module_path),
835                lineno=assert_.lineno,
836            )
837
838        self.statements = []  # type: List[ast.stmt]
839        self.variables = []  # type: List[str]
840        self.variable_counter = itertools.count()
841
842        if self.enable_assertion_pass_hook:
843            self.format_variables = []  # type: List[str]
844
845        self.stack = []  # type: List[Dict[str, ast.expr]]
846        self.expl_stmts = []  # type: List[ast.stmt]
847        self.push_format_context()
848        # Rewrite assert into a bunch of statements.
849        top_condition, explanation = self.visit(assert_.test)
850
851        negation = ast.UnaryOp(ast.Not(), top_condition)
852
853        if self.enable_assertion_pass_hook:  # Experimental pytest_assertion_pass hook
854            msg = self.pop_format_context(ast.Str(explanation))
855
856            # Failed
857            if assert_.msg:
858                assertmsg = self.helper("_format_assertmsg", assert_.msg)
859                gluestr = "\n>assert "
860            else:
861                assertmsg = ast.Str("")
862                gluestr = "assert "
863            err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
864            err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
865            err_name = ast.Name("AssertionError", ast.Load())
866            fmt = self.helper("_format_explanation", err_msg)
867            exc = ast.Call(err_name, [fmt], [])
868            raise_ = ast.Raise(exc, None)
869            statements_fail = []
870            statements_fail.extend(self.expl_stmts)
871            statements_fail.append(raise_)
872
873            # Passed
874            fmt_pass = self.helper("_format_explanation", msg)
875            orig = self._assert_expr_to_lineno()[assert_.lineno]
876            hook_call_pass = ast.Expr(
877                self.helper(
878                    "_call_assertion_pass",
879                    ast.Num(assert_.lineno),
880                    ast.Str(orig),
881                    fmt_pass,
882                )
883            )
884            # If any hooks implement assert_pass hook
885            hook_impl_test = ast.If(
886                self.helper("_check_if_assertion_pass_impl"),
887                self.expl_stmts + [hook_call_pass],
888                [],
889            )
890            statements_pass = [hook_impl_test]
891
892            # Test for assertion condition
893            main_test = ast.If(negation, statements_fail, statements_pass)
894            self.statements.append(main_test)
895            if self.format_variables:
896                variables = [
897                    ast.Name(name, ast.Store()) for name in self.format_variables
898                ]
899                clear_format = ast.Assign(variables, ast.NameConstant(None))
900                self.statements.append(clear_format)
901
902        else:  # Original assertion rewriting
903            # Create failure message.
904            body = self.expl_stmts
905            self.statements.append(ast.If(negation, body, []))
906            if assert_.msg:
907                assertmsg = self.helper("_format_assertmsg", assert_.msg)
908                explanation = "\n>assert " + explanation
909            else:
910                assertmsg = ast.Str("")
911                explanation = "assert " + explanation
912            template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
913            msg = self.pop_format_context(template)
914            fmt = self.helper("_format_explanation", msg)
915            err_name = ast.Name("AssertionError", ast.Load())
916            exc = ast.Call(err_name, [fmt], [])
917            raise_ = ast.Raise(exc, None)
918
919            body.append(raise_)
920
921        # Clear temporary variables by setting them to None.
922        if self.variables:
923            variables = [ast.Name(name, ast.Store()) for name in self.variables]
924            clear = ast.Assign(variables, ast.NameConstant(None))
925            self.statements.append(clear)
926        # Fix line numbers.
927        for stmt in self.statements:
928            set_location(stmt, assert_.lineno, assert_.col_offset)
929        return self.statements
930
931    def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
932        # Display the repr of the name if it's a local variable or
933        # _should_repr_global_name() thinks it's acceptable.
934        locs = ast.Call(self.builtin("locals"), [], [])
935        inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
936        dorepr = self.helper("_should_repr_global_name", name)
937        test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
938        expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
939        return name, self.explanation_param(expr)
940
941    def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
942        res_var = self.variable()
943        expl_list = self.assign(ast.List([], ast.Load()))
944        app = ast.Attribute(expl_list, "append", ast.Load())
945        is_or = int(isinstance(boolop.op, ast.Or))
946        body = save = self.statements
947        fail_save = self.expl_stmts
948        levels = len(boolop.values) - 1
949        self.push_format_context()
950        # Process each operand, short-circuiting if needed.
951        for i, v in enumerate(boolop.values):
952            if i:
953                fail_inner = []  # type: List[ast.stmt]
954                # cond is set in a prior loop iteration below
955                self.expl_stmts.append(ast.If(cond, fail_inner, []))  # noqa
956                self.expl_stmts = fail_inner
957            self.push_format_context()
958            res, expl = self.visit(v)
959            body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
960            expl_format = self.pop_format_context(ast.Str(expl))
961            call = ast.Call(app, [expl_format], [])
962            self.expl_stmts.append(ast.Expr(call))
963            if i < levels:
964                cond = res  # type: ast.expr
965                if is_or:
966                    cond = ast.UnaryOp(ast.Not(), cond)
967                inner = []  # type: List[ast.stmt]
968                self.statements.append(ast.If(cond, inner, []))
969                self.statements = body = inner
970        self.statements = save
971        self.expl_stmts = fail_save
972        expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
973        expl = self.pop_format_context(expl_template)
974        return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
975
976    def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
977        pattern = UNARY_MAP[unary.op.__class__]
978        operand_res, operand_expl = self.visit(unary.operand)
979        res = self.assign(ast.UnaryOp(unary.op, operand_res))
980        return res, pattern % (operand_expl,)
981
982    def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
983        symbol = BINOP_MAP[binop.op.__class__]
984        left_expr, left_expl = self.visit(binop.left)
985        right_expr, right_expl = self.visit(binop.right)
986        explanation = "({} {} {})".format(left_expl, symbol, right_expl)
987        res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
988        return res, explanation
989
990    def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
991        new_func, func_expl = self.visit(call.func)
992        arg_expls = []
993        new_args = []
994        new_kwargs = []
995        for arg in call.args:
996            res, expl = self.visit(arg)
997            arg_expls.append(expl)
998            new_args.append(res)
999        for keyword in call.keywords:
1000            res, expl = self.visit(keyword.value)
1001            new_kwargs.append(ast.keyword(keyword.arg, res))
1002            if keyword.arg:
1003                arg_expls.append(keyword.arg + "=" + expl)
1004            else:  # **args have `arg` keywords with an .arg of None
1005                arg_expls.append("**" + expl)
1006
1007        expl = "{}({})".format(func_expl, ", ".join(arg_expls))
1008        new_call = ast.Call(new_func, new_args, new_kwargs)
1009        res = self.assign(new_call)
1010        res_expl = self.explanation_param(self.display(res))
1011        outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
1012        return res, outer_expl
1013
1014    def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
1015        # From Python 3.5, a Starred node can appear in a function call.
1016        res, expl = self.visit(starred.value)
1017        new_starred = ast.Starred(res, starred.ctx)
1018        return new_starred, "*" + expl
1019
1020    def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
1021        if not isinstance(attr.ctx, ast.Load):
1022            return self.generic_visit(attr)
1023        value, value_expl = self.visit(attr.value)
1024        res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
1025        res_expl = self.explanation_param(self.display(res))
1026        pat = "%s\n{%s = %s.%s\n}"
1027        expl = pat % (res_expl, res_expl, value_expl, attr.attr)
1028        return res, expl
1029
1030    def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
1031        self.push_format_context()
1032        left_res, left_expl = self.visit(comp.left)
1033        if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
1034            left_expl = "({})".format(left_expl)
1035        res_variables = [self.variable() for i in range(len(comp.ops))]
1036        load_names = [ast.Name(v, ast.Load()) for v in res_variables]
1037        store_names = [ast.Name(v, ast.Store()) for v in res_variables]
1038        it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
1039        expls = []
1040        syms = []
1041        results = [left_res]
1042        for i, op, next_operand in it:
1043            next_res, next_expl = self.visit(next_operand)
1044            if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
1045                next_expl = "({})".format(next_expl)
1046            results.append(next_res)
1047            sym = BINOP_MAP[op.__class__]
1048            syms.append(ast.Str(sym))
1049            expl = "{} {} {}".format(left_expl, sym, next_expl)
1050            expls.append(ast.Str(expl))
1051            res_expr = ast.Compare(left_res, [op], [next_res])
1052            self.statements.append(ast.Assign([store_names[i]], res_expr))
1053            left_res, left_expl = next_res, next_expl
1054        # Use pytest.assertion.util._reprcompare if that's available.
1055        expl_call = self.helper(
1056            "_call_reprcompare",
1057            ast.Tuple(syms, ast.Load()),
1058            ast.Tuple(load_names, ast.Load()),
1059            ast.Tuple(expls, ast.Load()),
1060            ast.Tuple(results, ast.Load()),
1061        )
1062        if len(comp.ops) > 1:
1063            res = ast.BoolOp(ast.And(), load_names)  # type: ast.expr
1064        else:
1065            res = load_names[0]
1066        return res, self.explanation_param(self.pop_format_context(expl_call))
1067
1068
1069def try_makedirs(cache_dir: Path) -> bool:
1070    """Attempt to create the given directory and sub-directories exist.
1071
1072    Returns True if successful or if it already exists.
1073    """
1074    try:
1075        os.makedirs(fspath(cache_dir), exist_ok=True)
1076    except (FileNotFoundError, NotADirectoryError, FileExistsError):
1077        # One of the path components was not a directory:
1078        # - we're in a zip file
1079        # - it is a file
1080        return False
1081    except PermissionError:
1082        return False
1083    except OSError as e:
1084        # as of now, EROFS doesn't have an equivalent OSError-subclass
1085        if e.errno == errno.EROFS:
1086            return False
1087        raise
1088    return True
1089
1090
1091def get_cache_dir(file_path: Path) -> Path:
1092    """Return the cache directory to write .pyc files for the given .py file path."""
1093    if sys.version_info >= (3, 8) and sys.pycache_prefix:
1094        # given:
1095        #   prefix = '/tmp/pycs'
1096        #   path = '/home/user/proj/test_app.py'
1097        # we want:
1098        #   '/tmp/pycs/home/user/proj'
1099        return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
1100    else:
1101        # classic pycache directory
1102        return file_path.parent / "__pycache__"
1103