1# -*- coding: utf-8 -*-
2from __future__ import absolute_import
3from __future__ import division
4from __future__ import print_function
5
6import glob
7import os
8import py_compile
9import stat
10import sys
11import textwrap
12import zipfile
13
14import py
15import six
16
17import _pytest._code
18import pytest
19from _pytest.assertion import util
20from _pytest.assertion.rewrite import AssertionRewritingHook
21from _pytest.assertion.rewrite import PYTEST_TAG
22from _pytest.assertion.rewrite import rewrite_asserts
23from _pytest.main import EXIT_NOTESTSCOLLECTED
24
25ast = pytest.importorskip("ast")
26if sys.platform.startswith("java"):
27    # XXX should be xfail
28    pytest.skip("assert rewrite does currently not work on jython")
29
30
31def setup_module(mod):
32    mod._old_reprcompare = util._reprcompare
33    _pytest._code._reprcompare = None
34
35
36def teardown_module(mod):
37    util._reprcompare = mod._old_reprcompare
38    del mod._old_reprcompare
39
40
41def rewrite(src):
42    tree = ast.parse(src)
43    rewrite_asserts(tree)
44    return tree
45
46
47def getmsg(f, extra_ns=None, must_pass=False):
48    """Rewrite the assertions in f, run it, and get the failure message."""
49    src = "\n".join(_pytest._code.Code(f).source().lines)
50    mod = rewrite(src)
51    code = compile(mod, "<test>", "exec")
52    ns = {}
53    if extra_ns is not None:
54        ns.update(extra_ns)
55    exec(code, ns)
56    func = ns[f.__name__]
57    try:
58        func()
59    except AssertionError:
60        if must_pass:
61            pytest.fail("shouldn't have raised")
62        s = six.text_type(sys.exc_info()[1])
63        if not s.startswith("assert"):
64            return "AssertionError: " + s
65        return s
66    else:
67        if not must_pass:
68            pytest.fail("function didn't raise at all")
69
70
71class TestAssertionRewrite(object):
72    def test_place_initial_imports(self):
73        s = """'Doc string'\nother = stuff"""
74        m = rewrite(s)
75        assert isinstance(m.body[0], ast.Expr)
76        for imp in m.body[1:3]:
77            assert isinstance(imp, ast.Import)
78            assert imp.lineno == 2
79            assert imp.col_offset == 0
80        assert isinstance(m.body[3], ast.Assign)
81        s = """from __future__ import division\nother_stuff"""
82        m = rewrite(s)
83        assert isinstance(m.body[0], ast.ImportFrom)
84        for imp in m.body[1:3]:
85            assert isinstance(imp, ast.Import)
86            assert imp.lineno == 2
87            assert imp.col_offset == 0
88        assert isinstance(m.body[3], ast.Expr)
89        s = """'doc string'\nfrom __future__ import division"""
90        m = rewrite(s)
91        assert isinstance(m.body[0], ast.Expr)
92        assert isinstance(m.body[1], ast.ImportFrom)
93        for imp in m.body[2:4]:
94            assert isinstance(imp, ast.Import)
95            assert imp.lineno == 2
96            assert imp.col_offset == 0
97        s = """'doc string'\nfrom __future__ import division\nother"""
98        m = rewrite(s)
99        assert isinstance(m.body[0], ast.Expr)
100        assert isinstance(m.body[1], ast.ImportFrom)
101        for imp in m.body[2:4]:
102            assert isinstance(imp, ast.Import)
103            assert imp.lineno == 3
104            assert imp.col_offset == 0
105        assert isinstance(m.body[4], ast.Expr)
106        s = """from . import relative\nother_stuff"""
107        m = rewrite(s)
108        for imp in m.body[:2]:
109            assert isinstance(imp, ast.Import)
110            assert imp.lineno == 1
111            assert imp.col_offset == 0
112        assert isinstance(m.body[3], ast.Expr)
113
114    def test_dont_rewrite(self):
115        s = """'PYTEST_DONT_REWRITE'\nassert 14"""
116        m = rewrite(s)
117        assert len(m.body) == 2
118        assert m.body[1].msg is None
119
120    def test_dont_rewrite_plugin(self, testdir):
121        contents = {
122            "conftest.py": "pytest_plugins = 'plugin'; import plugin",
123            "plugin.py": "'PYTEST_DONT_REWRITE'",
124            "test_foo.py": "def test_foo(): pass",
125        }
126        testdir.makepyfile(**contents)
127        result = testdir.runpytest_subprocess()
128        assert "warnings" not in "".join(result.outlines)
129
130    def test_name(self, request):
131        def f():
132            assert False
133
134        assert getmsg(f) == "assert False"
135
136        def f():
137            f = False
138            assert f
139
140        assert getmsg(f) == "assert False"
141
142        def f():
143            assert a_global  # noqa
144
145        assert getmsg(f, {"a_global": False}) == "assert False"
146
147        def f():
148            assert sys == 42
149
150        verbose = request.config.getoption("verbose")
151        msg = getmsg(f, {"sys": sys})
152        if verbose > 0:
153            assert msg == (
154                "assert <module 'sys' (built-in)> == 42\n"
155                "  -<module 'sys' (built-in)>\n"
156                "  +42"
157            )
158        else:
159            assert msg == "assert sys == 42"
160
161        def f():
162            assert cls == 42  # noqa: F821
163
164        class X(object):
165            pass
166
167        msg = getmsg(f, {"cls": X}).splitlines()
168        if verbose > 0:
169            if six.PY2:
170                assert msg == [
171                    "assert <class 'test_assertrewrite.X'> == 42",
172                    "  -<class 'test_assertrewrite.X'>",
173                    "  +42",
174                ]
175            else:
176                assert msg == [
177                    "assert <class 'test_...e.<locals>.X'> == 42",
178                    "  -<class 'test_assertrewrite.TestAssertionRewrite.test_name.<locals>.X'>",
179                    "  +42",
180                ]
181        else:
182            assert msg == ["assert cls == 42"]
183
184    def test_dont_rewrite_if_hasattr_fails(self, request):
185        class Y(object):
186            """ A class whos getattr fails, but not with `AttributeError` """
187
188            def __getattr__(self, attribute_name):
189                raise KeyError()
190
191            def __repr__(self):
192                return "Y"
193
194            def __init__(self):
195                self.foo = 3
196
197        def f():
198            assert cls().foo == 2  # noqa
199
200        # XXX: looks like the "where" should also be there in verbose mode?!
201        message = getmsg(f, {"cls": Y}).splitlines()
202        if request.config.getoption("verbose") > 0:
203            assert message == ["assert 3 == 2", "  -3", "  +2"]
204        else:
205            assert message == [
206                "assert 3 == 2",
207                " +  where 3 = Y.foo",
208                " +    where Y = cls()",
209            ]
210
211    def test_assert_already_has_message(self):
212        def f():
213            assert False, "something bad!"
214
215        assert getmsg(f) == "AssertionError: something bad!\nassert False"
216
217    def test_assertion_message(self, testdir):
218        testdir.makepyfile(
219            """
220            def test_foo():
221                assert 1 == 2, "The failure message"
222        """
223        )
224        result = testdir.runpytest()
225        assert result.ret == 1
226        result.stdout.fnmatch_lines(
227            ["*AssertionError*The failure message*", "*assert 1 == 2*"]
228        )
229
230    def test_assertion_message_multiline(self, testdir):
231        testdir.makepyfile(
232            """
233            def test_foo():
234                assert 1 == 2, "A multiline\\nfailure message"
235        """
236        )
237        result = testdir.runpytest()
238        assert result.ret == 1
239        result.stdout.fnmatch_lines(
240            ["*AssertionError*A multiline*", "*failure message*", "*assert 1 == 2*"]
241        )
242
243    def test_assertion_message_tuple(self, testdir):
244        testdir.makepyfile(
245            """
246            def test_foo():
247                assert 1 == 2, (1, 2)
248        """
249        )
250        result = testdir.runpytest()
251        assert result.ret == 1
252        result.stdout.fnmatch_lines(
253            ["*AssertionError*%s*" % repr((1, 2)), "*assert 1 == 2*"]
254        )
255
256    def test_assertion_message_expr(self, testdir):
257        testdir.makepyfile(
258            """
259            def test_foo():
260                assert 1 == 2, 1 + 2
261        """
262        )
263        result = testdir.runpytest()
264        assert result.ret == 1
265        result.stdout.fnmatch_lines(["*AssertionError*3*", "*assert 1 == 2*"])
266
267    def test_assertion_message_escape(self, testdir):
268        testdir.makepyfile(
269            """
270            def test_foo():
271                assert 1 == 2, 'To be escaped: %'
272        """
273        )
274        result = testdir.runpytest()
275        assert result.ret == 1
276        result.stdout.fnmatch_lines(
277            ["*AssertionError: To be escaped: %", "*assert 1 == 2"]
278        )
279
280    @pytest.mark.skipif(
281        sys.version_info < (3,), reason="bytes is a string type in python 2"
282    )
283    def test_assertion_messages_bytes(self, testdir):
284        testdir.makepyfile("def test_bytes_assertion():\n    assert False, b'ohai!'\n")
285        result = testdir.runpytest()
286        assert result.ret == 1
287        result.stdout.fnmatch_lines(["*AssertionError: b'ohai!'", "*assert False"])
288
289    def test_boolop(self):
290        def f():
291            f = g = False
292            assert f and g
293
294        assert getmsg(f) == "assert (False)"
295
296        def f():
297            f = True
298            g = False
299            assert f and g
300
301        assert getmsg(f) == "assert (True and False)"
302
303        def f():
304            f = False
305            g = True
306            assert f and g
307
308        assert getmsg(f) == "assert (False)"
309
310        def f():
311            f = g = False
312            assert f or g
313
314        assert getmsg(f) == "assert (False or False)"
315
316        def f():
317            f = g = False
318            assert not f and not g
319
320        getmsg(f, must_pass=True)
321
322        def x():
323            return False
324
325        def f():
326            assert x() and x()
327
328        assert (
329            getmsg(f, {"x": x})
330            == """assert (False)
331 +  where False = x()"""
332        )
333
334        def f():
335            assert False or x()
336
337        assert (
338            getmsg(f, {"x": x})
339            == """assert (False or False)
340 +  where False = x()"""
341        )
342
343        def f():
344            assert 1 in {} and 2 in {}
345
346        assert getmsg(f) == "assert (1 in {})"
347
348        def f():
349            x = 1
350            y = 2
351            assert x in {1: None} and y in {}
352
353        assert getmsg(f) == "assert (1 in {1: None} and 2 in {})"
354
355        def f():
356            f = True
357            g = False
358            assert f or g
359
360        getmsg(f, must_pass=True)
361
362        def f():
363            f = g = h = lambda: True
364            assert f() and g() and h()
365
366        getmsg(f, must_pass=True)
367
368    def test_short_circuit_evaluation(self):
369        def f():
370            assert True or explode  # noqa
371
372        getmsg(f, must_pass=True)
373
374        def f():
375            x = 1
376            assert x == 1 or x == 2
377
378        getmsg(f, must_pass=True)
379
380    def test_unary_op(self):
381        def f():
382            x = True
383            assert not x
384
385        assert getmsg(f) == "assert not True"
386
387        def f():
388            x = 0
389            assert ~x + 1
390
391        assert getmsg(f) == "assert (~0 + 1)"
392
393        def f():
394            x = 3
395            assert -x + x
396
397        assert getmsg(f) == "assert (-3 + 3)"
398
399        def f():
400            x = 0
401            assert +x + x
402
403        assert getmsg(f) == "assert (+0 + 0)"
404
405    def test_binary_op(self):
406        def f():
407            x = 1
408            y = -1
409            assert x + y
410
411        assert getmsg(f) == "assert (1 + -1)"
412
413        def f():
414            assert not 5 % 4
415
416        assert getmsg(f) == "assert not (5 % 4)"
417
418    def test_boolop_percent(self):
419        def f():
420            assert 3 % 2 and False
421
422        assert getmsg(f) == "assert ((3 % 2) and False)"
423
424        def f():
425            assert False or 4 % 2
426
427        assert getmsg(f) == "assert (False or (4 % 2))"
428
429    @pytest.mark.skipif("sys.version_info < (3,5)")
430    def test_at_operator_issue1290(self, testdir):
431        testdir.makepyfile(
432            """
433            class Matrix(object):
434                def __init__(self, num):
435                    self.num = num
436                def __matmul__(self, other):
437                    return self.num * other.num
438
439            def test_multmat_operator():
440                assert Matrix(2) @ Matrix(3) == 6"""
441        )
442        testdir.runpytest().assert_outcomes(passed=1)
443
444    @pytest.mark.skipif("sys.version_info < (3,5)")
445    def test_starred_with_side_effect(self, testdir):
446        """See #4412"""
447        testdir.makepyfile(
448            """\
449            def test():
450                f = lambda x: x
451                x = iter([1, 2, 3])
452                assert 2 * next(x) == f(*[next(x)])
453            """
454        )
455        testdir.runpytest().assert_outcomes(passed=1)
456
457    def test_call(self):
458        def g(a=42, *args, **kwargs):
459            return False
460
461        ns = {"g": g}
462
463        def f():
464            assert g()
465
466        assert (
467            getmsg(f, ns)
468            == """assert False
469 +  where False = g()"""
470        )
471
472        def f():
473            assert g(1)
474
475        assert (
476            getmsg(f, ns)
477            == """assert False
478 +  where False = g(1)"""
479        )
480
481        def f():
482            assert g(1, 2)
483
484        assert (
485            getmsg(f, ns)
486            == """assert False
487 +  where False = g(1, 2)"""
488        )
489
490        def f():
491            assert g(1, g=42)
492
493        assert (
494            getmsg(f, ns)
495            == """assert False
496 +  where False = g(1, g=42)"""
497        )
498
499        def f():
500            assert g(1, 3, g=23)
501
502        assert (
503            getmsg(f, ns)
504            == """assert False
505 +  where False = g(1, 3, g=23)"""
506        )
507
508        def f():
509            seq = [1, 2, 3]
510            assert g(*seq)
511
512        assert (
513            getmsg(f, ns)
514            == """assert False
515 +  where False = g(*[1, 2, 3])"""
516        )
517
518        def f():
519            x = "a"
520            assert g(**{x: 2})
521
522        assert (
523            getmsg(f, ns)
524            == """assert False
525 +  where False = g(**{'a': 2})"""
526        )
527
528    def test_attribute(self):
529        class X(object):
530            g = 3
531
532        ns = {"x": X}
533
534        def f():
535            assert not x.g  # noqa
536
537        assert (
538            getmsg(f, ns)
539            == """assert not 3
540 +  where 3 = x.g"""
541        )
542
543        def f():
544            x.a = False  # noqa
545            assert x.a  # noqa
546
547        assert (
548            getmsg(f, ns)
549            == """assert False
550 +  where False = x.a"""
551        )
552
553    def test_comparisons(self):
554        def f():
555            a, b = range(2)
556            assert b < a
557
558        assert getmsg(f) == """assert 1 < 0"""
559
560        def f():
561            a, b, c = range(3)
562            assert a > b > c
563
564        assert getmsg(f) == """assert 0 > 1"""
565
566        def f():
567            a, b, c = range(3)
568            assert a < b > c
569
570        assert getmsg(f) == """assert 1 > 2"""
571
572        def f():
573            a, b, c = range(3)
574            assert a < b <= c
575
576        getmsg(f, must_pass=True)
577
578        def f():
579            a, b, c = range(3)
580            assert a < b
581            assert b < c
582
583        getmsg(f, must_pass=True)
584
585    def test_len(self, request):
586        def f():
587            values = list(range(10))
588            assert len(values) == 11
589
590        msg = getmsg(f)
591        if request.config.getoption("verbose") > 0:
592            assert msg == "assert 10 == 11\n  -10\n  +11"
593        else:
594            assert msg == "assert 10 == 11\n +  where 10 = len([0, 1, 2, 3, 4, 5, ...])"
595
596    def test_custom_reprcompare(self, monkeypatch):
597        def my_reprcompare(op, left, right):
598            return "42"
599
600        monkeypatch.setattr(util, "_reprcompare", my_reprcompare)
601
602        def f():
603            assert 42 < 3
604
605        assert getmsg(f) == "assert 42"
606
607        def my_reprcompare(op, left, right):
608            return "{} {} {}".format(left, op, right)
609
610        monkeypatch.setattr(util, "_reprcompare", my_reprcompare)
611
612        def f():
613            assert 1 < 3 < 5 <= 4 < 7
614
615        assert getmsg(f) == "assert 5 <= 4"
616
617    def test_assert_raising_nonzero_in_comparison(self):
618        def f():
619            class A(object):
620                def __nonzero__(self):
621                    raise ValueError(42)
622
623                def __lt__(self, other):
624                    return A()
625
626                def __repr__(self):
627                    return "<MY42 object>"
628
629            def myany(x):
630                return False
631
632            assert myany(A() < 0)
633
634        assert "<MY42 object> < 0" in getmsg(f)
635
636    def test_formatchar(self):
637        def f():
638            assert "%test" == "test"
639
640        assert getmsg(f).startswith("assert '%test' == 'test'")
641
642    def test_custom_repr(self, request):
643        def f():
644            class Foo(object):
645                a = 1
646
647                def __repr__(self):
648                    return "\n{ \n~ \n}"
649
650            f = Foo()
651            assert 0 == f.a
652
653        lines = util._format_lines([getmsg(f)])
654        if request.config.getoption("verbose") > 0:
655            assert lines == ["assert 0 == 1\n  -0\n  +1"]
656        else:
657            assert lines == ["assert 0 == 1\n +  where 1 = \\n{ \\n~ \\n}.a"]
658
659    def test_custom_repr_non_ascii(self):
660        def f():
661            class A(object):
662                name = u"ä"
663
664                def __repr__(self):
665                    return self.name.encode("UTF-8")  # only legal in python2
666
667            a = A()
668            assert not a.name
669
670        msg = getmsg(f)
671        assert "UnicodeDecodeError" not in msg
672        assert "UnicodeEncodeError" not in msg
673
674
675class TestRewriteOnImport(object):
676    def test_pycache_is_a_file(self, testdir):
677        testdir.tmpdir.join("__pycache__").write("Hello")
678        testdir.makepyfile(
679            """
680            def test_rewritten():
681                assert "@py_builtins" in globals()"""
682        )
683        assert testdir.runpytest().ret == 0
684
685    def test_pycache_is_readonly(self, testdir):
686        cache = testdir.tmpdir.mkdir("__pycache__")
687        old_mode = cache.stat().mode
688        cache.chmod(old_mode ^ stat.S_IWRITE)
689        testdir.makepyfile(
690            """
691            def test_rewritten():
692                assert "@py_builtins" in globals()"""
693        )
694        try:
695            assert testdir.runpytest().ret == 0
696        finally:
697            cache.chmod(old_mode)
698
699    def test_zipfile(self, testdir):
700        z = testdir.tmpdir.join("myzip.zip")
701        z_fn = str(z)
702        f = zipfile.ZipFile(z_fn, "w")
703        try:
704            f.writestr("test_gum/__init__.py", "")
705            f.writestr("test_gum/test_lizard.py", "")
706        finally:
707            f.close()
708        z.chmod(256)
709        testdir.makepyfile(
710            """
711            import sys
712            sys.path.append(%r)
713            import test_gum.test_lizard"""
714            % (z_fn,)
715        )
716        assert testdir.runpytest().ret == EXIT_NOTESTSCOLLECTED
717
718    def test_readonly(self, testdir):
719        sub = testdir.mkdir("testing")
720        sub.join("test_readonly.py").write(
721            b"""
722def test_rewritten():
723    assert "@py_builtins" in globals()
724            """,
725            "wb",
726        )
727        old_mode = sub.stat().mode
728        sub.chmod(320)
729        try:
730            assert testdir.runpytest().ret == 0
731        finally:
732            sub.chmod(old_mode)
733
734    def test_dont_write_bytecode(self, testdir, monkeypatch):
735        testdir.makepyfile(
736            """
737            import os
738            def test_no_bytecode():
739                assert "__pycache__" in __cached__
740                assert not os.path.exists(__cached__)
741                assert not os.path.exists(os.path.dirname(__cached__))"""
742        )
743        monkeypatch.setenv("PYTHONDONTWRITEBYTECODE", "1")
744        assert testdir.runpytest_subprocess().ret == 0
745
746    def test_orphaned_pyc_file(self, testdir):
747        if sys.version_info < (3, 0) and hasattr(sys, "pypy_version_info"):
748            pytest.skip("pypy2 doesn't run orphaned pyc files")
749
750        testdir.makepyfile(
751            """
752            import orphan
753            def test_it():
754                assert orphan.value == 17
755            """
756        )
757        testdir.makepyfile(
758            orphan="""
759            value = 17
760            """
761        )
762        py_compile.compile("orphan.py")
763        os.remove("orphan.py")
764
765        # Python 3 puts the .pyc files in a __pycache__ directory, and will
766        # not import from there without source.  It will import a .pyc from
767        # the source location though.
768        if not os.path.exists("orphan.pyc"):
769            pycs = glob.glob("__pycache__/orphan.*.pyc")
770            assert len(pycs) == 1
771            os.rename(pycs[0], "orphan.pyc")
772
773        assert testdir.runpytest().ret == 0
774
775    @pytest.mark.skipif('"__pypy__" in sys.modules')
776    def test_pyc_vs_pyo(self, testdir, monkeypatch):
777        testdir.makepyfile(
778            """
779            import pytest
780            def test_optimized():
781                "hello"
782                assert test_optimized.__doc__ is None"""
783        )
784        p = py.path.local.make_numbered_dir(
785            prefix="runpytest-", keep=None, rootdir=testdir.tmpdir
786        )
787        tmp = "--basetemp=%s" % p
788        monkeypatch.setenv("PYTHONOPTIMIZE", "2")
789        monkeypatch.delenv("PYTHONDONTWRITEBYTECODE", raising=False)
790        assert testdir.runpytest_subprocess(tmp).ret == 0
791        tagged = "test_pyc_vs_pyo." + PYTEST_TAG
792        assert tagged + ".pyo" in os.listdir("__pycache__")
793        monkeypatch.undo()
794        monkeypatch.delenv("PYTHONDONTWRITEBYTECODE", raising=False)
795        assert testdir.runpytest_subprocess(tmp).ret == 1
796        assert tagged + ".pyc" in os.listdir("__pycache__")
797
798    def test_package(self, testdir):
799        pkg = testdir.tmpdir.join("pkg")
800        pkg.mkdir()
801        pkg.join("__init__.py").ensure()
802        pkg.join("test_blah.py").write(
803            """
804def test_rewritten():
805    assert "@py_builtins" in globals()"""
806        )
807        assert testdir.runpytest().ret == 0
808
809    def test_translate_newlines(self, testdir):
810        content = "def test_rewritten():\r\n assert '@py_builtins' in globals()"
811        b = content.encode("utf-8")
812        testdir.tmpdir.join("test_newlines.py").write(b, "wb")
813        assert testdir.runpytest().ret == 0
814
815    @pytest.mark.skipif(
816        sys.version_info < (3, 4),
817        reason="packages without __init__.py not supported on python 2",
818    )
819    def test_package_without__init__py(self, testdir):
820        pkg = testdir.mkdir("a_package_without_init_py")
821        pkg.join("module.py").ensure()
822        testdir.makepyfile("import a_package_without_init_py.module")
823        assert testdir.runpytest().ret == EXIT_NOTESTSCOLLECTED
824
825    def test_rewrite_warning(self, testdir):
826        testdir.makeconftest(
827            """
828            import pytest
829            pytest.register_assert_rewrite("_pytest")
830        """
831        )
832        # needs to be a subprocess because pytester explicitly disables this warning
833        result = testdir.runpytest_subprocess()
834        result.stdout.fnmatch_lines(["*Module already imported*: _pytest"])
835
836    def test_rewrite_module_imported_from_conftest(self, testdir):
837        testdir.makeconftest(
838            """
839            import test_rewrite_module_imported
840        """
841        )
842        testdir.makepyfile(
843            test_rewrite_module_imported="""
844            def test_rewritten():
845                assert "@py_builtins" in globals()
846        """
847        )
848        assert testdir.runpytest_subprocess().ret == 0
849
850    def test_remember_rewritten_modules(self, pytestconfig, testdir, monkeypatch):
851        """
852        AssertionRewriteHook should remember rewritten modules so it
853        doesn't give false positives (#2005).
854        """
855        monkeypatch.syspath_prepend(testdir.tmpdir)
856        testdir.makepyfile(test_remember_rewritten_modules="")
857        warnings = []
858        hook = AssertionRewritingHook(pytestconfig)
859        monkeypatch.setattr(
860            hook, "_warn_already_imported", lambda code, msg: warnings.append(msg)
861        )
862        hook.find_module("test_remember_rewritten_modules")
863        hook.load_module("test_remember_rewritten_modules")
864        hook.mark_rewrite("test_remember_rewritten_modules")
865        hook.mark_rewrite("test_remember_rewritten_modules")
866        assert warnings == []
867
868    def test_rewrite_warning_using_pytest_plugins(self, testdir):
869        testdir.makepyfile(
870            **{
871                "conftest.py": "pytest_plugins = ['core', 'gui', 'sci']",
872                "core.py": "",
873                "gui.py": "pytest_plugins = ['core', 'sci']",
874                "sci.py": "pytest_plugins = ['core']",
875                "test_rewrite_warning_pytest_plugins.py": "def test(): pass",
876            }
877        )
878        testdir.chdir()
879        result = testdir.runpytest_subprocess()
880        result.stdout.fnmatch_lines(["*= 1 passed in *=*"])
881        assert "pytest-warning summary" not in result.stdout.str()
882
883    def test_rewrite_warning_using_pytest_plugins_env_var(self, testdir, monkeypatch):
884        monkeypatch.setenv("PYTEST_PLUGINS", "plugin")
885        testdir.makepyfile(
886            **{
887                "plugin.py": "",
888                "test_rewrite_warning_using_pytest_plugins_env_var.py": """
889                import plugin
890                pytest_plugins = ['plugin']
891                def test():
892                    pass
893            """,
894            }
895        )
896        testdir.chdir()
897        result = testdir.runpytest_subprocess()
898        result.stdout.fnmatch_lines(["*= 1 passed in *=*"])
899        assert "pytest-warning summary" not in result.stdout.str()
900
901    @pytest.mark.skipif(sys.version_info[0] > 2, reason="python 2 only")
902    def test_rewrite_future_imports(self, testdir):
903        """Test that rewritten modules don't inherit the __future__ flags
904        from the assertrewrite module.
905
906        assertion.rewrite imports __future__.division (and others), so
907        ensure rewritten modules don't inherit those flags.
908
909        The test below will fail if __future__.division is enabled
910        """
911        testdir.makepyfile(
912            """
913            def test():
914                x = 1 / 2
915                assert type(x) is int
916        """
917        )
918        result = testdir.runpytest()
919        assert result.ret == 0
920
921
922class TestAssertionRewriteHookDetails(object):
923    def test_loader_is_package_false_for_module(self, testdir):
924        testdir.makepyfile(
925            test_fun="""
926            def test_loader():
927                assert not __loader__.is_package(__name__)
928            """
929        )
930        result = testdir.runpytest()
931        result.stdout.fnmatch_lines(["* 1 passed*"])
932
933    def test_loader_is_package_true_for_package(self, testdir):
934        testdir.makepyfile(
935            test_fun="""
936            def test_loader():
937                assert not __loader__.is_package(__name__)
938
939            def test_fun():
940                assert __loader__.is_package('fun')
941
942            def test_missing():
943                assert not __loader__.is_package('pytest_not_there')
944            """
945        )
946        testdir.mkpydir("fun")
947        result = testdir.runpytest()
948        result.stdout.fnmatch_lines(["* 3 passed*"])
949
950    @pytest.mark.skipif("sys.version_info[0] >= 3")
951    @pytest.mark.xfail("hasattr(sys, 'pypy_translation_info')")
952    def test_assume_ascii(self, testdir):
953        content = "u'\xe2\x99\xa5\x01\xfe'"
954        testdir.tmpdir.join("test_encoding.py").write(content, "wb")
955        res = testdir.runpytest()
956        assert res.ret != 0
957        assert "SyntaxError: Non-ASCII character" in res.stdout.str()
958
959    @pytest.mark.skipif("sys.version_info[0] >= 3")
960    def test_detect_coding_cookie(self, testdir):
961        testdir.makepyfile(
962            test_cookie="""
963            # -*- coding: utf-8 -*-
964            u"St\xc3\xa4d"
965            def test_rewritten():
966                assert "@py_builtins" in globals()"""
967        )
968        assert testdir.runpytest().ret == 0
969
970    @pytest.mark.skipif("sys.version_info[0] >= 3")
971    def test_detect_coding_cookie_second_line(self, testdir):
972        testdir.makepyfile(
973            test_cookie="""
974            # -*- coding: utf-8 -*-
975            u"St\xc3\xa4d"
976            def test_rewritten():
977                assert "@py_builtins" in globals()"""
978        )
979        assert testdir.runpytest().ret == 0
980
981    @pytest.mark.skipif("sys.version_info[0] >= 3")
982    def test_detect_coding_cookie_crlf(self, testdir):
983        testdir.makepyfile(
984            test_cookie="""
985            # -*- coding: utf-8 -*-
986            u"St\xc3\xa4d"
987            def test_rewritten():
988                assert "@py_builtins" in globals()"""
989        )
990        assert testdir.runpytest().ret == 0
991
992    def test_sys_meta_path_munged(self, testdir):
993        testdir.makepyfile(
994            """
995            def test_meta_path():
996                import sys; sys.meta_path = []"""
997        )
998        assert testdir.runpytest().ret == 0
999
1000    def test_write_pyc(self, testdir, tmpdir, monkeypatch):
1001        from _pytest.assertion.rewrite import _write_pyc
1002        from _pytest.assertion import AssertionState
1003        import atomicwrites
1004        from contextlib import contextmanager
1005
1006        config = testdir.parseconfig([])
1007        state = AssertionState(config, "rewrite")
1008        source_path = tmpdir.ensure("source.py")
1009        pycpath = tmpdir.join("pyc").strpath
1010        assert _write_pyc(state, [1], source_path.stat(), pycpath)
1011
1012        @contextmanager
1013        def atomic_write_failed(fn, mode="r", overwrite=False):
1014            e = IOError()
1015            e.errno = 10
1016            raise e
1017            yield
1018
1019        monkeypatch.setattr(atomicwrites, "atomic_write", atomic_write_failed)
1020        assert not _write_pyc(state, [1], source_path.stat(), pycpath)
1021
1022    def test_resources_provider_for_loader(self, testdir):
1023        """
1024        Attempts to load resources from a package should succeed normally,
1025        even when the AssertionRewriteHook is used to load the modules.
1026
1027        See #366 for details.
1028        """
1029        pytest.importorskip("pkg_resources")
1030
1031        testdir.mkpydir("testpkg")
1032        contents = {
1033            "testpkg/test_pkg": """
1034                import pkg_resources
1035
1036                import pytest
1037                from _pytest.assertion.rewrite import AssertionRewritingHook
1038
1039                def test_load_resource():
1040                    assert isinstance(__loader__, AssertionRewritingHook)
1041                    res = pkg_resources.resource_string(__name__, 'resource.txt')
1042                    res = res.decode('ascii')
1043                    assert res == 'Load me please.'
1044                """
1045        }
1046        testdir.makepyfile(**contents)
1047        testdir.maketxtfile(**{"testpkg/resource": "Load me please."})
1048
1049        result = testdir.runpytest_subprocess()
1050        result.assert_outcomes(passed=1)
1051
1052    def test_read_pyc(self, tmpdir):
1053        """
1054        Ensure that the `_read_pyc` can properly deal with corrupted pyc files.
1055        In those circumstances it should just give up instead of generating
1056        an exception that is propagated to the caller.
1057        """
1058        import py_compile
1059        from _pytest.assertion.rewrite import _read_pyc
1060
1061        source = tmpdir.join("source.py")
1062        pyc = source + "c"
1063
1064        source.write("def test(): pass")
1065        py_compile.compile(str(source), str(pyc))
1066
1067        contents = pyc.read(mode="rb")
1068        strip_bytes = 20  # header is around 8 bytes, strip a little more
1069        assert len(contents) > strip_bytes
1070        pyc.write(contents[:strip_bytes], mode="wb")
1071
1072        assert _read_pyc(source, str(pyc)) is None  # no error
1073
1074    def test_reload_is_same(self, testdir):
1075        # A file that will be picked up during collecting.
1076        testdir.tmpdir.join("file.py").ensure()
1077        testdir.tmpdir.join("pytest.ini").write(
1078            textwrap.dedent(
1079                """
1080            [pytest]
1081            python_files = *.py
1082        """
1083            )
1084        )
1085
1086        testdir.makepyfile(
1087            test_fun="""
1088            import sys
1089            try:
1090                from imp import reload
1091            except ImportError:
1092                pass
1093
1094            def test_loader():
1095                import file
1096                assert sys.modules["file"] is reload(file)
1097            """
1098        )
1099        result = testdir.runpytest("-s")
1100        result.stdout.fnmatch_lines(["* 1 passed*"])
1101
1102    def test_reload_reloads(self, testdir):
1103        """Reloading a module after change picks up the change."""
1104        testdir.tmpdir.join("file.py").write(
1105            textwrap.dedent(
1106                """
1107            def reloaded():
1108                return False
1109
1110            def rewrite_self():
1111                with open(__file__, 'w') as self:
1112                    self.write('def reloaded(): return True')
1113        """
1114            )
1115        )
1116        testdir.tmpdir.join("pytest.ini").write(
1117            textwrap.dedent(
1118                """
1119            [pytest]
1120            python_files = *.py
1121        """
1122            )
1123        )
1124
1125        testdir.makepyfile(
1126            test_fun="""
1127            import sys
1128            try:
1129                from imp import reload
1130            except ImportError:
1131                pass
1132
1133            def test_loader():
1134                import file
1135                assert not file.reloaded()
1136                file.rewrite_self()
1137                reload(file)
1138                assert file.reloaded()
1139            """
1140        )
1141        result = testdir.runpytest("-s")
1142        result.stdout.fnmatch_lines(["* 1 passed*"])
1143
1144    def test_get_data_support(self, testdir):
1145        """Implement optional PEP302 api (#808).
1146        """
1147        path = testdir.mkpydir("foo")
1148        path.join("test_foo.py").write(
1149            textwrap.dedent(
1150                """\
1151                class Test(object):
1152                    def test_foo(self):
1153                        import pkgutil
1154                        data = pkgutil.get_data('foo.test_foo', 'data.txt')
1155                        assert data == b'Hey'
1156                """
1157            )
1158        )
1159        path.join("data.txt").write("Hey")
1160        result = testdir.runpytest()
1161        result.stdout.fnmatch_lines(["*1 passed*"])
1162
1163
1164def test_issue731(testdir):
1165    testdir.makepyfile(
1166        """
1167    class LongReprWithBraces(object):
1168        def __repr__(self):
1169           return 'LongReprWithBraces({' + ('a' * 80) + '}' + ('a' * 120) + ')'
1170
1171        def some_method(self):
1172            return False
1173
1174    def test_long_repr():
1175        obj = LongReprWithBraces()
1176        assert obj.some_method()
1177    """
1178    )
1179    result = testdir.runpytest()
1180    assert "unbalanced braces" not in result.stdout.str()
1181
1182
1183class TestIssue925(object):
1184    def test_simple_case(self, testdir):
1185        testdir.makepyfile(
1186            """
1187        def test_ternary_display():
1188            assert (False == False) == False
1189        """
1190        )
1191        result = testdir.runpytest()
1192        result.stdout.fnmatch_lines(["*E*assert (False == False) == False"])
1193
1194    def test_long_case(self, testdir):
1195        testdir.makepyfile(
1196            """
1197        def test_ternary_display():
1198             assert False == (False == True) == True
1199        """
1200        )
1201        result = testdir.runpytest()
1202        result.stdout.fnmatch_lines(["*E*assert (False == True) == True"])
1203
1204    def test_many_brackets(self, testdir):
1205        testdir.makepyfile(
1206            """
1207            def test_ternary_display():
1208                 assert True == ((False == True) == True)
1209            """
1210        )
1211        result = testdir.runpytest()
1212        result.stdout.fnmatch_lines(["*E*assert True == ((False == True) == True)"])
1213
1214
1215class TestIssue2121:
1216    def test_rewrite_python_files_contain_subdirs(self, testdir):
1217        testdir.makepyfile(
1218            **{
1219                "tests/file.py": """
1220                def test_simple_failure():
1221                    assert 1 + 1 == 3
1222                """
1223            }
1224        )
1225        testdir.makeini(
1226            """
1227                [pytest]
1228                python_files = tests/**.py
1229            """
1230        )
1231        result = testdir.runpytest()
1232        result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"])
1233
1234
1235@pytest.mark.skipif(
1236    sys.maxsize <= (2 ** 31 - 1), reason="Causes OverflowError on 32bit systems"
1237)
1238@pytest.mark.parametrize("offset", [-1, +1])
1239def test_source_mtime_long_long(testdir, offset):
1240    """Support modification dates after 2038 in rewritten files (#4903).
1241
1242    pytest would crash with:
1243
1244            fp.write(struct.pack("<ll", mtime, size))
1245        E   struct.error: argument out of range
1246    """
1247    p = testdir.makepyfile(
1248        """
1249        def test(): pass
1250    """
1251    )
1252    # use unsigned long timestamp which overflows signed long,
1253    # which was the cause of the bug
1254    # +1 offset also tests masking of 0xFFFFFFFF
1255    timestamp = 2 ** 32 + offset
1256    os.utime(str(p), (timestamp, timestamp))
1257    result = testdir.runpytest()
1258    assert result.ret == 0
1259
1260
1261def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch):
1262    """Fix infinite recursion when writing pyc files: if an import happens to be triggered when writing the pyc
1263    file, this would cause another call to the hook, which would trigger another pyc writing, which could
1264    trigger another import, and so on. (#3506)"""
1265    from _pytest.assertion import rewrite
1266
1267    testdir.syspathinsert()
1268    testdir.makepyfile(test_foo="def test_foo(): pass")
1269    testdir.makepyfile(test_bar="def test_bar(): pass")
1270
1271    original_write_pyc = rewrite._write_pyc
1272
1273    write_pyc_called = []
1274
1275    def spy_write_pyc(*args, **kwargs):
1276        # make a note that we have called _write_pyc
1277        write_pyc_called.append(True)
1278        # try to import a module at this point: we should not try to rewrite this module
1279        assert hook.find_module("test_bar") is None
1280        return original_write_pyc(*args, **kwargs)
1281
1282    monkeypatch.setattr(rewrite, "_write_pyc", spy_write_pyc)
1283    monkeypatch.setattr(sys, "dont_write_bytecode", False)
1284
1285    hook = AssertionRewritingHook(pytestconfig)
1286    assert hook.find_module("test_foo") is not None
1287    assert len(write_pyc_called) == 1
1288
1289
1290class TestEarlyRewriteBailout(object):
1291    @pytest.fixture
1292    def hook(self, pytestconfig, monkeypatch, testdir):
1293        """Returns a patched AssertionRewritingHook instance so we can configure its initial paths and track
1294        if imp.find_module has been called.
1295        """
1296        import imp
1297
1298        self.find_module_calls = []
1299        self.initial_paths = set()
1300
1301        class StubSession(object):
1302            _initialpaths = self.initial_paths
1303
1304            def isinitpath(self, p):
1305                return p in self._initialpaths
1306
1307        def spy_imp_find_module(name, path):
1308            self.find_module_calls.append(name)
1309            return imp.find_module(name, path)
1310
1311        hook = AssertionRewritingHook(pytestconfig)
1312        # use default patterns, otherwise we inherit pytest's testing config
1313        hook.fnpats[:] = ["test_*.py", "*_test.py"]
1314        monkeypatch.setattr(hook, "_imp_find_module", spy_imp_find_module)
1315        hook.set_session(StubSession())
1316        testdir.syspathinsert()
1317        return hook
1318
1319    def test_basic(self, testdir, hook):
1320        """
1321        Ensure we avoid calling imp.find_module when we know for sure a certain module will not be rewritten
1322        to optimize assertion rewriting (#3918).
1323        """
1324        testdir.makeconftest(
1325            """
1326            import pytest
1327            @pytest.fixture
1328            def fix(): return 1
1329        """
1330        )
1331        testdir.makepyfile(test_foo="def test_foo(): pass")
1332        testdir.makepyfile(bar="def bar(): pass")
1333        foobar_path = testdir.makepyfile(foobar="def foobar(): pass")
1334        self.initial_paths.add(foobar_path)
1335
1336        # conftest files should always be rewritten
1337        assert hook.find_module("conftest") is not None
1338        assert self.find_module_calls == ["conftest"]
1339
1340        # files matching "python_files" mask should always be rewritten
1341        assert hook.find_module("test_foo") is not None
1342        assert self.find_module_calls == ["conftest", "test_foo"]
1343
1344        # file does not match "python_files": early bailout
1345        assert hook.find_module("bar") is None
1346        assert self.find_module_calls == ["conftest", "test_foo"]
1347
1348        # file is an initial path (passed on the command-line): should be rewritten
1349        assert hook.find_module("foobar") is not None
1350        assert self.find_module_calls == ["conftest", "test_foo", "foobar"]
1351
1352    def test_pattern_contains_subdirectories(self, testdir, hook):
1353        """If one of the python_files patterns contain subdirectories ("tests/**.py") we can't bailout early
1354        because we need to match with the full path, which can only be found by calling imp.find_module.
1355        """
1356        p = testdir.makepyfile(
1357            **{
1358                "tests/file.py": """
1359                        def test_simple_failure():
1360                            assert 1 + 1 == 3
1361                        """
1362            }
1363        )
1364        testdir.syspathinsert(p.dirpath())
1365        hook.fnpats[:] = ["tests/**.py"]
1366        assert hook.find_module("file") is not None
1367        assert self.find_module_calls == ["file"]
1368
1369    @pytest.mark.skipif(
1370        sys.platform.startswith("win32"), reason="cannot remove cwd on Windows"
1371    )
1372    def test_cwd_changed(self, testdir, monkeypatch):
1373        # Setup conditions for py's fspath trying to import pathlib on py34
1374        # always (previously triggered via xdist only).
1375        # Ref: https://github.com/pytest-dev/py/pull/207
1376        monkeypatch.syspath_prepend("")
1377        monkeypatch.delitem(sys.modules, "pathlib", raising=False)
1378
1379        testdir.makepyfile(
1380            **{
1381                "test_setup_nonexisting_cwd.py": """
1382                import os
1383                import shutil
1384                import tempfile
1385
1386                d = tempfile.mkdtemp()
1387                os.chdir(d)
1388                shutil.rmtree(d)
1389            """,
1390                "test_test.py": """
1391                def test():
1392                    pass
1393            """,
1394            }
1395        )
1396        result = testdir.runpytest()
1397        result.stdout.fnmatch_lines(["* 1 passed in *"])
1398