1"""
2Peephole optimizer of CPython 3.6 reimplemented in pure Python using
3the bytecode module.
4"""
5import opcode
6import operator
7import sys
8from _pydevd_frame_eval.vendored.bytecode import Instr, Bytecode, ControlFlowGraph, BasicBlock, Compare
9
10JUMPS_ON_TRUE = frozenset(
11    (
12        "POP_JUMP_IF_TRUE",
13        "JUMP_IF_TRUE_OR_POP",
14    )
15)
16
17NOT_COMPARE = {
18    Compare.IN: Compare.NOT_IN,
19    Compare.NOT_IN: Compare.IN,
20    Compare.IS: Compare.IS_NOT,
21    Compare.IS_NOT: Compare.IS,
22}
23
24MAX_SIZE = 20
25
26
27class ExitUnchanged(Exception):
28    """Exception used to skip the peephole optimizer"""
29
30    pass
31
32
33class PeepholeOptimizer:
34    """Python reimplementation of the peephole optimizer.
35
36    Copy of the C comment:
37
38    Perform basic peephole optimizations to components of a code object.
39    The consts object should still be in list form to allow new constants
40    to be appended.
41
42    To keep the optimizer simple, it bails out (does nothing) for code that
43    has a length over 32,700, and does not calculate extended arguments.
44    That allows us to avoid overflow and sign issues. Likewise, it bails when
45    the lineno table has complex encoding for gaps >= 255. EXTENDED_ARG can
46    appear before MAKE_FUNCTION; in this case both opcodes are skipped.
47    EXTENDED_ARG preceding any other opcode causes the optimizer to bail.
48
49    Optimizations are restricted to simple transformations occuring within a
50    single basic block.  All transformations keep the code size the same or
51    smaller.  For those that reduce size, the gaps are initially filled with
52    NOPs.  Later those NOPs are removed and the jump addresses retargeted in
53    a single pass.  Code offset is adjusted accordingly.
54    """
55
56    def __init__(self):
57        # bytecode.ControlFlowGraph instance
58        self.code = None
59        self.const_stack = None
60        self.block_index = None
61        self.block = None
62        # index of the current instruction in self.block instructions
63        self.index = None
64        # whether we are in a LOAD_CONST sequence
65        self.in_consts = False
66
67    def check_result(self, value):
68        try:
69            size = len(value)
70        except TypeError:
71            return True
72        return size <= MAX_SIZE
73
74    def replace_load_const(self, nconst, instr, result):
75        # FIXME: remove temporary computed constants?
76        # FIXME: or at least reuse existing constants?
77
78        self.in_consts = True
79
80        load_const = Instr("LOAD_CONST", result, lineno=instr.lineno)
81        start = self.index - nconst - 1
82        self.block[start : self.index] = (load_const,)
83        self.index -= nconst
84
85        if nconst:
86            del self.const_stack[-nconst:]
87        self.const_stack.append(result)
88        self.in_consts = True
89
90    def eval_LOAD_CONST(self, instr):
91        self.in_consts = True
92        value = instr.arg
93        self.const_stack.append(value)
94        self.in_consts = True
95
96    def unaryop(self, op, instr):
97        try:
98            value = self.const_stack[-1]
99            result = op(value)
100        except IndexError:
101            return
102
103        if not self.check_result(result):
104            return
105
106        self.replace_load_const(1, instr, result)
107
108    def eval_UNARY_POSITIVE(self, instr):
109        return self.unaryop(operator.pos, instr)
110
111    def eval_UNARY_NEGATIVE(self, instr):
112        return self.unaryop(operator.neg, instr)
113
114    def eval_UNARY_INVERT(self, instr):
115        return self.unaryop(operator.invert, instr)
116
117    def get_next_instr(self, name):
118        try:
119            next_instr = self.block[self.index]
120        except IndexError:
121            return None
122        if next_instr.name == name:
123            return next_instr
124        return None
125
126    def eval_UNARY_NOT(self, instr):
127        # Note: UNARY_NOT <const> is not optimized
128
129        next_instr = self.get_next_instr("POP_JUMP_IF_FALSE")
130        if next_instr is None:
131            return None
132
133        # Replace UNARY_NOT+POP_JUMP_IF_FALSE with POP_JUMP_IF_TRUE
134        instr.set("POP_JUMP_IF_TRUE", next_instr.arg)
135        del self.block[self.index]
136
137    def binop(self, op, instr):
138        try:
139            left = self.const_stack[-2]
140            right = self.const_stack[-1]
141        except IndexError:
142            return
143
144        try:
145            result = op(left, right)
146        except Exception:
147            return
148
149        if not self.check_result(result):
150            return
151
152        self.replace_load_const(2, instr, result)
153
154    def eval_BINARY_ADD(self, instr):
155        return self.binop(operator.add, instr)
156
157    def eval_BINARY_SUBTRACT(self, instr):
158        return self.binop(operator.sub, instr)
159
160    def eval_BINARY_MULTIPLY(self, instr):
161        return self.binop(operator.mul, instr)
162
163    def eval_BINARY_TRUE_DIVIDE(self, instr):
164        return self.binop(operator.truediv, instr)
165
166    def eval_BINARY_FLOOR_DIVIDE(self, instr):
167        return self.binop(operator.floordiv, instr)
168
169    def eval_BINARY_MODULO(self, instr):
170        return self.binop(operator.mod, instr)
171
172    def eval_BINARY_POWER(self, instr):
173        return self.binop(operator.pow, instr)
174
175    def eval_BINARY_LSHIFT(self, instr):
176        return self.binop(operator.lshift, instr)
177
178    def eval_BINARY_RSHIFT(self, instr):
179        return self.binop(operator.rshift, instr)
180
181    def eval_BINARY_AND(self, instr):
182        return self.binop(operator.and_, instr)
183
184    def eval_BINARY_OR(self, instr):
185        return self.binop(operator.or_, instr)
186
187    def eval_BINARY_XOR(self, instr):
188        return self.binop(operator.xor, instr)
189
190    def eval_BINARY_SUBSCR(self, instr):
191        return self.binop(operator.getitem, instr)
192
193    def replace_container_of_consts(self, instr, container_type):
194        items = self.const_stack[-instr.arg :]
195        value = container_type(items)
196        self.replace_load_const(instr.arg, instr, value)
197
198    def build_tuple_unpack_seq(self, instr):
199        next_instr = self.get_next_instr("UNPACK_SEQUENCE")
200        if next_instr is None or next_instr.arg != instr.arg:
201            return
202
203        if instr.arg < 1:
204            return
205
206        if self.const_stack and instr.arg <= len(self.const_stack):
207            nconst = instr.arg
208            start = self.index - 1
209
210            # Rewrite LOAD_CONST instructions in the reverse order
211            load_consts = self.block[start - nconst : start]
212            self.block[start - nconst : start] = reversed(load_consts)
213
214            # Remove BUILD_TUPLE+UNPACK_SEQUENCE
215            self.block[start : start + 2] = ()
216            self.index -= 2
217            self.const_stack.clear()
218            return
219
220        if instr.arg == 1:
221            # Replace BUILD_TUPLE 1 + UNPACK_SEQUENCE 1 with NOP
222            del self.block[self.index - 1 : self.index + 1]
223        elif instr.arg == 2:
224            # Replace BUILD_TUPLE 2 + UNPACK_SEQUENCE 2 with ROT_TWO
225            rot2 = Instr("ROT_TWO", lineno=instr.lineno)
226            self.block[self.index - 1 : self.index + 1] = (rot2,)
227            self.index -= 1
228            self.const_stack.clear()
229        elif instr.arg == 3:
230            # Replace BUILD_TUPLE 3 + UNPACK_SEQUENCE 3
231            # with ROT_THREE + ROT_TWO
232            rot3 = Instr("ROT_THREE", lineno=instr.lineno)
233            rot2 = Instr("ROT_TWO", lineno=instr.lineno)
234            self.block[self.index - 1 : self.index + 1] = (rot3, rot2)
235            self.index -= 1
236            self.const_stack.clear()
237
238    def build_tuple(self, instr, container_type):
239        if instr.arg > len(self.const_stack):
240            return
241
242        next_instr = self.get_next_instr("COMPARE_OP")
243        if next_instr is None or next_instr.arg not in (Compare.IN, Compare.NOT_IN):
244            return
245
246        self.replace_container_of_consts(instr, container_type)
247        return True
248
249    def eval_BUILD_TUPLE(self, instr):
250        if not instr.arg:
251            return
252
253        if instr.arg <= len(self.const_stack):
254            self.replace_container_of_consts(instr, tuple)
255        else:
256            self.build_tuple_unpack_seq(instr)
257
258    def eval_BUILD_LIST(self, instr):
259        if not instr.arg:
260            return
261
262        if not self.build_tuple(instr, tuple):
263            self.build_tuple_unpack_seq(instr)
264
265    def eval_BUILD_SET(self, instr):
266        if not instr.arg:
267            return
268
269        self.build_tuple(instr, frozenset)
270
271    # Note: BUILD_SLICE is not optimized
272
273    def eval_COMPARE_OP(self, instr):
274        # Note: COMPARE_OP: 2 < 3 is not optimized
275
276        try:
277            new_arg = NOT_COMPARE[instr.arg]
278        except KeyError:
279            return
280
281        if self.get_next_instr("UNARY_NOT") is None:
282            return
283
284        # not (a is b) -->  a is not b
285        # not (a in b) -->  a not in b
286        # not (a is not b) -->  a is b
287        # not (a not in b) -->  a in b
288        instr.arg = new_arg
289        self.block[self.index - 1 : self.index + 1] = (instr,)
290
291    def jump_if_or_pop(self, instr):
292        # Simplify conditional jump to conditional jump where the
293        # result of the first test implies the success of a similar
294        # test or the failure of the opposite test.
295        #
296        # Arises in code like:
297        # "if a and b:"
298        # "if a or b:"
299        # "a and b or c"
300        # "(a and b) and c"
301        #
302        # x:JUMP_IF_FALSE_OR_POP y   y:JUMP_IF_FALSE_OR_POP z
303        #    -->  x:JUMP_IF_FALSE_OR_POP z
304        #
305        # x:JUMP_IF_FALSE_OR_POP y   y:JUMP_IF_TRUE_OR_POP z
306        #    -->  x:POP_JUMP_IF_FALSE y+3
307        # where y+3 is the instruction following the second test.
308        target_block = instr.arg
309        try:
310            target_instr = target_block[0]
311        except IndexError:
312            return
313
314        if not target_instr.is_cond_jump():
315            self.optimize_jump_to_cond_jump(instr)
316            return
317
318        if (target_instr.name in JUMPS_ON_TRUE) == (instr.name in JUMPS_ON_TRUE):
319            # The second jump will be taken iff the first is.
320
321            target2 = target_instr.arg
322            # The current opcode inherits its target's stack behaviour
323            instr.name = target_instr.name
324            instr.arg = target2
325            self.block[self.index - 1] = instr
326            self.index -= 1
327        else:
328            # The second jump is not taken if the first is (so jump past it),
329            # and all conditional jumps pop their argument when they're not
330            # taken (so change the first jump to pop its argument when it's
331            # taken).
332            if instr.name in JUMPS_ON_TRUE:
333                name = "POP_JUMP_IF_TRUE"
334            else:
335                name = "POP_JUMP_IF_FALSE"
336
337            new_label = self.code.split_block(target_block, 1)
338
339            instr.name = name
340            instr.arg = new_label
341            self.block[self.index - 1] = instr
342            self.index -= 1
343
344    def eval_JUMP_IF_FALSE_OR_POP(self, instr):
345        self.jump_if_or_pop(instr)
346
347    def eval_JUMP_IF_TRUE_OR_POP(self, instr):
348        self.jump_if_or_pop(instr)
349
350    def eval_NOP(self, instr):
351        # Remove NOP
352        del self.block[self.index - 1]
353        self.index -= 1
354
355    def optimize_jump_to_cond_jump(self, instr):
356        # Replace jumps to unconditional jumps
357        jump_label = instr.arg
358        assert isinstance(jump_label, BasicBlock), jump_label
359
360        try:
361            target_instr = jump_label[0]
362        except IndexError:
363            return
364
365        if instr.is_uncond_jump() and target_instr.name == "RETURN_VALUE":
366            # Replace JUMP_ABSOLUTE => RETURN_VALUE with RETURN_VALUE
367            self.block[self.index - 1] = target_instr
368
369        elif target_instr.is_uncond_jump():
370            # Replace JUMP_FORWARD t1 jumping to JUMP_FORWARD t2
371            # with JUMP_ABSOLUTE t2
372            jump_target2 = target_instr.arg
373
374            name = instr.name
375            if instr.name == "JUMP_FORWARD":
376                name = "JUMP_ABSOLUTE"
377            else:
378                # FIXME: reimplement this check
379                # if jump_target2 < 0:
380                #    # No backward relative jumps
381                #    return
382
383                # FIXME: remove this workaround and implement comment code ^^
384                if instr.opcode in opcode.hasjrel:
385                    return
386
387            instr.name = name
388            instr.arg = jump_target2
389            self.block[self.index - 1] = instr
390
391    def optimize_jump(self, instr):
392        if instr.is_uncond_jump() and self.index == len(self.block):
393            # JUMP_ABSOLUTE at the end of a block which points to the
394            # following block: remove the jump, link the current block
395            # to the following block
396            block_index = self.block_index
397            target_block = instr.arg
398            target_block_index = self.code.get_block_index(target_block)
399            if target_block_index == block_index:
400                del self.block[self.index - 1]
401                self.block.next_block = target_block
402                return
403
404        self.optimize_jump_to_cond_jump(instr)
405
406    def iterblock(self, block):
407        self.block = block
408        self.index = 0
409        while self.index < len(block):
410            instr = self.block[self.index]
411            self.index += 1
412            yield instr
413
414    def optimize_block(self, block):
415        self.const_stack.clear()
416        self.in_consts = False
417
418        for instr in self.iterblock(block):
419            if not self.in_consts:
420                self.const_stack.clear()
421            self.in_consts = False
422
423            meth_name = "eval_%s" % instr.name
424            meth = getattr(self, meth_name, None)
425            if meth is not None:
426                meth(instr)
427            elif instr.has_jump():
428                self.optimize_jump(instr)
429
430            # Note: Skipping over LOAD_CONST trueconst; POP_JUMP_IF_FALSE
431            # <target> is not implemented, since it looks like the optimization
432            # is never trigerred in practice. The compiler already optimizes if
433            # and while statements.
434
435    def remove_dead_blocks(self):
436        # FIXME: remove empty blocks?
437
438        used_blocks = {id(self.code[0])}
439        for block in self.code:
440            if block.next_block is not None:
441                used_blocks.add(id(block.next_block))
442            for instr in block:
443                if isinstance(instr, Instr) and isinstance(instr.arg, BasicBlock):
444                    used_blocks.add(id(instr.arg))
445
446        block_index = 0
447        while block_index < len(self.code):
448            block = self.code[block_index]
449            if id(block) not in used_blocks:
450                del self.code[block_index]
451            else:
452                block_index += 1
453
454        # FIXME: merge following blocks if block1 does not contain any
455        # jump and block1.next_block is block2
456
457    def optimize_cfg(self, cfg):
458        self.code = cfg
459        self.const_stack = []
460
461        self.remove_dead_blocks()
462
463        self.block_index = 0
464        while self.block_index < len(self.code):
465            block = self.code[self.block_index]
466            self.block_index += 1
467            self.optimize_block(block)
468
469    def optimize(self, code_obj):
470        bytecode = Bytecode.from_code(code_obj)
471        cfg = ControlFlowGraph.from_bytecode(bytecode)
472
473        self.optimize_cfg(cfg)
474
475        bytecode = cfg.to_bytecode()
476        code = bytecode.to_code()
477        return code
478
479
480# Code transformer for the PEP 511
481class CodeTransformer:
482    name = "pyopt"
483
484    def code_transformer(self, code, context):
485        if sys.flags.verbose:
486            print(
487                "Optimize %s:%s: %s"
488                % (code.co_filename, code.co_firstlineno, code.co_name)
489            )
490        optimizer = PeepholeOptimizer()
491        return optimizer.optimize(code)
492