1"""This module includes tests of the code object representation.
2
3>>> def f(x):
4...     def g(y):
5...         return x + y
6...     return g
7...
8
9>>> dump(f.__code__)
10name: f
11argcount: 1
12posonlyargcount: 0
13kwonlyargcount: 0
14names: ()
15varnames: ('x', 'g')
16cellvars: ('x',)
17freevars: ()
18nlocals: 2
19flags: 3
20consts: ('None', '<code object g>')
21
22>>> dump(f(4).__code__)
23name: g
24argcount: 1
25posonlyargcount: 0
26kwonlyargcount: 0
27names: ()
28varnames: ('y',)
29cellvars: ()
30freevars: ('x',)
31nlocals: 1
32flags: 19
33consts: ('None',)
34
35>>> def h(x, y):
36...     a = x + y
37...     b = x - y
38...     c = a * b
39...     return c
40...
41
42>>> dump(h.__code__)
43name: h
44argcount: 2
45posonlyargcount: 0
46kwonlyargcount: 0
47names: ()
48varnames: ('x', 'y', 'a', 'b', 'c')
49cellvars: ()
50freevars: ()
51nlocals: 5
52flags: 3
53consts: ('None',)
54
55>>> def attrs(obj):
56...     print(obj.attr1)
57...     print(obj.attr2)
58...     print(obj.attr3)
59
60>>> dump(attrs.__code__)
61name: attrs
62argcount: 1
63posonlyargcount: 0
64kwonlyargcount: 0
65names: ('print', 'attr1', 'attr2', 'attr3')
66varnames: ('obj',)
67cellvars: ()
68freevars: ()
69nlocals: 1
70flags: 3
71consts: ('None',)
72
73>>> def optimize_away():
74...     'doc string'
75...     'not a docstring'
76...     53
77...     0x53
78
79>>> dump(optimize_away.__code__)
80name: optimize_away
81argcount: 0
82posonlyargcount: 0
83kwonlyargcount: 0
84names: ()
85varnames: ()
86cellvars: ()
87freevars: ()
88nlocals: 0
89flags: 3
90consts: ("'doc string'", 'None')
91
92>>> def keywordonly_args(a,b,*,k1):
93...     return a,b,k1
94...
95
96>>> dump(keywordonly_args.__code__)
97name: keywordonly_args
98argcount: 2
99posonlyargcount: 0
100kwonlyargcount: 1
101names: ()
102varnames: ('a', 'b', 'k1')
103cellvars: ()
104freevars: ()
105nlocals: 3
106flags: 3
107consts: ('None',)
108
109>>> def posonly_args(a,b,/,c):
110...     return a,b,c
111...
112
113>>> dump(posonly_args.__code__)
114name: posonly_args
115argcount: 3
116posonlyargcount: 2
117kwonlyargcount: 0
118names: ()
119varnames: ('a', 'b', 'c')
120cellvars: ()
121freevars: ()
122nlocals: 3
123flags: 3
124consts: ('None',)
125
126"""
127
128import inspect
129import sys
130import threading
131import doctest
132import unittest
133import textwrap
134import weakref
135
136try:
137    import ctypes
138except ImportError:
139    ctypes = None
140from test.support import (cpython_only,
141                          check_impl_detail, requires_debug_ranges,
142                          gc_collect)
143from test.support.script_helper import assert_python_ok
144from opcode import opmap
145COPY_FREE_VARS = opmap['COPY_FREE_VARS']
146
147
148def consts(t):
149    """Yield a doctest-safe sequence of object reprs."""
150    for elt in t:
151        r = repr(elt)
152        if r.startswith("<code object"):
153            yield "<code object %s>" % elt.co_name
154        else:
155            yield r
156
157def dump(co):
158    """Print out a text representation of a code object."""
159    for attr in ["name", "argcount", "posonlyargcount",
160                 "kwonlyargcount", "names", "varnames",
161                 "cellvars", "freevars", "nlocals", "flags"]:
162        print("%s: %s" % (attr, getattr(co, "co_" + attr)))
163    print("consts:", tuple(consts(co.co_consts)))
164
165# Needed for test_closure_injection below
166# Defined at global scope to avoid implicitly closing over __class__
167def external_getitem(self, i):
168    return f"Foreign getitem: {super().__getitem__(i)}"
169
170class CodeTest(unittest.TestCase):
171
172    @cpython_only
173    def test_newempty(self):
174        import _testcapi
175        co = _testcapi.code_newempty("filename", "funcname", 15)
176        self.assertEqual(co.co_filename, "filename")
177        self.assertEqual(co.co_name, "funcname")
178        self.assertEqual(co.co_firstlineno, 15)
179
180    @cpython_only
181    def test_closure_injection(self):
182        # From https://bugs.python.org/issue32176
183        from types import FunctionType
184
185        def create_closure(__class__):
186            return (lambda: __class__).__closure__
187
188        def new_code(c):
189            '''A new code object with a __class__ cell added to freevars'''
190            return c.replace(co_freevars=c.co_freevars + ('__class__',), co_code=bytes([COPY_FREE_VARS, 1])+c.co_code)
191
192        def add_foreign_method(cls, name, f):
193            code = new_code(f.__code__)
194            assert not f.__closure__
195            closure = create_closure(cls)
196            defaults = f.__defaults__
197            setattr(cls, name, FunctionType(code, globals(), name, defaults, closure))
198
199        class List(list):
200            pass
201
202        add_foreign_method(List, "__getitem__", external_getitem)
203
204        # Ensure the closure injection actually worked
205        function = List.__getitem__
206        class_ref = function.__closure__[0].cell_contents
207        self.assertIs(class_ref, List)
208
209        # Ensure the zero-arg super() call in the injected method works
210        obj = List([1, 2, 3])
211        self.assertEqual(obj[0], "Foreign getitem: 1")
212
213    def test_constructor(self):
214        def func(): pass
215        co = func.__code__
216        CodeType = type(co)
217
218        # test code constructor
219        CodeType(co.co_argcount,
220                        co.co_posonlyargcount,
221                        co.co_kwonlyargcount,
222                        co.co_nlocals,
223                        co.co_stacksize,
224                        co.co_flags,
225                        co.co_code,
226                        co.co_consts,
227                        co.co_names,
228                        co.co_varnames,
229                        co.co_filename,
230                        co.co_name,
231                        co.co_qualname,
232                        co.co_firstlineno,
233                        co.co_lnotab,
234                        co.co_endlinetable,
235                        co.co_columntable,
236                        co.co_exceptiontable,
237                        co.co_freevars,
238                        co.co_cellvars)
239
240    def test_qualname(self):
241        self.assertEqual(
242            CodeTest.test_qualname.__code__.co_qualname,
243            CodeTest.test_qualname.__qualname__
244        )
245
246    def test_replace(self):
247        def func():
248            x = 1
249            return x
250        code = func.__code__
251
252        # different co_name, co_varnames, co_consts
253        def func2():
254            y = 2
255            z = 3
256            return y
257        code2 = func2.__code__
258
259        for attr, value in (
260            ("co_argcount", 0),
261            ("co_posonlyargcount", 0),
262            ("co_kwonlyargcount", 0),
263            ("co_nlocals", 1),
264            ("co_stacksize", 0),
265            ("co_flags", code.co_flags | inspect.CO_COROUTINE),
266            ("co_firstlineno", 100),
267            ("co_code", code2.co_code),
268            ("co_consts", code2.co_consts),
269            ("co_names", ("myname",)),
270            ("co_varnames", ('spam',)),
271            ("co_freevars", ("freevar",)),
272            ("co_cellvars", ("cellvar",)),
273            ("co_filename", "newfilename"),
274            ("co_name", "newname"),
275            ("co_linetable", code2.co_linetable),
276            ("co_endlinetable", code2.co_endlinetable),
277            ("co_columntable", code2.co_columntable),
278        ):
279            with self.subTest(attr=attr, value=value):
280                new_code = code.replace(**{attr: value})
281                self.assertEqual(getattr(new_code, attr), value)
282
283        new_code = code.replace(co_varnames=code2.co_varnames,
284                                co_nlocals=code2.co_nlocals)
285        self.assertEqual(new_code.co_varnames, code2.co_varnames)
286        self.assertEqual(new_code.co_nlocals, code2.co_nlocals)
287
288    def test_nlocals_mismatch(self):
289        def func():
290            x = 1
291            return x
292        co = func.__code__
293        assert co.co_nlocals > 0;
294
295        # First we try the constructor.
296        CodeType = type(co)
297        for diff in (-1, 1):
298            with self.assertRaises(ValueError):
299                CodeType(co.co_argcount,
300                         co.co_posonlyargcount,
301                         co.co_kwonlyargcount,
302                         # This is the only change.
303                         co.co_nlocals + diff,
304                         co.co_stacksize,
305                         co.co_flags,
306                         co.co_code,
307                         co.co_consts,
308                         co.co_names,
309                         co.co_varnames,
310                         co.co_filename,
311                         co.co_name,
312                         co.co_qualname,
313                         co.co_firstlineno,
314                         co.co_lnotab,
315                         co.co_endlinetable,
316                         co.co_columntable,
317                         co.co_exceptiontable,
318                         co.co_freevars,
319                         co.co_cellvars,
320                         )
321        # Then we try the replace method.
322        with self.assertRaises(ValueError):
323            co.replace(co_nlocals=co.co_nlocals - 1)
324        with self.assertRaises(ValueError):
325            co.replace(co_nlocals=co.co_nlocals + 1)
326
327    def test_shrinking_localsplus(self):
328        # Check that PyCode_NewWithPosOnlyArgs resizes both
329        # localsplusnames and localspluskinds, if an argument is a cell.
330        def func(arg):
331            return lambda: arg
332        code = func.__code__
333        newcode = code.replace(co_name="func")  # Should not raise SystemError
334        self.assertEqual(code, newcode)
335
336    def test_empty_linetable(self):
337        def func():
338            pass
339        new_code = code = func.__code__.replace(co_linetable=b'')
340        self.assertEqual(list(new_code.co_lines()), [])
341
342    @requires_debug_ranges()
343    def test_co_positions_artificial_instructions(self):
344        import dis
345
346        namespace = {}
347        exec(textwrap.dedent("""\
348        try:
349            1/0
350        except Exception as e:
351            exc = e
352        """), namespace)
353
354        exc = namespace['exc']
355        traceback = exc.__traceback__
356        code = traceback.tb_frame.f_code
357
358        artificial_instructions = []
359        for instr, positions in zip(
360            dis.get_instructions(code),
361            code.co_positions(),
362            strict=True
363        ):
364            # If any of the positions is None, then all have to
365            # be None as well for the case above. There are still
366            # some places in the compiler, where the artificial instructions
367            # get assigned the first_lineno but they don't have other positions.
368            # There is no easy way of inferring them at that stage, so for now
369            # we don't support it.
370            self.assertTrue(positions.count(None) in [0, 4])
371
372            if not any(positions):
373                artificial_instructions.append(instr)
374
375        self.assertEqual(
376            [
377                (instruction.opname, instruction.argval)
378                for instruction in artificial_instructions
379            ],
380            [
381                ("PUSH_EXC_INFO", None),
382                ("LOAD_CONST", None), # artificial 'None'
383                ("STORE_NAME", "e"),  # XX: we know the location for this
384                ("DELETE_NAME", "e"),
385                ("RERAISE", 1),
386                ("POP_EXCEPT_AND_RERAISE", None)
387            ]
388        )
389
390    def test_endline_and_columntable_none_when_no_debug_ranges(self):
391        # Make sure that if `-X no_debug_ranges` is used, the endlinetable and
392        # columntable are None.
393        code = textwrap.dedent("""
394            def f():
395                pass
396
397            assert f.__code__.co_endlinetable is None
398            assert f.__code__.co_columntable is None
399            """)
400        assert_python_ok('-X', 'no_debug_ranges', '-c', code)
401
402    def test_endline_and_columntable_none_when_no_debug_ranges_env(self):
403        # Same as above but using the environment variable opt out.
404        code = textwrap.dedent("""
405            def f():
406                pass
407
408            assert f.__code__.co_endlinetable is None
409            assert f.__code__.co_columntable is None
410            """)
411        assert_python_ok('-c', code, PYTHONNODEBUGRANGES='1')
412
413    # co_positions behavior when info is missing.
414
415    @requires_debug_ranges()
416    def test_co_positions_empty_linetable(self):
417        def func():
418            x = 1
419        new_code = func.__code__.replace(co_linetable=b'')
420        for line, end_line, column, end_column in new_code.co_positions():
421            self.assertIsNone(line)
422            self.assertEqual(end_line, new_code.co_firstlineno + 1)
423
424    @requires_debug_ranges()
425    def test_co_positions_empty_endlinetable(self):
426        def func():
427            x = 1
428        new_code = func.__code__.replace(co_endlinetable=b'')
429        for line, end_line, column, end_column in new_code.co_positions():
430            self.assertEqual(line, new_code.co_firstlineno + 1)
431            self.assertIsNone(end_line)
432
433    @requires_debug_ranges()
434    def test_co_positions_empty_columntable(self):
435        def func():
436            x = 1
437        new_code = func.__code__.replace(co_columntable=b'')
438        for line, end_line, column, end_column in new_code.co_positions():
439            self.assertEqual(line, new_code.co_firstlineno + 1)
440            self.assertEqual(end_line, new_code.co_firstlineno + 1)
441            self.assertIsNone(column)
442            self.assertIsNone(end_column)
443
444
445def isinterned(s):
446    return s is sys.intern(('_' + s + '_')[1:-1])
447
448class CodeConstsTest(unittest.TestCase):
449
450    def find_const(self, consts, value):
451        for v in consts:
452            if v == value:
453                return v
454        self.assertIn(value, consts)  # raises an exception
455        self.fail('Should never be reached')
456
457    def assertIsInterned(self, s):
458        if not isinterned(s):
459            self.fail('String %r is not interned' % (s,))
460
461    def assertIsNotInterned(self, s):
462        if isinterned(s):
463            self.fail('String %r is interned' % (s,))
464
465    @cpython_only
466    def test_interned_string(self):
467        co = compile('res = "str_value"', '?', 'exec')
468        v = self.find_const(co.co_consts, 'str_value')
469        self.assertIsInterned(v)
470
471    @cpython_only
472    def test_interned_string_in_tuple(self):
473        co = compile('res = ("str_value",)', '?', 'exec')
474        v = self.find_const(co.co_consts, ('str_value',))
475        self.assertIsInterned(v[0])
476
477    @cpython_only
478    def test_interned_string_in_frozenset(self):
479        co = compile('res = a in {"str_value"}', '?', 'exec')
480        v = self.find_const(co.co_consts, frozenset(('str_value',)))
481        self.assertIsInterned(tuple(v)[0])
482
483    @cpython_only
484    def test_interned_string_default(self):
485        def f(a='str_value'):
486            return a
487        self.assertIsInterned(f())
488
489    @cpython_only
490    def test_interned_string_with_null(self):
491        co = compile(r'res = "str\0value!"', '?', 'exec')
492        v = self.find_const(co.co_consts, 'str\0value!')
493        self.assertIsNotInterned(v)
494
495
496class CodeWeakRefTest(unittest.TestCase):
497
498    def test_basic(self):
499        # Create a code object in a clean environment so that we know we have
500        # the only reference to it left.
501        namespace = {}
502        exec("def f(): pass", globals(), namespace)
503        f = namespace["f"]
504        del namespace
505
506        self.called = False
507        def callback(code):
508            self.called = True
509
510        # f is now the last reference to the function, and through it, the code
511        # object.  While we hold it, check that we can create a weakref and
512        # deref it.  Then delete it, and check that the callback gets called and
513        # the reference dies.
514        coderef = weakref.ref(f.__code__, callback)
515        self.assertTrue(bool(coderef()))
516        del f
517        gc_collect()  # For PyPy or other GCs.
518        self.assertFalse(bool(coderef()))
519        self.assertTrue(self.called)
520
521
522if check_impl_detail(cpython=True) and ctypes is not None:
523    py = ctypes.pythonapi
524    freefunc = ctypes.CFUNCTYPE(None,ctypes.c_voidp)
525
526    RequestCodeExtraIndex = py._PyEval_RequestCodeExtraIndex
527    RequestCodeExtraIndex.argtypes = (freefunc,)
528    RequestCodeExtraIndex.restype = ctypes.c_ssize_t
529
530    SetExtra = py._PyCode_SetExtra
531    SetExtra.argtypes = (ctypes.py_object, ctypes.c_ssize_t, ctypes.c_voidp)
532    SetExtra.restype = ctypes.c_int
533
534    GetExtra = py._PyCode_GetExtra
535    GetExtra.argtypes = (ctypes.py_object, ctypes.c_ssize_t,
536                         ctypes.POINTER(ctypes.c_voidp))
537    GetExtra.restype = ctypes.c_int
538
539    LAST_FREED = None
540    def myfree(ptr):
541        global LAST_FREED
542        LAST_FREED = ptr
543
544    FREE_FUNC = freefunc(myfree)
545    FREE_INDEX = RequestCodeExtraIndex(FREE_FUNC)
546
547    class CoExtra(unittest.TestCase):
548        def get_func(self):
549            # Defining a function causes the containing function to have a
550            # reference to the code object.  We need the code objects to go
551            # away, so we eval a lambda.
552            return eval('lambda:42')
553
554        def test_get_non_code(self):
555            f = self.get_func()
556
557            self.assertRaises(SystemError, SetExtra, 42, FREE_INDEX,
558                              ctypes.c_voidp(100))
559            self.assertRaises(SystemError, GetExtra, 42, FREE_INDEX,
560                              ctypes.c_voidp(100))
561
562        def test_bad_index(self):
563            f = self.get_func()
564            self.assertRaises(SystemError, SetExtra, f.__code__,
565                              FREE_INDEX+100, ctypes.c_voidp(100))
566            self.assertEqual(GetExtra(f.__code__, FREE_INDEX+100,
567                              ctypes.c_voidp(100)), 0)
568
569        def test_free_called(self):
570            # Verify that the provided free function gets invoked
571            # when the code object is cleaned up.
572            f = self.get_func()
573
574            SetExtra(f.__code__, FREE_INDEX, ctypes.c_voidp(100))
575            del f
576            self.assertEqual(LAST_FREED, 100)
577
578        def test_get_set(self):
579            # Test basic get/set round tripping.
580            f = self.get_func()
581
582            extra = ctypes.c_voidp()
583
584            SetExtra(f.__code__, FREE_INDEX, ctypes.c_voidp(200))
585            # reset should free...
586            SetExtra(f.__code__, FREE_INDEX, ctypes.c_voidp(300))
587            self.assertEqual(LAST_FREED, 200)
588
589            extra = ctypes.c_voidp()
590            GetExtra(f.__code__, FREE_INDEX, extra)
591            self.assertEqual(extra.value, 300)
592            del f
593
594        def test_free_different_thread(self):
595            # Freeing a code object on a different thread then
596            # where the co_extra was set should be safe.
597            f = self.get_func()
598            class ThreadTest(threading.Thread):
599                def __init__(self, f, test):
600                    super().__init__()
601                    self.f = f
602                    self.test = test
603                def run(self):
604                    del self.f
605                    self.test.assertEqual(LAST_FREED, 500)
606
607            SetExtra(f.__code__, FREE_INDEX, ctypes.c_voidp(500))
608            tt = ThreadTest(f, self)
609            del f
610            tt.start()
611            tt.join()
612            self.assertEqual(LAST_FREED, 500)
613
614
615def load_tests(loader, tests, pattern):
616    tests.addTest(doctest.DocTestSuite())
617    return tests
618
619
620if __name__ == "__main__":
621    unittest.main()
622