1"""
2Peephole optimizer of CPython 3.6 reimplemented in pure Python using
3the bytecode module.
4"""
5import opcode
6import operator
7import sys
8from _pydevd_frame_eval.vendored.bytecode import Instr, Bytecode, ControlFlowGraph, BasicBlock, Compare
9
10JUMPS_ON_TRUE = frozenset(("POP_JUMP_IF_TRUE", "JUMP_IF_TRUE_OR_POP",))
11
12NOT_COMPARE = {
13    Compare.IN: Compare.NOT_IN,
14    Compare.NOT_IN: Compare.IN,
15    Compare.IS: Compare.IS_NOT,
16    Compare.IS_NOT: Compare.IS,
17}
18
19MAX_SIZE = 20
20
21
22class ExitUnchanged(Exception):
23    """Exception used to skip the peephole optimizer"""
24
25    pass
26
27
28class PeepholeOptimizer:
29    """Python reimplementation of the peephole optimizer.
30
31    Copy of the C comment:
32
33    Perform basic peephole optimizations to components of a code object.
34    The consts object should still be in list form to allow new constants
35    to be appended.
36
37    To keep the optimizer simple, it bails out (does nothing) for code that
38    has a length over 32,700, and does not calculate extended arguments.
39    That allows us to avoid overflow and sign issues. Likewise, it bails when
40    the lineno table has complex encoding for gaps >= 255. EXTENDED_ARG can
41    appear before MAKE_FUNCTION; in this case both opcodes are skipped.
42    EXTENDED_ARG preceding any other opcode causes the optimizer to bail.
43
44    Optimizations are restricted to simple transformations occuring within a
45    single basic block.  All transformations keep the code size the same or
46    smaller.  For those that reduce size, the gaps are initially filled with
47    NOPs.  Later those NOPs are removed and the jump addresses retargeted in
48    a single pass.  Code offset is adjusted accordingly.
49    """
50
51    def __init__(self):
52        # bytecode.ControlFlowGraph instance
53        self.code = None
54        self.const_stack = None
55        self.block_index = None
56        self.block = None
57        # index of the current instruction in self.block instructions
58        self.index = None
59        # whether we are in a LOAD_CONST sequence
60        self.in_consts = False
61
62    def check_result(self, value):
63        try:
64            size = len(value)
65        except TypeError:
66            return True
67        return size <= MAX_SIZE
68
69    def replace_load_const(self, nconst, instr, result):
70        # FIXME: remove temporary computed constants?
71        # FIXME: or at least reuse existing constants?
72
73        self.in_consts = True
74
75        load_const = Instr("LOAD_CONST", result, lineno=instr.lineno)
76        start = self.index - nconst - 1
77        self.block[start : self.index] = (load_const,)
78        self.index -= nconst
79
80        if nconst:
81            del self.const_stack[-nconst:]
82        self.const_stack.append(result)
83        self.in_consts = True
84
85    def eval_LOAD_CONST(self, instr):
86        self.in_consts = True
87        value = instr.arg
88        self.const_stack.append(value)
89        self.in_consts = True
90
91    def unaryop(self, op, instr):
92        try:
93            value = self.const_stack[-1]
94            result = op(value)
95        except IndexError:
96            return
97
98        if not self.check_result(result):
99            return
100
101        self.replace_load_const(1, instr, result)
102
103    def eval_UNARY_POSITIVE(self, instr):
104        return self.unaryop(operator.pos, instr)
105
106    def eval_UNARY_NEGATIVE(self, instr):
107        return self.unaryop(operator.neg, instr)
108
109    def eval_UNARY_INVERT(self, instr):
110        return self.unaryop(operator.invert, instr)
111
112    def get_next_instr(self, name):
113        try:
114            next_instr = self.block[self.index]
115        except IndexError:
116            return None
117        if next_instr.name == name:
118            return next_instr
119        return None
120
121    def eval_UNARY_NOT(self, instr):
122        # Note: UNARY_NOT <const> is not optimized
123
124        next_instr = self.get_next_instr("POP_JUMP_IF_FALSE")
125        if next_instr is None:
126            return None
127
128        # Replace UNARY_NOT+POP_JUMP_IF_FALSE with POP_JUMP_IF_TRUE
129        instr.set("POP_JUMP_IF_TRUE", next_instr.arg)
130        del self.block[self.index]
131
132    def binop(self, op, instr):
133        try:
134            left = self.const_stack[-2]
135            right = self.const_stack[-1]
136        except IndexError:
137            return
138
139        try:
140            result = op(left, right)
141        except Exception:
142            return
143
144        if not self.check_result(result):
145            return
146
147        self.replace_load_const(2, instr, result)
148
149    def eval_BINARY_ADD(self, instr):
150        return self.binop(operator.add, instr)
151
152    def eval_BINARY_SUBTRACT(self, instr):
153        return self.binop(operator.sub, instr)
154
155    def eval_BINARY_MULTIPLY(self, instr):
156        return self.binop(operator.mul, instr)
157
158    def eval_BINARY_TRUE_DIVIDE(self, instr):
159        return self.binop(operator.truediv, instr)
160
161    def eval_BINARY_FLOOR_DIVIDE(self, instr):
162        return self.binop(operator.floordiv, instr)
163
164    def eval_BINARY_MODULO(self, instr):
165        return self.binop(operator.mod, instr)
166
167    def eval_BINARY_POWER(self, instr):
168        return self.binop(operator.pow, instr)
169
170    def eval_BINARY_LSHIFT(self, instr):
171        return self.binop(operator.lshift, instr)
172
173    def eval_BINARY_RSHIFT(self, instr):
174        return self.binop(operator.rshift, instr)
175
176    def eval_BINARY_AND(self, instr):
177        return self.binop(operator.and_, instr)
178
179    def eval_BINARY_OR(self, instr):
180        return self.binop(operator.or_, instr)
181
182    def eval_BINARY_XOR(self, instr):
183        return self.binop(operator.xor, instr)
184
185    def eval_BINARY_SUBSCR(self, instr):
186        return self.binop(operator.getitem, instr)
187
188    def replace_container_of_consts(self, instr, container_type):
189        items = self.const_stack[-instr.arg :]
190        value = container_type(items)
191        self.replace_load_const(instr.arg, instr, value)
192
193    def build_tuple_unpack_seq(self, instr):
194        next_instr = self.get_next_instr("UNPACK_SEQUENCE")
195        if next_instr is None or next_instr.arg != instr.arg:
196            return
197
198        if instr.arg < 1:
199            return
200
201        if self.const_stack and instr.arg <= len(self.const_stack):
202            nconst = instr.arg
203            start = self.index - 1
204
205            # Rewrite LOAD_CONST instructions in the reverse order
206            load_consts = self.block[start - nconst : start]
207            self.block[start - nconst : start] = reversed(load_consts)
208
209            # Remove BUILD_TUPLE+UNPACK_SEQUENCE
210            self.block[start : start + 2] = ()
211            self.index -= 2
212            self.const_stack.clear()
213            return
214
215        if instr.arg == 1:
216            # Replace BUILD_TUPLE 1 + UNPACK_SEQUENCE 1 with NOP
217            del self.block[self.index - 1 : self.index + 1]
218        elif instr.arg == 2:
219            # Replace BUILD_TUPLE 2 + UNPACK_SEQUENCE 2 with ROT_TWO
220            rot2 = Instr("ROT_TWO", lineno=instr.lineno)
221            self.block[self.index - 1 : self.index + 1] = (rot2,)
222            self.index -= 1
223            self.const_stack.clear()
224        elif instr.arg == 3:
225            # Replace BUILD_TUPLE 3 + UNPACK_SEQUENCE 3
226            # with ROT_THREE + ROT_TWO
227            rot3 = Instr("ROT_THREE", lineno=instr.lineno)
228            rot2 = Instr("ROT_TWO", lineno=instr.lineno)
229            self.block[self.index - 1 : self.index + 1] = (rot3, rot2)
230            self.index -= 1
231            self.const_stack.clear()
232
233    def build_tuple(self, instr, container_type):
234        if instr.arg > len(self.const_stack):
235            return
236
237        next_instr = self.get_next_instr("COMPARE_OP")
238        if next_instr is None or next_instr.arg not in (Compare.IN, Compare.NOT_IN):
239            return
240
241        self.replace_container_of_consts(instr, container_type)
242        return True
243
244    def eval_BUILD_TUPLE(self, instr):
245        if not instr.arg:
246            return
247
248        if instr.arg <= len(self.const_stack):
249            self.replace_container_of_consts(instr, tuple)
250        else:
251            self.build_tuple_unpack_seq(instr)
252
253    def eval_BUILD_LIST(self, instr):
254        if not instr.arg:
255            return
256
257        if not self.build_tuple(instr, tuple):
258            self.build_tuple_unpack_seq(instr)
259
260    def eval_BUILD_SET(self, instr):
261        if not instr.arg:
262            return
263
264        self.build_tuple(instr, frozenset)
265
266    # Note: BUILD_SLICE is not optimized
267
268    def eval_COMPARE_OP(self, instr):
269        # Note: COMPARE_OP: 2 < 3 is not optimized
270
271        try:
272            new_arg = NOT_COMPARE[instr.arg]
273        except KeyError:
274            return
275
276        if self.get_next_instr("UNARY_NOT") is None:
277            return
278
279        # not (a is b) -->  a is not b
280        # not (a in b) -->  a not in b
281        # not (a is not b) -->  a is b
282        # not (a not in b) -->  a in b
283        instr.arg = new_arg
284        self.block[self.index - 1 : self.index + 1] = (instr,)
285
286    def jump_if_or_pop(self, instr):
287        # Simplify conditional jump to conditional jump where the
288        # result of the first test implies the success of a similar
289        # test or the failure of the opposite test.
290        #
291        # Arises in code like:
292        # "if a and b:"
293        # "if a or b:"
294        # "a and b or c"
295        # "(a and b) and c"
296        #
297        # x:JUMP_IF_FALSE_OR_POP y   y:JUMP_IF_FALSE_OR_POP z
298        #    -->  x:JUMP_IF_FALSE_OR_POP z
299        #
300        # x:JUMP_IF_FALSE_OR_POP y   y:JUMP_IF_TRUE_OR_POP z
301        #    -->  x:POP_JUMP_IF_FALSE y+3
302        # where y+3 is the instruction following the second test.
303        target_block = instr.arg
304        try:
305            target_instr = target_block[0]
306        except IndexError:
307            return
308
309        if not target_instr.is_cond_jump():
310            self.optimize_jump_to_cond_jump(instr)
311            return
312
313        if (target_instr.name in JUMPS_ON_TRUE) == (instr.name in JUMPS_ON_TRUE):
314            # The second jump will be taken iff the first is.
315
316            target2 = target_instr.arg
317            # The current opcode inherits its target's stack behaviour
318            instr.name = target_instr.name
319            instr.arg = target2
320            self.block[self.index - 1] = instr
321            self.index -= 1
322        else:
323            # The second jump is not taken if the first is (so jump past it),
324            # and all conditional jumps pop their argument when they're not
325            # taken (so change the first jump to pop its argument when it's
326            # taken).
327            if instr.name in JUMPS_ON_TRUE:
328                name = "POP_JUMP_IF_TRUE"
329            else:
330                name = "POP_JUMP_IF_FALSE"
331
332            new_label = self.code.split_block(target_block, 1)
333
334            instr.name = name
335            instr.arg = new_label
336            self.block[self.index - 1] = instr
337            self.index -= 1
338
339    def eval_JUMP_IF_FALSE_OR_POP(self, instr):
340        self.jump_if_or_pop(instr)
341
342    def eval_JUMP_IF_TRUE_OR_POP(self, instr):
343        self.jump_if_or_pop(instr)
344
345    def eval_NOP(self, instr):
346        # Remove NOP
347        del self.block[self.index - 1]
348        self.index -= 1
349
350    def optimize_jump_to_cond_jump(self, instr):
351        # Replace jumps to unconditional jumps
352        jump_label = instr.arg
353        assert isinstance(jump_label, BasicBlock), jump_label
354
355        try:
356            target_instr = jump_label[0]
357        except IndexError:
358            return
359
360        if instr.is_uncond_jump() and target_instr.name == "RETURN_VALUE":
361            # Replace JUMP_ABSOLUTE => RETURN_VALUE with RETURN_VALUE
362            self.block[self.index - 1] = target_instr
363
364        elif target_instr.is_uncond_jump():
365            # Replace JUMP_FORWARD t1 jumping to JUMP_FORWARD t2
366            # with JUMP_ABSOLUTE t2
367            jump_target2 = target_instr.arg
368
369            name = instr.name
370            if instr.name == "JUMP_FORWARD":
371                name = "JUMP_ABSOLUTE"
372            else:
373                # FIXME: reimplement this check
374                # if jump_target2 < 0:
375                #    # No backward relative jumps
376                #    return
377
378                # FIXME: remove this workaround and implement comment code ^^
379                if instr.opcode in opcode.hasjrel:
380                    return
381
382            instr.name = name
383            instr.arg = jump_target2
384            self.block[self.index - 1] = instr
385
386    def optimize_jump(self, instr):
387        if instr.is_uncond_jump() and self.index == len(self.block):
388            # JUMP_ABSOLUTE at the end of a block which points to the
389            # following block: remove the jump, link the current block
390            # to the following block
391            block_index = self.block_index
392            target_block = instr.arg
393            target_block_index = self.code.get_block_index(target_block)
394            if target_block_index == block_index:
395                del self.block[self.index - 1]
396                self.block.next_block = target_block
397                return
398
399        self.optimize_jump_to_cond_jump(instr)
400
401    def iterblock(self, block):
402        self.block = block
403        self.index = 0
404        while self.index < len(block):
405            instr = self.block[self.index]
406            self.index += 1
407            yield instr
408
409    def optimize_block(self, block):
410        self.const_stack.clear()
411        self.in_consts = False
412
413        for instr in self.iterblock(block):
414            if not self.in_consts:
415                self.const_stack.clear()
416            self.in_consts = False
417
418            meth_name = "eval_%s" % instr.name
419            meth = getattr(self, meth_name, None)
420            if meth is not None:
421                meth(instr)
422            elif instr.has_jump():
423                self.optimize_jump(instr)
424
425            # Note: Skipping over LOAD_CONST trueconst; POP_JUMP_IF_FALSE
426            # <target> is not implemented, since it looks like the optimization
427            # is never trigerred in practice. The compiler already optimizes if
428            # and while statements.
429
430    def remove_dead_blocks(self):
431        # FIXME: remove empty blocks?
432
433        used_blocks = {id(self.code[0])}
434        for block in self.code:
435            if block.next_block is not None:
436                used_blocks.add(id(block.next_block))
437            for instr in block:
438                if isinstance(instr, Instr) and isinstance(instr.arg, BasicBlock):
439                    used_blocks.add(id(instr.arg))
440
441        block_index = 0
442        while block_index < len(self.code):
443            block = self.code[block_index]
444            if id(block) not in used_blocks:
445                del self.code[block_index]
446            else:
447                block_index += 1
448
449        # FIXME: merge following blocks if block1 does not contain any
450        # jump and block1.next_block is block2
451
452    def optimize_cfg(self, cfg):
453        self.code = cfg
454        self.const_stack = []
455
456        self.remove_dead_blocks()
457
458        self.block_index = 0
459        while self.block_index < len(self.code):
460            block = self.code[self.block_index]
461            self.block_index += 1
462            self.optimize_block(block)
463
464    def optimize(self, code_obj):
465        bytecode = Bytecode.from_code(code_obj)
466        cfg = ControlFlowGraph.from_bytecode(bytecode)
467
468        self.optimize_cfg(cfg)
469
470        bytecode = cfg.to_bytecode()
471        code = bytecode.to_code()
472        return code
473
474
475# Code transformer for the PEP 511
476class CodeTransformer:
477    name = "pyopt"
478
479    def code_transformer(self, code, context):
480        if sys.flags.verbose:
481            print(
482                "Optimize %s:%s: %s"
483                % (code.co_filename, code.co_firstlineno, code.co_name)
484            )
485        optimizer = PeepholeOptimizer()
486        return optimizer.optimize(code)
487