1import ast
2import errno
3import glob
4import importlib
5import os
6import py_compile
7import stat
8import sys
9import textwrap
10import zipfile
11from functools import partial
12from typing import Dict
13from typing import List
14from typing import Mapping
15from typing import Optional
16from typing import Set
17
18import py
19
20import _pytest._code
21import pytest
22from _pytest.assertion import util
23from _pytest.assertion.rewrite import _get_assertion_exprs
24from _pytest.assertion.rewrite import AssertionRewritingHook
25from _pytest.assertion.rewrite import get_cache_dir
26from _pytest.assertion.rewrite import PYC_TAIL
27from _pytest.assertion.rewrite import PYTEST_TAG
28from _pytest.assertion.rewrite import rewrite_asserts
29from _pytest.config import ExitCode
30from _pytest.pathlib import make_numbered_dir
31from _pytest.pathlib import Path
32from _pytest.pytester import Testdir
33
34
35def rewrite(src: str) -> ast.Module:
36    tree = ast.parse(src)
37    rewrite_asserts(tree, src.encode())
38    return tree
39
40
41def getmsg(
42    f, extra_ns: Optional[Mapping[str, object]] = None, *, must_pass: bool = False
43) -> Optional[str]:
44    """Rewrite the assertions in f, run it, and get the failure message."""
45    src = "\n".join(_pytest._code.Code(f).source().lines)
46    mod = rewrite(src)
47    code = compile(mod, "<test>", "exec")
48    ns = {}  # type: Dict[str, object]
49    if extra_ns is not None:
50        ns.update(extra_ns)
51    exec(code, ns)
52    func = ns[f.__name__]
53    try:
54        func()  # type: ignore[operator]
55    except AssertionError:
56        if must_pass:
57            pytest.fail("shouldn't have raised")
58        s = str(sys.exc_info()[1])
59        if not s.startswith("assert"):
60            return "AssertionError: " + s
61        return s
62    else:
63        if not must_pass:
64            pytest.fail("function didn't raise at all")
65        return None
66
67
68class TestAssertionRewrite:
69    def test_place_initial_imports(self):
70        s = """'Doc string'\nother = stuff"""
71        m = rewrite(s)
72        assert isinstance(m.body[0], ast.Expr)
73        for imp in m.body[1:3]:
74            assert isinstance(imp, ast.Import)
75            assert imp.lineno == 2
76            assert imp.col_offset == 0
77        assert isinstance(m.body[3], ast.Assign)
78        s = """from __future__ import division\nother_stuff"""
79        m = rewrite(s)
80        assert isinstance(m.body[0], ast.ImportFrom)
81        for imp in m.body[1:3]:
82            assert isinstance(imp, ast.Import)
83            assert imp.lineno == 2
84            assert imp.col_offset == 0
85        assert isinstance(m.body[3], ast.Expr)
86        s = """'doc string'\nfrom __future__ import division"""
87        m = rewrite(s)
88        assert isinstance(m.body[0], ast.Expr)
89        assert isinstance(m.body[1], ast.ImportFrom)
90        for imp in m.body[2:4]:
91            assert isinstance(imp, ast.Import)
92            assert imp.lineno == 2
93            assert imp.col_offset == 0
94        s = """'doc string'\nfrom __future__ import division\nother"""
95        m = rewrite(s)
96        assert isinstance(m.body[0], ast.Expr)
97        assert isinstance(m.body[1], ast.ImportFrom)
98        for imp in m.body[2:4]:
99            assert isinstance(imp, ast.Import)
100            assert imp.lineno == 3
101            assert imp.col_offset == 0
102        assert isinstance(m.body[4], ast.Expr)
103        s = """from . import relative\nother_stuff"""
104        m = rewrite(s)
105        for imp in m.body[:2]:
106            assert isinstance(imp, ast.Import)
107            assert imp.lineno == 1
108            assert imp.col_offset == 0
109        assert isinstance(m.body[3], ast.Expr)
110
111    def test_dont_rewrite(self) -> None:
112        s = """'PYTEST_DONT_REWRITE'\nassert 14"""
113        m = rewrite(s)
114        assert len(m.body) == 2
115        assert isinstance(m.body[1], ast.Assert)
116        assert m.body[1].msg is None
117
118    def test_dont_rewrite_plugin(self, testdir):
119        contents = {
120            "conftest.py": "pytest_plugins = 'plugin'; import plugin",
121            "plugin.py": "'PYTEST_DONT_REWRITE'",
122            "test_foo.py": "def test_foo(): pass",
123        }
124        testdir.makepyfile(**contents)
125        result = testdir.runpytest_subprocess()
126        assert "warning" not in "".join(result.outlines)
127
128    def test_rewrites_plugin_as_a_package(self, testdir):
129        pkgdir = testdir.mkpydir("plugin")
130        pkgdir.join("__init__.py").write(
131            "import pytest\n"
132            "@pytest.fixture\n"
133            "def special_asserter():\n"
134            "    def special_assert(x, y):\n"
135            "        assert x == y\n"
136            "    return special_assert\n"
137        )
138        testdir.makeconftest('pytest_plugins = ["plugin"]')
139        testdir.makepyfile("def test(special_asserter): special_asserter(1, 2)\n")
140        result = testdir.runpytest()
141        result.stdout.fnmatch_lines(["*assert 1 == 2*"])
142
143    def test_honors_pep_235(self, testdir, monkeypatch):
144        # note: couldn't make it fail on macos with a single `sys.path` entry
145        # note: these modules are named `test_*` to trigger rewriting
146        testdir.tmpdir.join("test_y.py").write("x = 1")
147        xdir = testdir.tmpdir.join("x").ensure_dir()
148        xdir.join("test_Y").ensure_dir().join("__init__.py").write("x = 2")
149        testdir.makepyfile(
150            "import test_y\n"
151            "import test_Y\n"
152            "def test():\n"
153            "    assert test_y.x == 1\n"
154            "    assert test_Y.x == 2\n"
155        )
156        monkeypatch.syspath_prepend(xdir)
157        testdir.runpytest().assert_outcomes(passed=1)
158
159    def test_name(self, request) -> None:
160        def f1() -> None:
161            assert False
162
163        assert getmsg(f1) == "assert False"
164
165        def f2() -> None:
166            f = False
167            assert f
168
169        assert getmsg(f2) == "assert False"
170
171        def f3() -> None:
172            assert a_global  # type: ignore[name-defined] # noqa
173
174        assert getmsg(f3, {"a_global": False}) == "assert False"
175
176        def f4() -> None:
177            assert sys == 42  # type: ignore[comparison-overlap]
178
179        verbose = request.config.getoption("verbose")
180        msg = getmsg(f4, {"sys": sys})
181        if verbose > 0:
182            assert msg == (
183                "assert <module 'sys' (built-in)> == 42\n"
184                "  +<module 'sys' (built-in)>\n"
185                "  -42"
186            )
187        else:
188            assert msg == "assert sys == 42"
189
190        def f5() -> None:
191            assert cls == 42  # type: ignore[name-defined]  # noqa: F821
192
193        class X:
194            pass
195
196        msg = getmsg(f5, {"cls": X})
197        assert msg is not None
198        lines = msg.splitlines()
199        if verbose > 1:
200            assert lines == [
201                "assert {!r} == 42".format(X),
202                "  +{!r}".format(X),
203                "  -42",
204            ]
205        elif verbose > 0:
206            assert lines == [
207                "assert <class 'test_...e.<locals>.X'> == 42",
208                "  +{!r}".format(X),
209                "  -42",
210            ]
211        else:
212            assert lines == ["assert cls == 42"]
213
214    def test_assertrepr_compare_same_width(self, request) -> None:
215        """Should use same width/truncation with same initial width."""
216
217        def f() -> None:
218            assert "1234567890" * 5 + "A" == "1234567890" * 5 + "B"
219
220        msg = getmsg(f)
221        assert msg is not None
222        line = msg.splitlines()[0]
223        if request.config.getoption("verbose") > 1:
224            assert line == (
225                "assert '12345678901234567890123456789012345678901234567890A' "
226                "== '12345678901234567890123456789012345678901234567890B'"
227            )
228        else:
229            assert line == (
230                "assert '123456789012...901234567890A' "
231                "== '123456789012...901234567890B'"
232            )
233
234    def test_dont_rewrite_if_hasattr_fails(self, request) -> None:
235        class Y:
236            """A class whose getattr fails, but not with `AttributeError`."""
237
238            def __getattr__(self, attribute_name):
239                raise KeyError()
240
241            def __repr__(self) -> str:
242                return "Y"
243
244            def __init__(self) -> None:
245                self.foo = 3
246
247        def f() -> None:
248            assert cls().foo == 2  # type: ignore[name-defined] # noqa: F821
249
250        # XXX: looks like the "where" should also be there in verbose mode?!
251        msg = getmsg(f, {"cls": Y})
252        assert msg is not None
253        lines = msg.splitlines()
254        if request.config.getoption("verbose") > 0:
255            assert lines == ["assert 3 == 2", "  +3", "  -2"]
256        else:
257            assert lines == [
258                "assert 3 == 2",
259                " +  where 3 = Y.foo",
260                " +    where Y = cls()",
261            ]
262
263    def test_assert_already_has_message(self):
264        def f():
265            assert False, "something bad!"
266
267        assert getmsg(f) == "AssertionError: something bad!\nassert False"
268
269    def test_assertion_message(self, testdir):
270        testdir.makepyfile(
271            """
272            def test_foo():
273                assert 1 == 2, "The failure message"
274        """
275        )
276        result = testdir.runpytest()
277        assert result.ret == 1
278        result.stdout.fnmatch_lines(
279            ["*AssertionError*The failure message*", "*assert 1 == 2*"]
280        )
281
282    def test_assertion_message_multiline(self, testdir):
283        testdir.makepyfile(
284            """
285            def test_foo():
286                assert 1 == 2, "A multiline\\nfailure message"
287        """
288        )
289        result = testdir.runpytest()
290        assert result.ret == 1
291        result.stdout.fnmatch_lines(
292            ["*AssertionError*A multiline*", "*failure message*", "*assert 1 == 2*"]
293        )
294
295    def test_assertion_message_tuple(self, testdir):
296        testdir.makepyfile(
297            """
298            def test_foo():
299                assert 1 == 2, (1, 2)
300        """
301        )
302        result = testdir.runpytest()
303        assert result.ret == 1
304        result.stdout.fnmatch_lines(
305            ["*AssertionError*%s*" % repr((1, 2)), "*assert 1 == 2*"]
306        )
307
308    def test_assertion_message_expr(self, testdir):
309        testdir.makepyfile(
310            """
311            def test_foo():
312                assert 1 == 2, 1 + 2
313        """
314        )
315        result = testdir.runpytest()
316        assert result.ret == 1
317        result.stdout.fnmatch_lines(["*AssertionError*3*", "*assert 1 == 2*"])
318
319    def test_assertion_message_escape(self, testdir):
320        testdir.makepyfile(
321            """
322            def test_foo():
323                assert 1 == 2, 'To be escaped: %'
324        """
325        )
326        result = testdir.runpytest()
327        assert result.ret == 1
328        result.stdout.fnmatch_lines(
329            ["*AssertionError: To be escaped: %", "*assert 1 == 2"]
330        )
331
332    def test_assertion_messages_bytes(self, testdir):
333        testdir.makepyfile("def test_bytes_assertion():\n    assert False, b'ohai!'\n")
334        result = testdir.runpytest()
335        assert result.ret == 1
336        result.stdout.fnmatch_lines(["*AssertionError: b'ohai!'", "*assert False"])
337
338    def test_boolop(self) -> None:
339        def f1() -> None:
340            f = g = False
341            assert f and g
342
343        assert getmsg(f1) == "assert (False)"
344
345        def f2() -> None:
346            f = True
347            g = False
348            assert f and g
349
350        assert getmsg(f2) == "assert (True and False)"
351
352        def f3() -> None:
353            f = False
354            g = True
355            assert f and g
356
357        assert getmsg(f3) == "assert (False)"
358
359        def f4() -> None:
360            f = g = False
361            assert f or g
362
363        assert getmsg(f4) == "assert (False or False)"
364
365        def f5() -> None:
366            f = g = False
367            assert not f and not g
368
369        getmsg(f5, must_pass=True)
370
371        def x() -> bool:
372            return False
373
374        def f6() -> None:
375            assert x() and x()
376
377        assert (
378            getmsg(f6, {"x": x})
379            == """assert (False)
380 +  where False = x()"""
381        )
382
383        def f7() -> None:
384            assert False or x()  # type: ignore[unreachable]
385
386        assert (
387            getmsg(f7, {"x": x})
388            == """assert (False or False)
389 +  where False = x()"""
390        )
391
392        def f8() -> None:
393            assert 1 in {} and 2 in {}
394
395        assert getmsg(f8) == "assert (1 in {})"
396
397        def f9() -> None:
398            x = 1
399            y = 2
400            assert x in {1: None} and y in {}
401
402        assert getmsg(f9) == "assert (1 in {1: None} and 2 in {})"
403
404        def f10() -> None:
405            f = True
406            g = False
407            assert f or g
408
409        getmsg(f10, must_pass=True)
410
411        def f11() -> None:
412            f = g = h = lambda: True
413            assert f() and g() and h()
414
415        getmsg(f11, must_pass=True)
416
417    def test_short_circuit_evaluation(self) -> None:
418        def f1() -> None:
419            assert True or explode  # type: ignore[name-defined,unreachable] # noqa: F821
420
421        getmsg(f1, must_pass=True)
422
423        def f2() -> None:
424            x = 1
425            assert x == 1 or x == 2
426
427        getmsg(f2, must_pass=True)
428
429    def test_unary_op(self) -> None:
430        def f1() -> None:
431            x = True
432            assert not x
433
434        assert getmsg(f1) == "assert not True"
435
436        def f2() -> None:
437            x = 0
438            assert ~x + 1
439
440        assert getmsg(f2) == "assert (~0 + 1)"
441
442        def f3() -> None:
443            x = 3
444            assert -x + x
445
446        assert getmsg(f3) == "assert (-3 + 3)"
447
448        def f4() -> None:
449            x = 0
450            assert +x + x
451
452        assert getmsg(f4) == "assert (+0 + 0)"
453
454    def test_binary_op(self) -> None:
455        def f1() -> None:
456            x = 1
457            y = -1
458            assert x + y
459
460        assert getmsg(f1) == "assert (1 + -1)"
461
462        def f2() -> None:
463            assert not 5 % 4
464
465        assert getmsg(f2) == "assert not (5 % 4)"
466
467    def test_boolop_percent(self) -> None:
468        def f1() -> None:
469            assert 3 % 2 and False
470
471        assert getmsg(f1) == "assert ((3 % 2) and False)"
472
473        def f2() -> None:
474            assert False or 4 % 2  # type: ignore[unreachable]
475
476        assert getmsg(f2) == "assert (False or (4 % 2))"
477
478    def test_at_operator_issue1290(self, testdir):
479        testdir.makepyfile(
480            """
481            class Matrix(object):
482                def __init__(self, num):
483                    self.num = num
484                def __matmul__(self, other):
485                    return self.num * other.num
486
487            def test_multmat_operator():
488                assert Matrix(2) @ Matrix(3) == 6"""
489        )
490        testdir.runpytest().assert_outcomes(passed=1)
491
492    def test_starred_with_side_effect(self, testdir):
493        """See #4412"""
494        testdir.makepyfile(
495            """\
496            def test():
497                f = lambda x: x
498                x = iter([1, 2, 3])
499                assert 2 * next(x) == f(*[next(x)])
500            """
501        )
502        testdir.runpytest().assert_outcomes(passed=1)
503
504    def test_call(self) -> None:
505        def g(a=42, *args, **kwargs) -> bool:
506            return False
507
508        ns = {"g": g}
509
510        def f1() -> None:
511            assert g()
512
513        assert (
514            getmsg(f1, ns)
515            == """assert False
516 +  where False = g()"""
517        )
518
519        def f2() -> None:
520            assert g(1)
521
522        assert (
523            getmsg(f2, ns)
524            == """assert False
525 +  where False = g(1)"""
526        )
527
528        def f3() -> None:
529            assert g(1, 2)
530
531        assert (
532            getmsg(f3, ns)
533            == """assert False
534 +  where False = g(1, 2)"""
535        )
536
537        def f4() -> None:
538            assert g(1, g=42)
539
540        assert (
541            getmsg(f4, ns)
542            == """assert False
543 +  where False = g(1, g=42)"""
544        )
545
546        def f5() -> None:
547            assert g(1, 3, g=23)
548
549        assert (
550            getmsg(f5, ns)
551            == """assert False
552 +  where False = g(1, 3, g=23)"""
553        )
554
555        def f6() -> None:
556            seq = [1, 2, 3]
557            assert g(*seq)
558
559        assert (
560            getmsg(f6, ns)
561            == """assert False
562 +  where False = g(*[1, 2, 3])"""
563        )
564
565        def f7() -> None:
566            x = "a"
567            assert g(**{x: 2})
568
569        assert (
570            getmsg(f7, ns)
571            == """assert False
572 +  where False = g(**{'a': 2})"""
573        )
574
575    def test_attribute(self) -> None:
576        class X:
577            g = 3
578
579        ns = {"x": X}
580
581        def f1() -> None:
582            assert not x.g  # type: ignore[name-defined] # noqa: F821
583
584        assert (
585            getmsg(f1, ns)
586            == """assert not 3
587 +  where 3 = x.g"""
588        )
589
590        def f2() -> None:
591            x.a = False  # type: ignore[name-defined] # noqa: F821
592            assert x.a  # type: ignore[name-defined] # noqa: F821
593
594        assert (
595            getmsg(f2, ns)
596            == """assert False
597 +  where False = x.a"""
598        )
599
600    def test_comparisons(self) -> None:
601        def f1() -> None:
602            a, b = range(2)
603            assert b < a
604
605        assert getmsg(f1) == """assert 1 < 0"""
606
607        def f2() -> None:
608            a, b, c = range(3)
609            assert a > b > c
610
611        assert getmsg(f2) == """assert 0 > 1"""
612
613        def f3() -> None:
614            a, b, c = range(3)
615            assert a < b > c
616
617        assert getmsg(f3) == """assert 1 > 2"""
618
619        def f4() -> None:
620            a, b, c = range(3)
621            assert a < b <= c
622
623        getmsg(f4, must_pass=True)
624
625        def f5() -> None:
626            a, b, c = range(3)
627            assert a < b
628            assert b < c
629
630        getmsg(f5, must_pass=True)
631
632    def test_len(self, request):
633        def f():
634            values = list(range(10))
635            assert len(values) == 11
636
637        msg = getmsg(f)
638        if request.config.getoption("verbose") > 0:
639            assert msg == "assert 10 == 11\n  +10\n  -11"
640        else:
641            assert msg == "assert 10 == 11\n +  where 10 = len([0, 1, 2, 3, 4, 5, ...])"
642
643    def test_custom_reprcompare(self, monkeypatch) -> None:
644        def my_reprcompare1(op, left, right) -> str:
645            return "42"
646
647        monkeypatch.setattr(util, "_reprcompare", my_reprcompare1)
648
649        def f1() -> None:
650            assert 42 < 3
651
652        assert getmsg(f1) == "assert 42"
653
654        def my_reprcompare2(op, left, right) -> str:
655            return "{} {} {}".format(left, op, right)
656
657        monkeypatch.setattr(util, "_reprcompare", my_reprcompare2)
658
659        def f2() -> None:
660            assert 1 < 3 < 5 <= 4 < 7
661
662        assert getmsg(f2) == "assert 5 <= 4"
663
664    def test_assert_raising__bool__in_comparison(self) -> None:
665        def f() -> None:
666            class A:
667                def __bool__(self):
668                    raise ValueError(42)
669
670                def __lt__(self, other):
671                    return A()
672
673                def __repr__(self):
674                    return "<MY42 object>"
675
676            def myany(x) -> bool:
677                return False
678
679            assert myany(A() < 0)
680
681        msg = getmsg(f)
682        assert msg is not None
683        assert "<MY42 object> < 0" in msg
684
685    def test_formatchar(self) -> None:
686        def f() -> None:
687            assert "%test" == "test"  # type: ignore[comparison-overlap]
688
689        msg = getmsg(f)
690        assert msg is not None
691        assert msg.startswith("assert '%test' == 'test'")
692
693    def test_custom_repr(self, request) -> None:
694        def f() -> None:
695            class Foo:
696                a = 1
697
698                def __repr__(self):
699                    return "\n{ \n~ \n}"
700
701            f = Foo()
702            assert 0 == f.a
703
704        msg = getmsg(f)
705        assert msg is not None
706        lines = util._format_lines([msg])
707        if request.config.getoption("verbose") > 0:
708            assert lines == ["assert 0 == 1\n  +0\n  -1"]
709        else:
710            assert lines == ["assert 0 == 1\n +  where 1 = \\n{ \\n~ \\n}.a"]
711
712    def test_custom_repr_non_ascii(self) -> None:
713        def f() -> None:
714            class A:
715                name = "ä"
716
717                def __repr__(self):
718                    return self.name.encode("UTF-8")  # only legal in python2
719
720            a = A()
721            assert not a.name
722
723        msg = getmsg(f)
724        assert msg is not None
725        assert "UnicodeDecodeError" not in msg
726        assert "UnicodeEncodeError" not in msg
727
728
729class TestRewriteOnImport:
730    def test_pycache_is_a_file(self, testdir):
731        testdir.tmpdir.join("__pycache__").write("Hello")
732        testdir.makepyfile(
733            """
734            def test_rewritten():
735                assert "@py_builtins" in globals()"""
736        )
737        assert testdir.runpytest().ret == 0
738
739    def test_pycache_is_readonly(self, testdir):
740        cache = testdir.tmpdir.mkdir("__pycache__")
741        old_mode = cache.stat().mode
742        cache.chmod(old_mode ^ stat.S_IWRITE)
743        testdir.makepyfile(
744            """
745            def test_rewritten():
746                assert "@py_builtins" in globals()"""
747        )
748        try:
749            assert testdir.runpytest().ret == 0
750        finally:
751            cache.chmod(old_mode)
752
753    def test_zipfile(self, testdir):
754        z = testdir.tmpdir.join("myzip.zip")
755        z_fn = str(z)
756        f = zipfile.ZipFile(z_fn, "w")
757        try:
758            f.writestr("test_gum/__init__.py", "")
759            f.writestr("test_gum/test_lizard.py", "")
760        finally:
761            f.close()
762        z.chmod(256)
763        testdir.makepyfile(
764            """
765            import sys
766            sys.path.append(%r)
767            import test_gum.test_lizard"""
768            % (z_fn,)
769        )
770        assert testdir.runpytest().ret == ExitCode.NO_TESTS_COLLECTED
771
772    def test_readonly(self, testdir):
773        sub = testdir.mkdir("testing")
774        sub.join("test_readonly.py").write(
775            b"""
776def test_rewritten():
777    assert "@py_builtins" in globals()
778            """,
779            "wb",
780        )
781        old_mode = sub.stat().mode
782        sub.chmod(320)
783        try:
784            assert testdir.runpytest().ret == 0
785        finally:
786            sub.chmod(old_mode)
787
788    def test_dont_write_bytecode(self, testdir, monkeypatch):
789        testdir.makepyfile(
790            """
791            import os
792            def test_no_bytecode():
793                assert "__pycache__" in __cached__
794                assert not os.path.exists(__cached__)
795                assert not os.path.exists(os.path.dirname(__cached__))"""
796        )
797        monkeypatch.setenv("PYTHONDONTWRITEBYTECODE", "1")
798        assert testdir.runpytest_subprocess().ret == 0
799
800    def test_orphaned_pyc_file(self, testdir):
801        testdir.makepyfile(
802            """
803            import orphan
804            def test_it():
805                assert orphan.value == 17
806            """
807        )
808        testdir.makepyfile(
809            orphan="""
810            value = 17
811            """
812        )
813        py_compile.compile("orphan.py")
814        os.remove("orphan.py")
815
816        # Python 3 puts the .pyc files in a __pycache__ directory, and will
817        # not import from there without source.  It will import a .pyc from
818        # the source location though.
819        if not os.path.exists("orphan.pyc"):
820            pycs = glob.glob("__pycache__/orphan.*.pyc")
821            assert len(pycs) == 1
822            os.rename(pycs[0], "orphan.pyc")
823
824        assert testdir.runpytest().ret == 0
825
826    def test_cached_pyc_includes_pytest_version(self, testdir, monkeypatch):
827        """Avoid stale caches (#1671)"""
828        monkeypatch.delenv("PYTHONDONTWRITEBYTECODE", raising=False)
829        testdir.makepyfile(
830            test_foo="""
831            def test_foo():
832                assert True
833            """
834        )
835        result = testdir.runpytest_subprocess()
836        assert result.ret == 0
837        found_names = glob.glob(
838            "__pycache__/*-pytest-{}.pyc".format(pytest.__version__)
839        )
840        assert found_names, "pyc with expected tag not found in names: {}".format(
841            glob.glob("__pycache__/*.pyc")
842        )
843
844    @pytest.mark.skipif('"__pypy__" in sys.modules')
845    def test_pyc_vs_pyo(self, testdir, monkeypatch):
846        testdir.makepyfile(
847            """
848            import pytest
849            def test_optimized():
850                "hello"
851                assert test_optimized.__doc__ is None"""
852        )
853        p = make_numbered_dir(root=Path(testdir.tmpdir), prefix="runpytest-")
854        tmp = "--basetemp=%s" % p
855        monkeypatch.setenv("PYTHONOPTIMIZE", "2")
856        monkeypatch.delenv("PYTHONDONTWRITEBYTECODE", raising=False)
857        assert testdir.runpytest_subprocess(tmp).ret == 0
858        tagged = "test_pyc_vs_pyo." + PYTEST_TAG
859        assert tagged + ".pyo" in os.listdir("__pycache__")
860        monkeypatch.undo()
861        monkeypatch.delenv("PYTHONDONTWRITEBYTECODE", raising=False)
862        assert testdir.runpytest_subprocess(tmp).ret == 1
863        assert tagged + ".pyc" in os.listdir("__pycache__")
864
865    def test_package(self, testdir):
866        pkg = testdir.tmpdir.join("pkg")
867        pkg.mkdir()
868        pkg.join("__init__.py").ensure()
869        pkg.join("test_blah.py").write(
870            """
871def test_rewritten():
872    assert "@py_builtins" in globals()"""
873        )
874        assert testdir.runpytest().ret == 0
875
876    def test_translate_newlines(self, testdir):
877        content = "def test_rewritten():\r\n assert '@py_builtins' in globals()"
878        b = content.encode("utf-8")
879        testdir.tmpdir.join("test_newlines.py").write(b, "wb")
880        assert testdir.runpytest().ret == 0
881
882    def test_package_without__init__py(self, testdir):
883        pkg = testdir.mkdir("a_package_without_init_py")
884        pkg.join("module.py").ensure()
885        testdir.makepyfile("import a_package_without_init_py.module")
886        assert testdir.runpytest().ret == ExitCode.NO_TESTS_COLLECTED
887
888    def test_rewrite_warning(self, testdir):
889        testdir.makeconftest(
890            """
891            import pytest
892            pytest.register_assert_rewrite("_pytest")
893        """
894        )
895        # needs to be a subprocess because pytester explicitly disables this warning
896        result = testdir.runpytest_subprocess()
897        result.stdout.fnmatch_lines(["*Module already imported*: _pytest"])
898
899    def test_rewrite_module_imported_from_conftest(self, testdir):
900        testdir.makeconftest(
901            """
902            import test_rewrite_module_imported
903        """
904        )
905        testdir.makepyfile(
906            test_rewrite_module_imported="""
907            def test_rewritten():
908                assert "@py_builtins" in globals()
909        """
910        )
911        assert testdir.runpytest_subprocess().ret == 0
912
913    def test_remember_rewritten_modules(self, pytestconfig, testdir, monkeypatch):
914        """`AssertionRewriteHook` should remember rewritten modules so it
915        doesn't give false positives (#2005)."""
916        monkeypatch.syspath_prepend(testdir.tmpdir)
917        testdir.makepyfile(test_remember_rewritten_modules="")
918        warnings = []
919        hook = AssertionRewritingHook(pytestconfig)
920        monkeypatch.setattr(
921            hook, "_warn_already_imported", lambda code, msg: warnings.append(msg)
922        )
923        spec = hook.find_spec("test_remember_rewritten_modules")
924        assert spec is not None
925        module = importlib.util.module_from_spec(spec)
926        hook.exec_module(module)
927        hook.mark_rewrite("test_remember_rewritten_modules")
928        hook.mark_rewrite("test_remember_rewritten_modules")
929        assert warnings == []
930
931    def test_rewrite_warning_using_pytest_plugins(self, testdir):
932        testdir.makepyfile(
933            **{
934                "conftest.py": "pytest_plugins = ['core', 'gui', 'sci']",
935                "core.py": "",
936                "gui.py": "pytest_plugins = ['core', 'sci']",
937                "sci.py": "pytest_plugins = ['core']",
938                "test_rewrite_warning_pytest_plugins.py": "def test(): pass",
939            }
940        )
941        testdir.chdir()
942        result = testdir.runpytest_subprocess()
943        result.stdout.fnmatch_lines(["*= 1 passed in *=*"])
944        result.stdout.no_fnmatch_line("*pytest-warning summary*")
945
946    def test_rewrite_warning_using_pytest_plugins_env_var(self, testdir, monkeypatch):
947        monkeypatch.setenv("PYTEST_PLUGINS", "plugin")
948        testdir.makepyfile(
949            **{
950                "plugin.py": "",
951                "test_rewrite_warning_using_pytest_plugins_env_var.py": """
952                import plugin
953                pytest_plugins = ['plugin']
954                def test():
955                    pass
956            """,
957            }
958        )
959        testdir.chdir()
960        result = testdir.runpytest_subprocess()
961        result.stdout.fnmatch_lines(["*= 1 passed in *=*"])
962        result.stdout.no_fnmatch_line("*pytest-warning summary*")
963
964
965class TestAssertionRewriteHookDetails:
966    def test_sys_meta_path_munged(self, testdir):
967        testdir.makepyfile(
968            """
969            def test_meta_path():
970                import sys; sys.meta_path = []"""
971        )
972        assert testdir.runpytest().ret == 0
973
974    def test_write_pyc(self, testdir: Testdir, tmpdir, monkeypatch) -> None:
975        from _pytest.assertion.rewrite import _write_pyc
976        from _pytest.assertion import AssertionState
977
978        config = testdir.parseconfig()
979        state = AssertionState(config, "rewrite")
980        source_path = str(tmpdir.ensure("source.py"))
981        pycpath = tmpdir.join("pyc").strpath
982        co = compile("1", "f.py", "single")
983        assert _write_pyc(state, co, os.stat(source_path), pycpath)
984
985        if sys.platform == "win32":
986            from contextlib import contextmanager
987
988            @contextmanager
989            def atomic_write_failed(fn, mode="r", overwrite=False):
990                e = OSError()
991                e.errno = 10
992                raise e
993                yield  # type:ignore[unreachable]
994
995            monkeypatch.setattr(
996                _pytest.assertion.rewrite, "atomic_write", atomic_write_failed
997            )
998        else:
999
1000            def raise_oserror(*args):
1001                raise OSError()
1002
1003            monkeypatch.setattr("os.rename", raise_oserror)
1004
1005        assert not _write_pyc(state, co, os.stat(source_path), pycpath)
1006
1007    def test_resources_provider_for_loader(self, testdir):
1008        """
1009        Attempts to load resources from a package should succeed normally,
1010        even when the AssertionRewriteHook is used to load the modules.
1011
1012        See #366 for details.
1013        """
1014        pytest.importorskip("pkg_resources")
1015
1016        testdir.mkpydir("testpkg")
1017        contents = {
1018            "testpkg/test_pkg": """
1019                import pkg_resources
1020
1021                import pytest
1022                from _pytest.assertion.rewrite import AssertionRewritingHook
1023
1024                def test_load_resource():
1025                    assert isinstance(__loader__, AssertionRewritingHook)
1026                    res = pkg_resources.resource_string(__name__, 'resource.txt')
1027                    res = res.decode('ascii')
1028                    assert res == 'Load me please.'
1029                """
1030        }
1031        testdir.makepyfile(**contents)
1032        testdir.maketxtfile(**{"testpkg/resource": "Load me please."})
1033
1034        result = testdir.runpytest_subprocess()
1035        result.assert_outcomes(passed=1)
1036
1037    def test_read_pyc(self, tmp_path: Path) -> None:
1038        """
1039        Ensure that the `_read_pyc` can properly deal with corrupted pyc files.
1040        In those circumstances it should just give up instead of generating
1041        an exception that is propagated to the caller.
1042        """
1043        import py_compile
1044        from _pytest.assertion.rewrite import _read_pyc
1045
1046        source = tmp_path / "source.py"
1047        pyc = Path(str(source) + "c")
1048
1049        source.write_text("def test(): pass")
1050        py_compile.compile(str(source), str(pyc))
1051
1052        contents = pyc.read_bytes()
1053        strip_bytes = 20  # header is around 8 bytes, strip a little more
1054        assert len(contents) > strip_bytes
1055        pyc.write_bytes(contents[:strip_bytes])
1056
1057        assert _read_pyc(source, pyc) is None  # no error
1058
1059    def test_reload_is_same_and_reloads(self, testdir: Testdir) -> None:
1060        """Reloading a (collected) module after change picks up the change."""
1061        testdir.makeini(
1062            """
1063            [pytest]
1064            python_files = *.py
1065            """
1066        )
1067        testdir.makepyfile(
1068            file="""
1069            def reloaded():
1070                return False
1071
1072            def rewrite_self():
1073                with open(__file__, 'w') as self:
1074                    self.write('def reloaded(): return True')
1075            """,
1076            test_fun="""
1077            import sys
1078            from importlib import reload
1079
1080            def test_loader():
1081                import file
1082                assert not file.reloaded()
1083                file.rewrite_self()
1084                assert sys.modules["file"] is reload(file)
1085                assert file.reloaded()
1086            """,
1087        )
1088        result = testdir.runpytest()
1089        result.stdout.fnmatch_lines(["* 1 passed*"])
1090
1091    def test_get_data_support(self, testdir):
1092        """Implement optional PEP302 api (#808)."""
1093        path = testdir.mkpydir("foo")
1094        path.join("test_foo.py").write(
1095            textwrap.dedent(
1096                """\
1097                class Test(object):
1098                    def test_foo(self):
1099                        import pkgutil
1100                        data = pkgutil.get_data('foo.test_foo', 'data.txt')
1101                        assert data == b'Hey'
1102                """
1103            )
1104        )
1105        path.join("data.txt").write("Hey")
1106        result = testdir.runpytest()
1107        result.stdout.fnmatch_lines(["*1 passed*"])
1108
1109
1110def test_issue731(testdir):
1111    testdir.makepyfile(
1112        """
1113    class LongReprWithBraces(object):
1114        def __repr__(self):
1115           return 'LongReprWithBraces({' + ('a' * 80) + '}' + ('a' * 120) + ')'
1116
1117        def some_method(self):
1118            return False
1119
1120    def test_long_repr():
1121        obj = LongReprWithBraces()
1122        assert obj.some_method()
1123    """
1124    )
1125    result = testdir.runpytest()
1126    result.stdout.no_fnmatch_line("*unbalanced braces*")
1127
1128
1129class TestIssue925:
1130    def test_simple_case(self, testdir):
1131        testdir.makepyfile(
1132            """
1133        def test_ternary_display():
1134            assert (False == False) == False
1135        """
1136        )
1137        result = testdir.runpytest()
1138        result.stdout.fnmatch_lines(["*E*assert (False == False) == False"])
1139
1140    def test_long_case(self, testdir):
1141        testdir.makepyfile(
1142            """
1143        def test_ternary_display():
1144             assert False == (False == True) == True
1145        """
1146        )
1147        result = testdir.runpytest()
1148        result.stdout.fnmatch_lines(["*E*assert (False == True) == True"])
1149
1150    def test_many_brackets(self, testdir):
1151        testdir.makepyfile(
1152            """
1153            def test_ternary_display():
1154                 assert True == ((False == True) == True)
1155            """
1156        )
1157        result = testdir.runpytest()
1158        result.stdout.fnmatch_lines(["*E*assert True == ((False == True) == True)"])
1159
1160
1161class TestIssue2121:
1162    def test_rewrite_python_files_contain_subdirs(self, testdir):
1163        testdir.makepyfile(
1164            **{
1165                "tests/file.py": """
1166                def test_simple_failure():
1167                    assert 1 + 1 == 3
1168                """
1169            }
1170        )
1171        testdir.makeini(
1172            """
1173                [pytest]
1174                python_files = tests/**.py
1175            """
1176        )
1177        result = testdir.runpytest()
1178        result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"])
1179
1180
1181@pytest.mark.skipif(
1182    sys.maxsize <= (2 ** 31 - 1), reason="Causes OverflowError on 32bit systems"
1183)
1184@pytest.mark.parametrize("offset", [-1, +1])
1185def test_source_mtime_long_long(testdir, offset):
1186    """Support modification dates after 2038 in rewritten files (#4903).
1187
1188    pytest would crash with:
1189
1190            fp.write(struct.pack("<ll", mtime, size))
1191        E   struct.error: argument out of range
1192    """
1193    p = testdir.makepyfile(
1194        """
1195        def test(): pass
1196    """
1197    )
1198    # use unsigned long timestamp which overflows signed long,
1199    # which was the cause of the bug
1200    # +1 offset also tests masking of 0xFFFFFFFF
1201    timestamp = 2 ** 32 + offset
1202    os.utime(str(p), (timestamp, timestamp))
1203    result = testdir.runpytest()
1204    assert result.ret == 0
1205
1206
1207def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch) -> None:
1208    """Fix infinite recursion when writing pyc files: if an import happens to be triggered when writing the pyc
1209    file, this would cause another call to the hook, which would trigger another pyc writing, which could
1210    trigger another import, and so on. (#3506)"""
1211    from _pytest.assertion import rewrite as rewritemod
1212
1213    testdir.syspathinsert()
1214    testdir.makepyfile(test_foo="def test_foo(): pass")
1215    testdir.makepyfile(test_bar="def test_bar(): pass")
1216
1217    original_write_pyc = rewritemod._write_pyc
1218
1219    write_pyc_called = []
1220
1221    def spy_write_pyc(*args, **kwargs):
1222        # make a note that we have called _write_pyc
1223        write_pyc_called.append(True)
1224        # try to import a module at this point: we should not try to rewrite this module
1225        assert hook.find_spec("test_bar") is None
1226        return original_write_pyc(*args, **kwargs)
1227
1228    monkeypatch.setattr(rewritemod, "_write_pyc", spy_write_pyc)
1229    monkeypatch.setattr(sys, "dont_write_bytecode", False)
1230
1231    hook = AssertionRewritingHook(pytestconfig)
1232    spec = hook.find_spec("test_foo")
1233    assert spec is not None
1234    module = importlib.util.module_from_spec(spec)
1235    hook.exec_module(module)
1236    assert len(write_pyc_called) == 1
1237
1238
1239class TestEarlyRewriteBailout:
1240    @pytest.fixture
1241    def hook(self, pytestconfig, monkeypatch, testdir) -> AssertionRewritingHook:
1242        """Returns a patched AssertionRewritingHook instance so we can configure its initial paths and track
1243        if PathFinder.find_spec has been called.
1244        """
1245        import importlib.machinery
1246
1247        self.find_spec_calls = []  # type: List[str]
1248        self.initial_paths = set()  # type: Set[py.path.local]
1249
1250        class StubSession:
1251            _initialpaths = self.initial_paths
1252
1253            def isinitpath(self, p):
1254                return p in self._initialpaths
1255
1256        def spy_find_spec(name, path):
1257            self.find_spec_calls.append(name)
1258            return importlib.machinery.PathFinder.find_spec(name, path)
1259
1260        hook = AssertionRewritingHook(pytestconfig)
1261        # use default patterns, otherwise we inherit pytest's testing config
1262        hook.fnpats[:] = ["test_*.py", "*_test.py"]
1263        monkeypatch.setattr(hook, "_find_spec", spy_find_spec)
1264        hook.set_session(StubSession())  # type: ignore[arg-type]
1265        testdir.syspathinsert()
1266        return hook
1267
1268    def test_basic(self, testdir, hook: AssertionRewritingHook) -> None:
1269        """
1270        Ensure we avoid calling PathFinder.find_spec when we know for sure a certain
1271        module will not be rewritten to optimize assertion rewriting (#3918).
1272        """
1273        testdir.makeconftest(
1274            """
1275            import pytest
1276            @pytest.fixture
1277            def fix(): return 1
1278        """
1279        )
1280        testdir.makepyfile(test_foo="def test_foo(): pass")
1281        testdir.makepyfile(bar="def bar(): pass")
1282        foobar_path = testdir.makepyfile(foobar="def foobar(): pass")
1283        self.initial_paths.add(foobar_path)
1284
1285        # conftest files should always be rewritten
1286        assert hook.find_spec("conftest") is not None
1287        assert self.find_spec_calls == ["conftest"]
1288
1289        # files matching "python_files" mask should always be rewritten
1290        assert hook.find_spec("test_foo") is not None
1291        assert self.find_spec_calls == ["conftest", "test_foo"]
1292
1293        # file does not match "python_files": early bailout
1294        assert hook.find_spec("bar") is None
1295        assert self.find_spec_calls == ["conftest", "test_foo"]
1296
1297        # file is an initial path (passed on the command-line): should be rewritten
1298        assert hook.find_spec("foobar") is not None
1299        assert self.find_spec_calls == ["conftest", "test_foo", "foobar"]
1300
1301    def test_pattern_contains_subdirectories(
1302        self, testdir, hook: AssertionRewritingHook
1303    ) -> None:
1304        """If one of the python_files patterns contain subdirectories ("tests/**.py") we can't bailout early
1305        because we need to match with the full path, which can only be found by calling PathFinder.find_spec
1306        """
1307        p = testdir.makepyfile(
1308            **{
1309                "tests/file.py": """\
1310                    def test_simple_failure():
1311                        assert 1 + 1 == 3
1312                """
1313            }
1314        )
1315        testdir.syspathinsert(p.dirpath())
1316        hook.fnpats[:] = ["tests/**.py"]
1317        assert hook.find_spec("file") is not None
1318        assert self.find_spec_calls == ["file"]
1319
1320    @pytest.mark.skipif(
1321        sys.platform.startswith("win32"), reason="cannot remove cwd on Windows"
1322    )
1323    def test_cwd_changed(self, testdir, monkeypatch):
1324        # Setup conditions for py's fspath trying to import pathlib on py34
1325        # always (previously triggered via xdist only).
1326        # Ref: https://github.com/pytest-dev/py/pull/207
1327        monkeypatch.syspath_prepend("")
1328        monkeypatch.delitem(sys.modules, "pathlib", raising=False)
1329
1330        testdir.makepyfile(
1331            **{
1332                "test_setup_nonexisting_cwd.py": """\
1333                    import os
1334                    import shutil
1335                    import tempfile
1336
1337                    d = tempfile.mkdtemp()
1338                    os.chdir(d)
1339                    shutil.rmtree(d)
1340                """,
1341                "test_test.py": """\
1342                    def test():
1343                        pass
1344                """,
1345            }
1346        )
1347        result = testdir.runpytest()
1348        result.stdout.fnmatch_lines(["* 1 passed in *"])
1349
1350
1351class TestAssertionPass:
1352    def test_option_default(self, testdir):
1353        config = testdir.parseconfig()
1354        assert config.getini("enable_assertion_pass_hook") is False
1355
1356    @pytest.fixture
1357    def flag_on(self, testdir):
1358        testdir.makeini("[pytest]\nenable_assertion_pass_hook = True\n")
1359
1360    @pytest.fixture
1361    def hook_on(self, testdir):
1362        testdir.makeconftest(
1363            """\
1364            def pytest_assertion_pass(item, lineno, orig, expl):
1365                raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
1366            """
1367        )
1368
1369    def test_hook_call(self, testdir, flag_on, hook_on):
1370        testdir.makepyfile(
1371            """\
1372            def test_simple():
1373                a=1
1374                b=2
1375                c=3
1376                d=0
1377
1378                assert a+b == c+d
1379
1380            # cover failing assertions with a message
1381            def test_fails():
1382                assert False, "assert with message"
1383            """
1384        )
1385        result = testdir.runpytest()
1386        result.stdout.fnmatch_lines(
1387            "*Assertion Passed: a+b == c+d (1 + 2) == (3 + 0) at line 7*"
1388        )
1389
1390    def test_hook_call_with_parens(self, testdir, flag_on, hook_on):
1391        testdir.makepyfile(
1392            """\
1393            def f(): return 1
1394            def test():
1395                assert f()
1396            """
1397        )
1398        result = testdir.runpytest()
1399        result.stdout.fnmatch_lines("*Assertion Passed: f() 1")
1400
1401    def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch, flag_on):
1402        """Assertion pass should not be called (and hence formatting should
1403        not occur) if there is no hook declared for pytest_assertion_pass"""
1404
1405        def raise_on_assertionpass(*_, **__):
1406            raise Exception("Assertion passed called when it shouldn't!")
1407
1408        monkeypatch.setattr(
1409            _pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
1410        )
1411
1412        testdir.makepyfile(
1413            """\
1414            def test_simple():
1415                a=1
1416                b=2
1417                c=3
1418                d=0
1419
1420                assert a+b == c+d
1421            """
1422        )
1423        result = testdir.runpytest()
1424        result.assert_outcomes(passed=1)
1425
1426    def test_hook_not_called_without_cmd_option(self, testdir, monkeypatch):
1427        """Assertion pass should not be called (and hence formatting should
1428        not occur) if there is no hook declared for pytest_assertion_pass"""
1429
1430        def raise_on_assertionpass(*_, **__):
1431            raise Exception("Assertion passed called when it shouldn't!")
1432
1433        monkeypatch.setattr(
1434            _pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
1435        )
1436
1437        testdir.makeconftest(
1438            """\
1439            def pytest_assertion_pass(item, lineno, orig, expl):
1440                raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
1441            """
1442        )
1443
1444        testdir.makepyfile(
1445            """\
1446            def test_simple():
1447                a=1
1448                b=2
1449                c=3
1450                d=0
1451
1452                assert a+b == c+d
1453            """
1454        )
1455        result = testdir.runpytest()
1456        result.assert_outcomes(passed=1)
1457
1458
1459@pytest.mark.parametrize(
1460    ("src", "expected"),
1461    (
1462        # fmt: off
1463        pytest.param(b"", {}, id="trivial"),
1464        pytest.param(
1465            b"def x(): assert 1\n",
1466            {1: "1"},
1467            id="assert statement not on own line",
1468        ),
1469        pytest.param(
1470            b"def x():\n"
1471            b"    assert 1\n"
1472            b"    assert 1+2\n",
1473            {2: "1", 3: "1+2"},
1474            id="multiple assertions",
1475        ),
1476        pytest.param(
1477            # changes in encoding cause the byte offsets to be different
1478            "# -*- coding: latin1\n"
1479            "def ÀÀÀÀÀ(): assert 1\n".encode("latin1"),
1480            {2: "1"},
1481            id="latin1 encoded on first line\n",
1482        ),
1483        pytest.param(
1484            # using the default utf-8 encoding
1485            "def ÀÀÀÀÀ(): assert 1\n".encode(),
1486            {1: "1"},
1487            id="utf-8 encoded on first line",
1488        ),
1489        pytest.param(
1490            b"def x():\n"
1491            b"    assert (\n"
1492            b"        1 + 2  # comment\n"
1493            b"    )\n",
1494            {2: "(\n        1 + 2  # comment\n    )"},
1495            id="multi-line assertion",
1496        ),
1497        pytest.param(
1498            b"def x():\n"
1499            b"    assert y == [\n"
1500            b"        1, 2, 3\n"
1501            b"    ]\n",
1502            {2: "y == [\n        1, 2, 3\n    ]"},
1503            id="multi line assert with list continuation",
1504        ),
1505        pytest.param(
1506            b"def x():\n"
1507            b"    assert 1 + \\\n"
1508            b"        2\n",
1509            {2: "1 + \\\n        2"},
1510            id="backslash continuation",
1511        ),
1512        pytest.param(
1513            b"def x():\n"
1514            b"    assert x, y\n",
1515            {2: "x"},
1516            id="assertion with message",
1517        ),
1518        pytest.param(
1519            b"def x():\n"
1520            b"    assert (\n"
1521            b"        f(1, 2, 3)\n"
1522            b"    ),  'f did not work!'\n",
1523            {2: "(\n        f(1, 2, 3)\n    )"},
1524            id="assertion with message, test spanning multiple lines",
1525        ),
1526        pytest.param(
1527            b"def x():\n"
1528            b"    assert \\\n"
1529            b"        x\\\n"
1530            b"        , 'failure message'\n",
1531            {2: "x"},
1532            id="escaped newlines plus message",
1533        ),
1534        pytest.param(
1535            b"def x(): assert 5",
1536            {1: "5"},
1537            id="no newline at end of file",
1538        ),
1539        # fmt: on
1540    ),
1541)
1542def test_get_assertion_exprs(src, expected):
1543    assert _get_assertion_exprs(src) == expected
1544
1545
1546def test_try_makedirs(monkeypatch, tmp_path: Path) -> None:
1547    from _pytest.assertion.rewrite import try_makedirs
1548
1549    p = tmp_path / "foo"
1550
1551    # create
1552    assert try_makedirs(p)
1553    assert p.is_dir()
1554
1555    # already exist
1556    assert try_makedirs(p)
1557
1558    # monkeypatch to simulate all error situations
1559    def fake_mkdir(p, exist_ok=False, *, exc):
1560        assert isinstance(p, str)
1561        raise exc
1562
1563    monkeypatch.setattr(os, "makedirs", partial(fake_mkdir, exc=FileNotFoundError()))
1564    assert not try_makedirs(p)
1565
1566    monkeypatch.setattr(os, "makedirs", partial(fake_mkdir, exc=NotADirectoryError()))
1567    assert not try_makedirs(p)
1568
1569    monkeypatch.setattr(os, "makedirs", partial(fake_mkdir, exc=PermissionError()))
1570    assert not try_makedirs(p)
1571
1572    err = OSError()
1573    err.errno = errno.EROFS
1574    monkeypatch.setattr(os, "makedirs", partial(fake_mkdir, exc=err))
1575    assert not try_makedirs(p)
1576
1577    # unhandled OSError should raise
1578    err = OSError()
1579    err.errno = errno.ECHILD
1580    monkeypatch.setattr(os, "makedirs", partial(fake_mkdir, exc=err))
1581    with pytest.raises(OSError) as exc_info:
1582        try_makedirs(p)
1583    assert exc_info.value.errno == errno.ECHILD
1584
1585
1586class TestPyCacheDir:
1587    @pytest.mark.parametrize(
1588        "prefix, source, expected",
1589        [
1590            ("c:/tmp/pycs", "d:/projects/src/foo.py", "c:/tmp/pycs/projects/src"),
1591            (None, "d:/projects/src/foo.py", "d:/projects/src/__pycache__"),
1592            ("/tmp/pycs", "/home/projects/src/foo.py", "/tmp/pycs/home/projects/src"),
1593            (None, "/home/projects/src/foo.py", "/home/projects/src/__pycache__"),
1594        ],
1595    )
1596    def test_get_cache_dir(self, monkeypatch, prefix, source, expected):
1597        if prefix:
1598            if sys.version_info < (3, 8):
1599                pytest.skip("pycache_prefix not available in py<38")
1600            monkeypatch.setattr(sys, "pycache_prefix", prefix)  # type:ignore
1601
1602        assert get_cache_dir(Path(source)) == Path(expected)
1603
1604    @pytest.mark.skipif(
1605        sys.version_info < (3, 8), reason="pycache_prefix not available in py<38"
1606    )
1607    def test_sys_pycache_prefix_integration(self, tmp_path, monkeypatch, testdir):
1608        """Integration test for sys.pycache_prefix (#4730)."""
1609        pycache_prefix = tmp_path / "my/pycs"
1610        monkeypatch.setattr(sys, "pycache_prefix", str(pycache_prefix))
1611        monkeypatch.setattr(sys, "dont_write_bytecode", False)
1612
1613        testdir.makepyfile(
1614            **{
1615                "src/test_foo.py": """
1616                import bar
1617                def test_foo():
1618                    pass
1619            """,
1620                "src/bar/__init__.py": "",
1621            }
1622        )
1623        result = testdir.runpytest()
1624        assert result.ret == 0
1625
1626        test_foo = Path(testdir.tmpdir) / "src/test_foo.py"
1627        bar_init = Path(testdir.tmpdir) / "src/bar/__init__.py"
1628        assert test_foo.is_file()
1629        assert bar_init.is_file()
1630
1631        # test file: rewritten, custom pytest cache tag
1632        test_foo_pyc = get_cache_dir(test_foo) / ("test_foo" + PYC_TAIL)
1633        assert test_foo_pyc.is_file()
1634
1635        # normal file: not touched by pytest, normal cache tag
1636        bar_init_pyc = get_cache_dir(bar_init) / "__init__.{cache_tag}.pyc".format(
1637            cache_tag=sys.implementation.cache_tag
1638        )
1639        assert bar_init_pyc.is_file()
1640