1import dis
2import unittest
3
4from test.bytecode_helper import BytecodeTestCase
5
6
7def count_instr_recursively(f, opname):
8    count = 0
9    for instr in dis.get_instructions(f):
10        if instr.opname == opname:
11            count += 1
12    if hasattr(f, '__code__'):
13        f = f.__code__
14    for c in f.co_consts:
15        if hasattr(c, 'co_code'):
16            count += count_instr_recursively(c, opname)
17    return count
18
19
20class TestTranforms(BytecodeTestCase):
21
22    def check_jump_targets(self, code):
23        instructions = list(dis.get_instructions(code))
24        targets = {instr.offset: instr for instr in instructions}
25        for instr in instructions:
26            if 'JUMP_' not in instr.opname:
27                continue
28            tgt = targets[instr.argval]
29            # jump to unconditional jump
30            if tgt.opname in ('JUMP_ABSOLUTE', 'JUMP_FORWARD'):
31                self.fail(f'{instr.opname} at {instr.offset} '
32                          f'jumps to {tgt.opname} at {tgt.offset}')
33            # unconditional jump to RETURN_VALUE
34            if (instr.opname in ('JUMP_ABSOLUTE', 'JUMP_FORWARD') and
35                tgt.opname == 'RETURN_VALUE'):
36                self.fail(f'{instr.opname} at {instr.offset} '
37                          f'jumps to {tgt.opname} at {tgt.offset}')
38            # JUMP_IF_*_OR_POP jump to conditional jump
39            if '_OR_POP' in instr.opname and 'JUMP_IF_' in tgt.opname:
40                self.fail(f'{instr.opname} at {instr.offset} '
41                          f'jumps to {tgt.opname} at {tgt.offset}')
42
43    def check_lnotab(self, code):
44        "Check that the lnotab byte offsets are sensible."
45        code = dis._get_code_object(code)
46        lnotab = list(dis.findlinestarts(code))
47        # Don't bother checking if the line info is sensible, because
48        # most of the line info we can get at comes from lnotab.
49        min_bytecode = min(t[0] for t in lnotab)
50        max_bytecode = max(t[0] for t in lnotab)
51        self.assertGreaterEqual(min_bytecode, 0)
52        self.assertLess(max_bytecode, len(code.co_code))
53        # This could conceivably test more (and probably should, as there
54        # aren't very many tests of lnotab), if peepholer wasn't scheduled
55        # to be replaced anyway.
56
57    def test_unot(self):
58        # UNARY_NOT POP_JUMP_IF_FALSE  -->  POP_JUMP_IF_TRUE'
59        def unot(x):
60            if not x == 2:
61                del x
62        self.assertNotInBytecode(unot, 'UNARY_NOT')
63        self.assertNotInBytecode(unot, 'POP_JUMP_IF_FALSE')
64        self.assertInBytecode(unot, 'POP_JUMP_IF_TRUE')
65        self.check_lnotab(unot)
66
67    def test_elim_inversion_of_is_or_in(self):
68        for line, cmp_op in (
69            ('not a is b', 'is not',),
70            ('not a in b', 'not in',),
71            ('not a is not b', 'is',),
72            ('not a not in b', 'in',),
73            ):
74            code = compile(line, '', 'single')
75            self.assertInBytecode(code, 'COMPARE_OP', cmp_op)
76            self.check_lnotab(code)
77
78    def test_global_as_constant(self):
79        # LOAD_GLOBAL None/True/False  -->  LOAD_CONST None/True/False
80        def f():
81            x = None
82            x = None
83            return x
84        def g():
85            x = True
86            return x
87        def h():
88            x = False
89            return x
90
91        for func, elem in ((f, None), (g, True), (h, False)):
92            self.assertNotInBytecode(func, 'LOAD_GLOBAL')
93            self.assertInBytecode(func, 'LOAD_CONST', elem)
94            self.check_lnotab(func)
95
96        def f():
97            'Adding a docstring made this test fail in Py2.5.0'
98            return None
99
100        self.assertNotInBytecode(f, 'LOAD_GLOBAL')
101        self.assertInBytecode(f, 'LOAD_CONST', None)
102        self.check_lnotab(f)
103
104    def test_while_one(self):
105        # Skip over:  LOAD_CONST trueconst  POP_JUMP_IF_FALSE xx
106        def f():
107            while 1:
108                pass
109            return list
110        for elem in ('LOAD_CONST', 'POP_JUMP_IF_FALSE'):
111            self.assertNotInBytecode(f, elem)
112        for elem in ('JUMP_ABSOLUTE',):
113            self.assertInBytecode(f, elem)
114        self.check_lnotab(f)
115
116    def test_pack_unpack(self):
117        for line, elem in (
118            ('a, = a,', 'LOAD_CONST',),
119            ('a, b = a, b', 'ROT_TWO',),
120            ('a, b, c = a, b, c', 'ROT_THREE',),
121            ):
122            code = compile(line,'','single')
123            self.assertInBytecode(code, elem)
124            self.assertNotInBytecode(code, 'BUILD_TUPLE')
125            self.assertNotInBytecode(code, 'UNPACK_TUPLE')
126            self.check_lnotab(code)
127
128    def test_folding_of_tuples_of_constants(self):
129        for line, elem in (
130            ('a = 1,2,3', (1, 2, 3)),
131            ('("a","b","c")', ('a', 'b', 'c')),
132            ('a,b,c = 1,2,3', (1, 2, 3)),
133            ('(None, 1, None)', (None, 1, None)),
134            ('((1, 2), 3, 4)', ((1, 2), 3, 4)),
135            ):
136            code = compile(line,'','single')
137            self.assertInBytecode(code, 'LOAD_CONST', elem)
138            self.assertNotInBytecode(code, 'BUILD_TUPLE')
139            self.check_lnotab(code)
140
141        # Long tuples should be folded too.
142        code = compile(repr(tuple(range(10000))),'','single')
143        self.assertNotInBytecode(code, 'BUILD_TUPLE')
144        # One LOAD_CONST for the tuple, one for the None return value
145        load_consts = [instr for instr in dis.get_instructions(code)
146                              if instr.opname == 'LOAD_CONST']
147        self.assertEqual(len(load_consts), 2)
148        self.check_lnotab(code)
149
150        # Bug 1053819:  Tuple of constants misidentified when presented with:
151        # . . . opcode_with_arg 100   unary_opcode   BUILD_TUPLE 1  . . .
152        # The following would segfault upon compilation
153        def crater():
154            (~[
155                0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
156                0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
157                0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
158                0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
159                0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
160                0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
161                0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
162                0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
163                0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
164                0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
165            ],)
166        self.check_lnotab(crater)
167
168    def test_folding_of_lists_of_constants(self):
169        for line, elem in (
170            # in/not in constants with BUILD_LIST should be folded to a tuple:
171            ('a in [1,2,3]', (1, 2, 3)),
172            ('a not in ["a","b","c"]', ('a', 'b', 'c')),
173            ('a in [None, 1, None]', (None, 1, None)),
174            ('a not in [(1, 2), 3, 4]', ((1, 2), 3, 4)),
175            ):
176            code = compile(line, '', 'single')
177            self.assertInBytecode(code, 'LOAD_CONST', elem)
178            self.assertNotInBytecode(code, 'BUILD_LIST')
179            self.check_lnotab(code)
180
181    def test_folding_of_sets_of_constants(self):
182        for line, elem in (
183            # in/not in constants with BUILD_SET should be folded to a frozenset:
184            ('a in {1,2,3}', frozenset({1, 2, 3})),
185            ('a not in {"a","b","c"}', frozenset({'a', 'c', 'b'})),
186            ('a in {None, 1, None}', frozenset({1, None})),
187            ('a not in {(1, 2), 3, 4}', frozenset({(1, 2), 3, 4})),
188            ('a in {1, 2, 3, 3, 2, 1}', frozenset({1, 2, 3})),
189            ):
190            code = compile(line, '', 'single')
191            self.assertNotInBytecode(code, 'BUILD_SET')
192            self.assertInBytecode(code, 'LOAD_CONST', elem)
193            self.check_lnotab(code)
194
195        # Ensure that the resulting code actually works:
196        def f(a):
197            return a in {1, 2, 3}
198
199        def g(a):
200            return a not in {1, 2, 3}
201
202        self.assertTrue(f(3))
203        self.assertTrue(not f(4))
204        self.check_lnotab(f)
205
206        self.assertTrue(not g(3))
207        self.assertTrue(g(4))
208        self.check_lnotab(g)
209
210
211    def test_folding_of_binops_on_constants(self):
212        for line, elem in (
213            ('a = 2+3+4', 9),                   # chained fold
214            ('"@"*4', '@@@@'),                  # check string ops
215            ('a="abc" + "def"', 'abcdef'),      # check string ops
216            ('a = 3**4', 81),                   # binary power
217            ('a = 3*4', 12),                    # binary multiply
218            ('a = 13//4', 3),                   # binary floor divide
219            ('a = 14%4', 2),                    # binary modulo
220            ('a = 2+3', 5),                     # binary add
221            ('a = 13-4', 9),                    # binary subtract
222            ('a = (12,13)[1]', 13),             # binary subscr
223            ('a = 13 << 2', 52),                # binary lshift
224            ('a = 13 >> 2', 3),                 # binary rshift
225            ('a = 13 & 7', 5),                  # binary and
226            ('a = 13 ^ 7', 10),                 # binary xor
227            ('a = 13 | 7', 15),                 # binary or
228            ):
229            code = compile(line, '', 'single')
230            self.assertInBytecode(code, 'LOAD_CONST', elem)
231            for instr in dis.get_instructions(code):
232                self.assertFalse(instr.opname.startswith('BINARY_'))
233            self.check_lnotab(code)
234
235        # Verify that unfoldables are skipped
236        code = compile('a=2+"b"', '', 'single')
237        self.assertInBytecode(code, 'LOAD_CONST', 2)
238        self.assertInBytecode(code, 'LOAD_CONST', 'b')
239        self.check_lnotab(code)
240
241        # Verify that large sequences do not result from folding
242        code = compile('a="x"*10000', '', 'single')
243        self.assertInBytecode(code, 'LOAD_CONST', 10000)
244        self.assertNotIn("x"*10000, code.co_consts)
245        self.check_lnotab(code)
246        code = compile('a=1<<1000', '', 'single')
247        self.assertInBytecode(code, 'LOAD_CONST', 1000)
248        self.assertNotIn(1<<1000, code.co_consts)
249        self.check_lnotab(code)
250        code = compile('a=2**1000', '', 'single')
251        self.assertInBytecode(code, 'LOAD_CONST', 1000)
252        self.assertNotIn(2**1000, code.co_consts)
253        self.check_lnotab(code)
254
255    def test_binary_subscr_on_unicode(self):
256        # valid code get optimized
257        code = compile('"foo"[0]', '', 'single')
258        self.assertInBytecode(code, 'LOAD_CONST', 'f')
259        self.assertNotInBytecode(code, 'BINARY_SUBSCR')
260        self.check_lnotab(code)
261        code = compile('"\u0061\uffff"[1]', '', 'single')
262        self.assertInBytecode(code, 'LOAD_CONST', '\uffff')
263        self.assertNotInBytecode(code,'BINARY_SUBSCR')
264        self.check_lnotab(code)
265
266        # With PEP 393, non-BMP char get optimized
267        code = compile('"\U00012345"[0]', '', 'single')
268        self.assertInBytecode(code, 'LOAD_CONST', '\U00012345')
269        self.assertNotInBytecode(code, 'BINARY_SUBSCR')
270        self.check_lnotab(code)
271
272        # invalid code doesn't get optimized
273        # out of range
274        code = compile('"fuu"[10]', '', 'single')
275        self.assertInBytecode(code, 'BINARY_SUBSCR')
276        self.check_lnotab(code)
277
278    def test_folding_of_unaryops_on_constants(self):
279        for line, elem in (
280            ('-0.5', -0.5),                     # unary negative
281            ('-0.0', -0.0),                     # -0.0
282            ('-(1.0-1.0)', -0.0),               # -0.0 after folding
283            ('-0', 0),                          # -0
284            ('~-2', 1),                         # unary invert
285            ('+1', 1),                          # unary positive
286        ):
287            code = compile(line, '', 'single')
288            self.assertInBytecode(code, 'LOAD_CONST', elem)
289            for instr in dis.get_instructions(code):
290                self.assertFalse(instr.opname.startswith('UNARY_'))
291            self.check_lnotab(code)
292
293        # Check that -0.0 works after marshaling
294        def negzero():
295            return -(1.0-1.0)
296
297        for instr in dis.get_instructions(negzero):
298            self.assertFalse(instr.opname.startswith('UNARY_'))
299        self.check_lnotab(negzero)
300
301        # Verify that unfoldables are skipped
302        for line, elem, opname in (
303            ('-"abc"', 'abc', 'UNARY_NEGATIVE'),
304            ('~"abc"', 'abc', 'UNARY_INVERT'),
305        ):
306            code = compile(line, '', 'single')
307            self.assertInBytecode(code, 'LOAD_CONST', elem)
308            self.assertInBytecode(code, opname)
309            self.check_lnotab(code)
310
311    def test_elim_extra_return(self):
312        # RETURN LOAD_CONST None RETURN  -->  RETURN
313        def f(x):
314            return x
315        self.assertNotInBytecode(f, 'LOAD_CONST', None)
316        returns = [instr for instr in dis.get_instructions(f)
317                          if instr.opname == 'RETURN_VALUE']
318        self.assertEqual(len(returns), 1)
319        self.check_lnotab(f)
320
321    def test_elim_jump_to_return(self):
322        # JUMP_FORWARD to RETURN -->  RETURN
323        def f(cond, true_value, false_value):
324            # Intentionally use two-line expression to test issue37213.
325            return (true_value if cond
326                    else false_value)
327        self.check_jump_targets(f)
328        self.assertNotInBytecode(f, 'JUMP_FORWARD')
329        self.assertNotInBytecode(f, 'JUMP_ABSOLUTE')
330        returns = [instr for instr in dis.get_instructions(f)
331                          if instr.opname == 'RETURN_VALUE']
332        self.assertEqual(len(returns), 2)
333        self.check_lnotab(f)
334
335    def test_elim_jump_to_uncond_jump(self):
336        # POP_JUMP_IF_FALSE to JUMP_FORWARD --> POP_JUMP_IF_FALSE to non-jump
337        def f():
338            if a:
339                # Intentionally use two-line expression to test issue37213.
340                if (c
341                    or d):
342                    foo()
343            else:
344                baz()
345        self.check_jump_targets(f)
346        self.check_lnotab(f)
347
348    def test_elim_jump_to_uncond_jump2(self):
349        # POP_JUMP_IF_FALSE to JUMP_ABSOLUTE --> POP_JUMP_IF_FALSE to non-jump
350        def f():
351            while a:
352                # Intentionally use two-line expression to test issue37213.
353                if (c
354                    or d):
355                    a = foo()
356        self.check_jump_targets(f)
357        self.check_lnotab(f)
358
359    def test_elim_jump_to_uncond_jump3(self):
360        # Intentionally use two-line expressions to test issue37213.
361        # JUMP_IF_FALSE_OR_POP to JUMP_IF_FALSE_OR_POP --> JUMP_IF_FALSE_OR_POP to non-jump
362        def f(a, b, c):
363            return ((a and b)
364                    and c)
365        self.check_jump_targets(f)
366        self.check_lnotab(f)
367        self.assertEqual(count_instr_recursively(f, 'JUMP_IF_FALSE_OR_POP'), 2)
368        # JUMP_IF_TRUE_OR_POP to JUMP_IF_TRUE_OR_POP --> JUMP_IF_TRUE_OR_POP to non-jump
369        def f(a, b, c):
370            return ((a or b)
371                    or c)
372        self.check_jump_targets(f)
373        self.check_lnotab(f)
374        self.assertEqual(count_instr_recursively(f, 'JUMP_IF_TRUE_OR_POP'), 2)
375        # JUMP_IF_FALSE_OR_POP to JUMP_IF_TRUE_OR_POP --> POP_JUMP_IF_FALSE to non-jump
376        def f(a, b, c):
377            return ((a and b)
378                    or c)
379        self.check_jump_targets(f)
380        self.check_lnotab(f)
381        self.assertNotInBytecode(f, 'JUMP_IF_FALSE_OR_POP')
382        self.assertInBytecode(f, 'JUMP_IF_TRUE_OR_POP')
383        self.assertInBytecode(f, 'POP_JUMP_IF_FALSE')
384        # JUMP_IF_TRUE_OR_POP to JUMP_IF_FALSE_OR_POP --> POP_JUMP_IF_TRUE to non-jump
385        def f(a, b, c):
386            return ((a or b)
387                    and c)
388        self.check_jump_targets(f)
389        self.check_lnotab(f)
390        self.assertNotInBytecode(f, 'JUMP_IF_TRUE_OR_POP')
391        self.assertInBytecode(f, 'JUMP_IF_FALSE_OR_POP')
392        self.assertInBytecode(f, 'POP_JUMP_IF_TRUE')
393
394    def test_elim_jump_after_return1(self):
395        # Eliminate dead code: jumps immediately after returns can't be reached
396        def f(cond1, cond2):
397            if cond1: return 1
398            if cond2: return 2
399            while 1:
400                return 3
401            while 1:
402                if cond1: return 4
403                return 5
404            return 6
405        self.assertNotInBytecode(f, 'JUMP_FORWARD')
406        self.assertNotInBytecode(f, 'JUMP_ABSOLUTE')
407        returns = [instr for instr in dis.get_instructions(f)
408                          if instr.opname == 'RETURN_VALUE']
409        self.assertLessEqual(len(returns), 6)
410        self.check_lnotab(f)
411
412    def test_elim_jump_after_return2(self):
413        # Eliminate dead code: jumps immediately after returns can't be reached
414        def f(cond1, cond2):
415            while 1:
416                if cond1: return 4
417        self.assertNotInBytecode(f, 'JUMP_FORWARD')
418        # There should be one jump for the while loop.
419        returns = [instr for instr in dis.get_instructions(f)
420                          if instr.opname == 'JUMP_ABSOLUTE']
421        self.assertEqual(len(returns), 1)
422        returns = [instr for instr in dis.get_instructions(f)
423                          if instr.opname == 'RETURN_VALUE']
424        self.assertLessEqual(len(returns), 2)
425        self.check_lnotab(f)
426
427    def test_make_function_doesnt_bail(self):
428        def f():
429            def g()->1+1:
430                pass
431            return g
432        self.assertNotInBytecode(f, 'BINARY_ADD')
433        self.check_lnotab(f)
434
435    def test_constant_folding(self):
436        # Issue #11244: aggressive constant folding.
437        exprs = [
438            '3 * -5',
439            '-3 * 5',
440            '2 * (3 * 4)',
441            '(2 * 3) * 4',
442            '(-1, 2, 3)',
443            '(1, -2, 3)',
444            '(1, 2, -3)',
445            '(1, 2, -3) * 6',
446            'lambda x: x in {(3 * -5) + (-1 - 6), (1, -2, 3) * 2, None}',
447        ]
448        for e in exprs:
449            code = compile(e, '', 'single')
450            for instr in dis.get_instructions(code):
451                self.assertFalse(instr.opname.startswith('UNARY_'))
452                self.assertFalse(instr.opname.startswith('BINARY_'))
453                self.assertFalse(instr.opname.startswith('BUILD_'))
454            self.check_lnotab(code)
455
456    def test_in_literal_list(self):
457        def containtest():
458            return x in [a, b]
459        self.assertEqual(count_instr_recursively(containtest, 'BUILD_LIST'), 0)
460        self.check_lnotab(containtest)
461
462    def test_iterate_literal_list(self):
463        def forloop():
464            for x in [a, b]:
465                pass
466        self.assertEqual(count_instr_recursively(forloop, 'BUILD_LIST'), 0)
467        self.check_lnotab(forloop)
468
469    def test_condition_with_binop_with_bools(self):
470        def f():
471            if True or False:
472                return 1
473            return 0
474        self.assertEqual(f(), 1)
475        self.check_lnotab(f)
476
477    def test_if_with_if_expression(self):  # XXX does this belong in 3.8?
478        # Check bpo-37289
479        def f(x):
480            if (True if x else False):
481                return True
482            return False
483        self.assertTrue(f(True))
484        self.check_lnotab(f)
485
486    def test_trailing_nops(self):
487        # Check the lnotab of a function that even after trivial
488        # optimization has trailing nops, which the lnotab adjustment has to
489        # handle properly (bpo-38115).
490        def f(x):
491            while 1:
492                return 3
493            while 1:
494                return 5
495            return 6
496        self.check_lnotab(f)
497
498
499class TestBuglets(unittest.TestCase):
500
501    def test_bug_11510(self):
502        # folded constant set optimization was commingled with the tuple
503        # unpacking optimization which would fail if the set had duplicate
504        # elements so that the set length was unexpected
505        def f():
506            x, y = {1, 1}
507            return x, y
508        with self.assertRaises(ValueError):
509            f()
510
511
512if __name__ == "__main__":
513    unittest.main()
514