1import os
2import pytest
3import re
4
5from xdis import disassemble_file
6from xdis import PYTHON3, PYTHON_VERSION
7
8if PYTHON3:
9    from io import StringIO
10
11    hextring_file = "testdata/01_hexstring-2.7-for3x.right"
12else:
13    from StringIO import StringIO
14
15    hextring_file = "testdata/01_hexstring-2.7.right"
16
17
18def get_srcdir():
19    filename = os.path.normcase(os.path.dirname(__file__))
20    return os.path.realpath(filename)
21
22
23if PYTHON_VERSION >= 3.2:
24    @pytest.mark.parametrize(
25        ("test_tuple", "function_to_test"),
26        [
27            (
28                ("../test/bytecode_3.6/01_fstring.pyc", "testdata/fstring-3.6.right"),
29                disassemble_file,
30            ),
31            (
32                ("../test/bytecode_3.0/04_raise.pyc", "testdata/raise-3.0.right"),
33                disassemble_file,
34            ),
35            (
36                (
37                    "../test/bytecode_2.7pypy/04_pypy_lambda.pyc",
38                    "testdata/pypy_lambda.right",
39                ),
40                disassemble_file,
41            ),
42            (
43                ("../test/bytecode_3.6/03_big_dict.pyc", "testdata/big_dict-3.6.right"),
44                disassemble_file,
45            ),
46            (("../test/bytecode_2.7/01_hexstring.pyc", hextring_file), disassemble_file),
47        ],
48    )
49    def test_funcoutput(capfd, test_tuple, function_to_test):
50        in_file, filename_expected = [os.path.join(get_srcdir(), p) for p in test_tuple]
51        resout = StringIO()
52        function_to_test(in_file, resout)
53        expected = "".join(open(filename_expected, "r").readlines())
54        got_lines = resout.getvalue().split("\n")
55        got_lines = [
56            re.sub(" at 0x[0-9a-f]+", " at 0xdeadbeef0000", line) for line in got_lines
57        ]
58        got_lines = [
59            re.sub(
60                "<code object .*>|<Code.+ code object .*>",
61                "<code object at 0xdeadbeef0000>",
62                line,
63            )
64            for line in got_lines
65        ]
66        got = "\n".join(got_lines[5:])
67
68        if "XDIS_DONT_WRITE_DOT_GOT_FILES" not in os.environ:
69            if got != expected:
70                with open(filename_expected + ".got", "w") as out:
71                    out.write(got)
72        assert got == expected
73
74if __name__ == "__main__":
75    import sys
76    test_funcoutput(sys.stdout, ("../test/bytecode_3.0/04_raise.pyc", "testdata/raise-3.0.right"),
77                    disassemble_file)
78