1"""Helpers for generating for loops and comprehensions.
2
3We special case certain kinds for loops such as "for x in range(...)"
4for better efficiency.  Each for loop generator class below deals one
5such special case.
6"""
7
8from typing import Union, List, Optional, Tuple, Callable
9from typing_extensions import Type, ClassVar
10
11from mypy.nodes import (
12    Lvalue, Expression, TupleExpr, CallExpr, RefExpr, GeneratorExpr, ARG_POS, MemberExpr, TypeAlias
13)
14from mypyc.ir.ops import (
15    Value, BasicBlock, Integer, Branch, Register, TupleGet, TupleSet, IntOp
16)
17from mypyc.ir.rtypes import (
18    RType, is_short_int_rprimitive, is_list_rprimitive, is_sequence_rprimitive,
19    is_tuple_rprimitive, is_dict_rprimitive,
20    RTuple, short_int_rprimitive, int_rprimitive
21)
22from mypyc.primitives.registry import CFunctionDescription
23from mypyc.primitives.dict_ops import (
24    dict_next_key_op, dict_next_value_op, dict_next_item_op, dict_check_size_op,
25    dict_key_iter_op, dict_value_iter_op, dict_item_iter_op
26)
27from mypyc.primitives.list_ops import list_append_op, list_get_item_unsafe_op, new_list_set_item_op
28from mypyc.primitives.set_ops import set_add_op
29from mypyc.primitives.generic_ops import iter_op, next_op
30from mypyc.primitives.exc_ops import no_err_occurred_op
31from mypyc.irbuild.builder import IRBuilder
32from mypyc.irbuild.targets import AssignmentTarget, AssignmentTargetTuple
33
34GenFunc = Callable[[], None]
35
36
37def for_loop_helper(builder: IRBuilder, index: Lvalue, expr: Expression,
38                    body_insts: GenFunc, else_insts: Optional[GenFunc],
39                    line: int) -> None:
40    """Generate IR for a loop.
41
42    Args:
43        index: the loop index Lvalue
44        expr: the expression to iterate over
45        body_insts: a function that generates the body of the loop
46        else_insts: a function that generates the else block instructions
47    """
48    # Body of the loop
49    body_block = BasicBlock()
50    # Block that steps to the next item
51    step_block = BasicBlock()
52    # Block for the else clause, if we need it
53    else_block = BasicBlock()
54    # Block executed after the loop
55    exit_block = BasicBlock()
56
57    # Determine where we want to exit, if our condition check fails.
58    normal_loop_exit = else_block if else_insts is not None else exit_block
59
60    for_gen = make_for_loop_generator(builder, index, expr, body_block, normal_loop_exit, line)
61
62    builder.push_loop_stack(step_block, exit_block)
63    condition_block = BasicBlock()
64    builder.goto_and_activate(condition_block)
65
66    # Add loop condition check.
67    for_gen.gen_condition()
68
69    # Generate loop body.
70    builder.activate_block(body_block)
71    for_gen.begin_body()
72    body_insts()
73
74    # We generate a separate step block (which might be empty).
75    builder.goto_and_activate(step_block)
76    for_gen.gen_step()
77    # Go back to loop condition.
78    builder.goto(condition_block)
79
80    for_gen.add_cleanup(normal_loop_exit)
81    builder.pop_loop_stack()
82
83    if else_insts is not None:
84        builder.activate_block(else_block)
85        else_insts()
86        builder.goto(exit_block)
87
88    builder.activate_block(exit_block)
89
90
91def for_loop_helper_with_index(builder: IRBuilder,
92                               index: Lvalue,
93                               expr: Expression,
94                               expr_reg: Value,
95                               body_insts: Callable[[Value], None], line: int) -> None:
96    """Generate IR for a sequence iteration.
97
98    This function only works for sequence type. Compared to for_loop_helper,
99    it would feed iteration index to body_insts.
100
101    Args:
102        index: the loop index Lvalue
103        expr: the expression to iterate over
104        body_insts: a function that generates the body of the loop.
105                    It needs a index as parameter.
106    """
107    assert is_sequence_rprimitive(expr_reg.type)
108    target_type = builder.get_sequence_type(expr)
109
110    body_block = BasicBlock()
111    step_block = BasicBlock()
112    exit_block = BasicBlock()
113    condition_block = BasicBlock()
114
115    for_gen = ForSequence(builder, index, body_block, exit_block, line, False)
116    for_gen.init(expr_reg, target_type, reverse=False)
117
118    builder.push_loop_stack(step_block, exit_block)
119
120    builder.goto_and_activate(condition_block)
121    for_gen.gen_condition()
122
123    builder.activate_block(body_block)
124    for_gen.begin_body()
125    body_insts(builder.read(for_gen.index_target))
126
127    builder.goto_and_activate(step_block)
128    for_gen.gen_step()
129    builder.goto(condition_block)
130
131    for_gen.add_cleanup(exit_block)
132    builder.pop_loop_stack()
133
134    builder.activate_block(exit_block)
135
136
137def sequence_from_generator_preallocate_helper(
138        builder: IRBuilder,
139        gen: GeneratorExpr,
140        empty_op_llbuilder: Callable[[Value, int], Value],
141        set_item_op: CFunctionDescription) -> Optional[Value]:
142    """Generate a new tuple or list from a simple generator expression.
143
144    Currently we only optimize for simplest generator expression, which means that
145    there is no condition list in the generator and only one original sequence with
146    one index is allowed.
147
148    e.g.  (1) tuple(f(x) for x in a_list/a_tuple)
149          (2) list(f(x) for x in a_list/a_tuple)
150          (3) [f(x) for x in a_list/a_tuple]
151    RTuple as an original sequence is not supported yet.
152
153    Args:
154        empty_op_llbuilder: A function that can generate an empty sequence op when
155            passed in length. See `new_list_op_with_length` and `new_tuple_op_with_length`
156            for detailed implementation.
157        set_item_op: A primitive that can modify an arbitrary position of a sequence.
158            The op should have three arguments:
159                - Self
160                - Target position
161                - New Value
162            See `new_list_set_item_op` and `new_tuple_set_item_op` for detailed
163            implementation.
164    """
165    if len(gen.sequences) == 1 and len(gen.indices) == 1 and len(gen.condlists[0]) == 0:
166        rtype = builder.node_type(gen.sequences[0])
167        if is_list_rprimitive(rtype) or is_tuple_rprimitive(rtype):
168            sequence = builder.accept(gen.sequences[0])
169            length = builder.builder.builtin_len(sequence, gen.line, use_pyssize_t=True)
170            target_op = empty_op_llbuilder(length, gen.line)
171
172            def set_item(item_index: Value) -> None:
173                e = builder.accept(gen.left_expr)
174                builder.call_c(set_item_op, [target_op, item_index, e], gen.line)
175
176            for_loop_helper_with_index(builder, gen.indices[0], gen.sequences[0], sequence,
177                                       set_item, gen.line)
178
179            return target_op
180    return None
181
182
183def translate_list_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Value:
184    # Try simplest list comprehension, otherwise fall back to general one
185    val = sequence_from_generator_preallocate_helper(
186        builder, gen,
187        empty_op_llbuilder=builder.builder.new_list_op_with_length,
188        set_item_op=new_list_set_item_op)
189    if val is not None:
190        return val
191
192    list_ops = builder.new_list_op([], gen.line)
193    loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
194
195    def gen_inner_stmts() -> None:
196        e = builder.accept(gen.left_expr)
197        builder.call_c(list_append_op, [list_ops, e], gen.line)
198
199    comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)
200    return list_ops
201
202
203def translate_set_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Value:
204    set_ops = builder.new_set_op([], gen.line)
205    loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
206
207    def gen_inner_stmts() -> None:
208        e = builder.accept(gen.left_expr)
209        builder.call_c(set_add_op, [set_ops, e], gen.line)
210
211    comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)
212    return set_ops
213
214
215def comprehension_helper(builder: IRBuilder,
216                         loop_params: List[Tuple[Lvalue, Expression, List[Expression]]],
217                         gen_inner_stmts: Callable[[], None],
218                         line: int) -> None:
219    """Helper function for list comprehensions.
220
221    Args:
222        loop_params: a list of (index, expr, [conditions]) tuples defining nested loops:
223            - "index" is the Lvalue indexing that loop;
224            - "expr" is the expression for the object to be iterated over;
225            - "conditions" is a list of conditions, evaluated in order with short-circuiting,
226                that must all be true for the loop body to be executed
227        gen_inner_stmts: function to generate the IR for the body of the innermost loop
228    """
229    def handle_loop(loop_params: List[Tuple[Lvalue, Expression, List[Expression]]]) -> None:
230        """Generate IR for a loop.
231
232        Given a list of (index, expression, [conditions]) tuples, generate IR
233        for the nested loops the list defines.
234        """
235        index, expr, conds = loop_params[0]
236        for_loop_helper(builder, index, expr,
237                        lambda: loop_contents(conds, loop_params[1:]),
238                        None, line)
239
240    def loop_contents(
241            conds: List[Expression],
242            remaining_loop_params: List[Tuple[Lvalue, Expression, List[Expression]]],
243    ) -> None:
244        """Generate the body of the loop.
245
246        Args:
247            conds: a list of conditions to be evaluated (in order, with short circuiting)
248                to gate the body of the loop
249            remaining_loop_params: the parameters for any further nested loops; if it's empty
250                we'll instead evaluate the "gen_inner_stmts" function
251        """
252        # Check conditions, in order, short circuiting them.
253        for cond in conds:
254            cond_val = builder.accept(cond)
255            cont_block, rest_block = BasicBlock(), BasicBlock()
256            # If the condition is true we'll skip the continue.
257            builder.add_bool_branch(cond_val, rest_block, cont_block)
258            builder.activate_block(cont_block)
259            builder.nonlocal_control[-1].gen_continue(builder, cond.line)
260            builder.goto_and_activate(rest_block)
261
262        if remaining_loop_params:
263            # There's another nested level, so the body of this loop is another loop.
264            return handle_loop(remaining_loop_params)
265        else:
266            # We finally reached the actual body of the generator.
267            # Generate the IR for the inner loop body.
268            gen_inner_stmts()
269
270    handle_loop(loop_params)
271
272
273def is_range_ref(expr: RefExpr) -> bool:
274    return (expr.fullname == 'builtins.range'
275            or isinstance(expr.node, TypeAlias) and expr.fullname == 'six.moves.xrange')
276
277
278def make_for_loop_generator(builder: IRBuilder,
279                            index: Lvalue,
280                            expr: Expression,
281                            body_block: BasicBlock,
282                            loop_exit: BasicBlock,
283                            line: int,
284                            nested: bool = False) -> 'ForGenerator':
285    """Return helper object for generating a for loop over an iterable.
286
287    If "nested" is True, this is a nested iterator such as "e" in "enumerate(e)".
288    """
289
290    rtyp = builder.node_type(expr)
291    if is_sequence_rprimitive(rtyp):
292        # Special case "for x in <list>".
293        expr_reg = builder.accept(expr)
294        target_type = builder.get_sequence_type(expr)
295
296        for_list = ForSequence(builder, index, body_block, loop_exit, line, nested)
297        for_list.init(expr_reg, target_type, reverse=False)
298        return for_list
299
300    if is_dict_rprimitive(rtyp):
301        # Special case "for k in <dict>".
302        expr_reg = builder.accept(expr)
303        target_type = builder.get_dict_key_type(expr)
304
305        for_dict = ForDictionaryKeys(builder, index, body_block, loop_exit, line, nested)
306        for_dict.init(expr_reg, target_type)
307        return for_dict
308
309    if (isinstance(expr, CallExpr)
310            and isinstance(expr.callee, RefExpr)):
311        if (is_range_ref(expr.callee)
312                and (len(expr.args) <= 2
313                     or (len(expr.args) == 3
314                         and builder.extract_int(expr.args[2]) is not None))
315                and set(expr.arg_kinds) == {ARG_POS}):
316            # Special case "for x in range(...)".
317            # We support the 3 arg form but only for int literals, since it doesn't
318            # seem worth the hassle of supporting dynamically determining which
319            # direction of comparison to do.
320            if len(expr.args) == 1:
321                start_reg = Integer(0)  # type: Value
322                end_reg = builder.accept(expr.args[0])
323            else:
324                start_reg = builder.accept(expr.args[0])
325                end_reg = builder.accept(expr.args[1])
326            if len(expr.args) == 3:
327                step = builder.extract_int(expr.args[2])
328                assert step is not None
329                if step == 0:
330                    builder.error("range() step can't be zero", expr.args[2].line)
331            else:
332                step = 1
333
334            for_range = ForRange(builder, index, body_block, loop_exit, line, nested)
335            for_range.init(start_reg, end_reg, step)
336            return for_range
337
338        elif (expr.callee.fullname == 'builtins.enumerate'
339                and len(expr.args) == 1
340                and expr.arg_kinds == [ARG_POS]
341                and isinstance(index, TupleExpr)
342                and len(index.items) == 2):
343            # Special case "for i, x in enumerate(y)".
344            lvalue1 = index.items[0]
345            lvalue2 = index.items[1]
346            for_enumerate = ForEnumerate(builder, index, body_block, loop_exit, line,
347                                         nested)
348            for_enumerate.init(lvalue1, lvalue2, expr.args[0])
349            return for_enumerate
350
351        elif (expr.callee.fullname == 'builtins.zip'
352                and len(expr.args) >= 2
353                and set(expr.arg_kinds) == {ARG_POS}
354                and isinstance(index, TupleExpr)
355                and len(index.items) == len(expr.args)):
356            # Special case "for x, y in zip(a, b)".
357            for_zip = ForZip(builder, index, body_block, loop_exit, line, nested)
358            for_zip.init(index.items, expr.args)
359            return for_zip
360
361        if (expr.callee.fullname == 'builtins.reversed'
362                and len(expr.args) == 1
363                and expr.arg_kinds == [ARG_POS]
364                and is_sequence_rprimitive(builder.node_type(expr.args[0]))):
365            # Special case "for x in reversed(<list>)".
366            expr_reg = builder.accept(expr.args[0])
367            target_type = builder.get_sequence_type(expr)
368
369            for_list = ForSequence(builder, index, body_block, loop_exit, line, nested)
370            for_list.init(expr_reg, target_type, reverse=True)
371            return for_list
372    if (isinstance(expr, CallExpr)
373            and isinstance(expr.callee, MemberExpr)
374            and not expr.args):
375        # Special cases for dictionary iterator methods, like dict.items().
376        rtype = builder.node_type(expr.callee.expr)
377        if (is_dict_rprimitive(rtype)
378                and expr.callee.name in ('keys', 'values', 'items')):
379            expr_reg = builder.accept(expr.callee.expr)
380            for_dict_type = None  # type: Optional[Type[ForGenerator]]
381            if expr.callee.name == 'keys':
382                target_type = builder.get_dict_key_type(expr.callee.expr)
383                for_dict_type = ForDictionaryKeys
384            elif expr.callee.name == 'values':
385                target_type = builder.get_dict_value_type(expr.callee.expr)
386                for_dict_type = ForDictionaryValues
387            else:
388                target_type = builder.get_dict_item_type(expr.callee.expr)
389                for_dict_type = ForDictionaryItems
390            for_dict_gen = for_dict_type(builder, index, body_block, loop_exit, line, nested)
391            for_dict_gen.init(expr_reg, target_type)
392            return for_dict_gen
393
394    # Default to a generic for loop.
395    expr_reg = builder.accept(expr)
396    for_obj = ForIterable(builder, index, body_block, loop_exit, line, nested)
397    item_type = builder._analyze_iterable_item_type(expr)
398    item_rtype = builder.type_to_rtype(item_type)
399    for_obj.init(expr_reg, item_rtype)
400    return for_obj
401
402
403class ForGenerator:
404    """Abstract base class for generating for loops."""
405
406    def __init__(self,
407                 builder: IRBuilder,
408                 index: Lvalue,
409                 body_block: BasicBlock,
410                 loop_exit: BasicBlock,
411                 line: int,
412                 nested: bool) -> None:
413        self.builder = builder
414        self.index = index
415        self.body_block = body_block
416        self.line = line
417        # Some for loops need a cleanup block that we execute at exit. We
418        # create a cleanup block if needed. However, if we are generating a for
419        # loop for a nested iterator, such as "e" in "enumerate(e)", the
420        # outermost generator should generate the cleanup block -- we don't
421        # need to do it here.
422        if self.need_cleanup() and not nested:
423            # Create a new block to handle cleanup after loop exit.
424            self.loop_exit = BasicBlock()
425        else:
426            # Just use the existing loop exit block.
427            self.loop_exit = loop_exit
428
429    def need_cleanup(self) -> bool:
430        """If this returns true, we need post-loop cleanup."""
431        return False
432
433    def add_cleanup(self, exit_block: BasicBlock) -> None:
434        """Add post-loop cleanup, if needed."""
435        if self.need_cleanup():
436            self.builder.activate_block(self.loop_exit)
437            self.gen_cleanup()
438            self.builder.goto(exit_block)
439
440    def gen_condition(self) -> None:
441        """Generate check for loop exit (e.g. exhaustion of iteration)."""
442
443    def begin_body(self) -> None:
444        """Generate ops at the beginning of the body (if needed)."""
445
446    def gen_step(self) -> None:
447        """Generate stepping to the next item (if needed)."""
448
449    def gen_cleanup(self) -> None:
450        """Generate post-loop cleanup (if needed)."""
451
452    def load_len(self, expr: Union[Value, AssignmentTarget]) -> Value:
453        """A helper to get collection length, used by several subclasses."""
454        return self.builder.builder.builtin_len(self.builder.read(expr, self.line), self.line)
455
456
457class ForIterable(ForGenerator):
458    """Generate IR for a for loop over an arbitrary iterable (the normal case)."""
459
460    def need_cleanup(self) -> bool:
461        # Create a new cleanup block for when the loop is finished.
462        return True
463
464    def init(self, expr_reg: Value, target_type: RType) -> None:
465        # Define targets to contain the expression, along with the iterator that will be used
466        # for the for-loop. If we are inside of a generator function, spill these into the
467        # environment class.
468        builder = self.builder
469        iter_reg = builder.call_c(iter_op, [expr_reg], self.line)
470        builder.maybe_spill(expr_reg)
471        self.iter_target = builder.maybe_spill(iter_reg)
472        self.target_type = target_type
473
474    def gen_condition(self) -> None:
475        # We call __next__ on the iterator and check to see if the return value
476        # is NULL, which signals either the end of the Iterable being traversed
477        # or an exception being raised. Note that Branch.IS_ERROR checks only
478        # for NULL (an exception does not necessarily have to be raised).
479        builder = self.builder
480        line = self.line
481        self.next_reg = builder.call_c(next_op, [builder.read(self.iter_target, line)], line)
482        builder.add(Branch(self.next_reg, self.loop_exit, self.body_block, Branch.IS_ERROR))
483
484    def begin_body(self) -> None:
485        # Assign the value obtained from __next__ to the
486        # lvalue so that it can be referenced by code in the body of the loop.
487        builder = self.builder
488        line = self.line
489        # We unbox here so that iterating with tuple unpacking generates a tuple based
490        # unpack instead of an iterator based one.
491        next_reg = builder.coerce(self.next_reg, self.target_type, line)
492        builder.assign(builder.get_assignment_target(self.index), next_reg, line)
493
494    def gen_step(self) -> None:
495        # Nothing to do here, since we get the next item as part of gen_condition().
496        pass
497
498    def gen_cleanup(self) -> None:
499        # We set the branch to go here if the conditional evaluates to true. If
500        # an exception was raised during the loop, then err_reg wil be set to
501        # True. If no_err_occurred_op returns False, then the exception will be
502        # propagated using the ERR_FALSE flag.
503        self.builder.call_c(no_err_occurred_op, [], self.line)
504
505
506def unsafe_index(
507    builder: IRBuilder, target: Value, index: Value, line: int
508) -> Value:
509    """Emit a potentially unsafe index into a target."""
510    # This doesn't really fit nicely into any of our data-driven frameworks
511    # since we want to use __getitem__ if we don't have an unsafe version,
512    # so we just check manually.
513    if is_list_rprimitive(target.type):
514        return builder.call_c(list_get_item_unsafe_op, [target, index], line)
515    else:
516        return builder.gen_method_call(target, '__getitem__', [index], None, line)
517
518
519class ForSequence(ForGenerator):
520    """Generate optimized IR for a for loop over a sequence.
521
522    Supports iterating in both forward and reverse.
523    """
524
525    def init(self, expr_reg: Value, target_type: RType, reverse: bool) -> None:
526        builder = self.builder
527        self.reverse = reverse
528        # Define target to contain the expression, along with the index that will be used
529        # for the for-loop. If we are inside of a generator function, spill these into the
530        # environment class.
531        self.expr_target = builder.maybe_spill(expr_reg)
532        if not reverse:
533            index_reg = Integer(0)  # type: Value
534        else:
535            index_reg = builder.binary_op(self.load_len(self.expr_target),
536                                          Integer(1), '-', self.line)
537        self.index_target = builder.maybe_spill_assignable(index_reg)
538        self.target_type = target_type
539
540    def gen_condition(self) -> None:
541        builder = self.builder
542        line = self.line
543        # TODO: Don't reload the length each time when iterating an immutable sequence?
544        if self.reverse:
545            # If we are iterating in reverse order, we obviously need
546            # to check that the index is still positive. Somewhat less
547            # obviously we still need to check against the length,
548            # since it could shrink out from under us.
549            comparison = builder.binary_op(builder.read(self.index_target, line),
550                                           Integer(0), '>=', line)
551            second_check = BasicBlock()
552            builder.add_bool_branch(comparison, second_check, self.loop_exit)
553            builder.activate_block(second_check)
554        # For compatibility with python semantics we recalculate the length
555        # at every iteration.
556        len_reg = self.load_len(self.expr_target)
557        comparison = builder.binary_op(builder.read(self.index_target, line), len_reg, '<', line)
558        builder.add_bool_branch(comparison, self.body_block, self.loop_exit)
559
560    def begin_body(self) -> None:
561        builder = self.builder
562        line = self.line
563        # Read the next list item.
564        value_box = unsafe_index(
565            builder,
566            builder.read(self.expr_target, line),
567            builder.read(self.index_target, line),
568            line
569        )
570        assert value_box
571        # We coerce to the type of list elements here so that
572        # iterating with tuple unpacking generates a tuple based
573        # unpack instead of an iterator based one.
574        builder.assign(builder.get_assignment_target(self.index),
575                       builder.coerce(value_box, self.target_type, line), line)
576
577    def gen_step(self) -> None:
578        # Step to the next item.
579        builder = self.builder
580        line = self.line
581        step = 1 if not self.reverse else -1
582        add = builder.int_op(short_int_rprimitive,
583                             builder.read(self.index_target, line),
584                             Integer(step), IntOp.ADD, line)
585        builder.assign(self.index_target, add, line)
586
587
588class ForDictionaryCommon(ForGenerator):
589    """Generate optimized IR for a for loop over dictionary keys/values.
590
591    The logic is pretty straightforward, we use PyDict_Next() API wrapped in
592    a tuple, so that we can modify only a single register. The layout of the tuple:
593      * f0: are there more items (bool)
594      * f1: current offset (int)
595      * f2: next key (object)
596      * f3: next value (object)
597    For more info see https://docs.python.org/3/c-api/dict.html#c.PyDict_Next.
598
599    Note that for subclasses we fall back to generic PyObject_GetIter() logic,
600    since they may override some iteration methods in subtly incompatible manner.
601    The fallback logic is implemented in CPy.h via dynamic type check.
602    """
603    dict_next_op = None  # type: ClassVar[CFunctionDescription]
604    dict_iter_op = None  # type: ClassVar[CFunctionDescription]
605
606    def need_cleanup(self) -> bool:
607        # Technically, a dict subclass can raise an unrelated exception
608        # in __next__(), so we need this.
609        return True
610
611    def init(self, expr_reg: Value, target_type: RType) -> None:
612        builder = self.builder
613        self.target_type = target_type
614
615        # We add some variables to environment class, so they can be read across yield.
616        self.expr_target = builder.maybe_spill(expr_reg)
617        offset = Integer(0)
618        self.offset_target = builder.maybe_spill_assignable(offset)
619        self.size = builder.maybe_spill(self.load_len(self.expr_target))
620
621        # For dict class (not a subclass) this is the dictionary itself.
622        iter_reg = builder.call_c(self.dict_iter_op, [expr_reg], self.line)
623        self.iter_target = builder.maybe_spill(iter_reg)
624
625    def gen_condition(self) -> None:
626        """Get next key/value pair, set new offset, and check if we should continue."""
627        builder = self.builder
628        line = self.line
629        self.next_tuple = self.builder.call_c(
630            self.dict_next_op, [builder.read(self.iter_target, line),
631                                builder.read(self.offset_target, line)], line)
632
633        # Do this here instead of in gen_step() to minimize variables in environment.
634        new_offset = builder.add(TupleGet(self.next_tuple, 1, line))
635        builder.assign(self.offset_target, new_offset, line)
636
637        should_continue = builder.add(TupleGet(self.next_tuple, 0, line))
638        builder.add(
639            Branch(should_continue, self.body_block, self.loop_exit, Branch.BOOL)
640        )
641
642    def gen_step(self) -> None:
643        """Check that dictionary didn't change size during iteration.
644
645        Raise RuntimeError if it is not the case to match CPython behavior.
646        """
647        builder = self.builder
648        line = self.line
649        # Technically, we don't need a new primitive for this, but it is simpler.
650        builder.call_c(dict_check_size_op,
651                       [builder.read(self.expr_target, line),
652                        builder.read(self.size, line)], line)
653
654    def gen_cleanup(self) -> None:
655        # Same as for generic ForIterable.
656        self.builder.call_c(no_err_occurred_op, [], self.line)
657
658
659class ForDictionaryKeys(ForDictionaryCommon):
660    """Generate optimized IR for a for loop over dictionary keys."""
661    dict_next_op = dict_next_key_op
662    dict_iter_op = dict_key_iter_op
663
664    def begin_body(self) -> None:
665        builder = self.builder
666        line = self.line
667
668        # Key is stored at the third place in the tuple.
669        key = builder.add(TupleGet(self.next_tuple, 2, line))
670        builder.assign(builder.get_assignment_target(self.index),
671                       builder.coerce(key, self.target_type, line), line)
672
673
674class ForDictionaryValues(ForDictionaryCommon):
675    """Generate optimized IR for a for loop over dictionary values."""
676    dict_next_op = dict_next_value_op
677    dict_iter_op = dict_value_iter_op
678
679    def begin_body(self) -> None:
680        builder = self.builder
681        line = self.line
682
683        # Value is stored at the third place in the tuple.
684        value = builder.add(TupleGet(self.next_tuple, 2, line))
685        builder.assign(builder.get_assignment_target(self.index),
686                       builder.coerce(value, self.target_type, line), line)
687
688
689class ForDictionaryItems(ForDictionaryCommon):
690    """Generate optimized IR for a for loop over dictionary items."""
691    dict_next_op = dict_next_item_op
692    dict_iter_op = dict_item_iter_op
693
694    def begin_body(self) -> None:
695        builder = self.builder
696        line = self.line
697
698        key = builder.add(TupleGet(self.next_tuple, 2, line))
699        value = builder.add(TupleGet(self.next_tuple, 3, line))
700
701        # Coerce just in case e.g. key is itself a tuple to be unpacked.
702        assert isinstance(self.target_type, RTuple)
703        key = builder.coerce(key, self.target_type.types[0], line)
704        value = builder.coerce(value, self.target_type.types[1], line)
705
706        target = builder.get_assignment_target(self.index)
707        if isinstance(target, AssignmentTargetTuple):
708            # Simpler code for common case: for k, v in d.items().
709            if len(target.items) != 2:
710                builder.error("Expected a pair for dict item iteration", line)
711            builder.assign(target.items[0], key, line)
712            builder.assign(target.items[1], value, line)
713        else:
714            rvalue = builder.add(TupleSet([key, value], line))
715            builder.assign(target, rvalue, line)
716
717
718class ForRange(ForGenerator):
719    """Generate optimized IR for a for loop over an integer range."""
720
721    def init(self, start_reg: Value, end_reg: Value, step: int) -> None:
722        builder = self.builder
723        self.start_reg = start_reg
724        self.end_reg = end_reg
725        self.step = step
726        self.end_target = builder.maybe_spill(end_reg)
727        if is_short_int_rprimitive(start_reg.type) and is_short_int_rprimitive(end_reg.type):
728            index_type = short_int_rprimitive
729        else:
730            index_type = int_rprimitive
731        index_reg = Register(index_type)
732        builder.assign(index_reg, start_reg, -1)
733        self.index_reg = builder.maybe_spill_assignable(index_reg)
734        # Initialize loop index to 0. Assert that the index target is assignable.
735        self.index_target = builder.get_assignment_target(
736            self.index)  # type: Union[Register, AssignmentTarget]
737        builder.assign(self.index_target, builder.read(self.index_reg, self.line), self.line)
738
739    def gen_condition(self) -> None:
740        builder = self.builder
741        line = self.line
742        # Add loop condition check.
743        cmp = '<' if self.step > 0 else '>'
744        comparison = builder.binary_op(builder.read(self.index_reg, line),
745                                       builder.read(self.end_target, line), cmp, line)
746        builder.add_bool_branch(comparison, self.body_block, self.loop_exit)
747
748    def gen_step(self) -> None:
749        builder = self.builder
750        line = self.line
751
752        # Increment index register. If the range is known to fit in short ints, use
753        # short ints.
754        if (is_short_int_rprimitive(self.start_reg.type)
755                and is_short_int_rprimitive(self.end_reg.type)):
756            new_val = builder.int_op(short_int_rprimitive,
757                                     builder.read(self.index_reg, line),
758                                     Integer(self.step), IntOp.ADD, line)
759
760        else:
761            new_val = builder.binary_op(
762                builder.read(self.index_reg, line), Integer(self.step), '+', line)
763        builder.assign(self.index_reg, new_val, line)
764        builder.assign(self.index_target, new_val, line)
765
766
767class ForInfiniteCounter(ForGenerator):
768    """Generate optimized IR for a for loop counting from 0 to infinity."""
769
770    def init(self) -> None:
771        builder = self.builder
772        # Create a register to store the state of the loop index and
773        # initialize this register along with the loop index to 0.
774        zero = Integer(0)
775        self.index_reg = builder.maybe_spill_assignable(zero)
776        self.index_target = builder.get_assignment_target(
777            self.index)  # type: Union[Register, AssignmentTarget]
778        builder.assign(self.index_target, zero, self.line)
779
780    def gen_step(self) -> None:
781        builder = self.builder
782        line = self.line
783        # We can safely assume that the integer is short, since we are not going to wrap
784        # around a 63-bit integer.
785        # NOTE: This would be questionable if short ints could be 32 bits.
786        new_val = builder.int_op(short_int_rprimitive,
787                                 builder.read(self.index_reg, line),
788                                 Integer(1), IntOp.ADD, line)
789        builder.assign(self.index_reg, new_val, line)
790        builder.assign(self.index_target, new_val, line)
791
792
793class ForEnumerate(ForGenerator):
794    """Generate optimized IR for a for loop of form "for i, x in enumerate(it)"."""
795
796    def need_cleanup(self) -> bool:
797        # The wrapped for loop might need cleanup. This might generate a
798        # redundant cleanup block, but that's okay.
799        return True
800
801    def init(self, index1: Lvalue, index2: Lvalue, expr: Expression) -> None:
802        # Count from 0 to infinity (for the index lvalue).
803        self.index_gen = ForInfiniteCounter(
804            self.builder,
805            index1,
806            self.body_block,
807            self.loop_exit,
808            self.line, nested=True)
809        self.index_gen.init()
810        # Iterate over the actual iterable.
811        self.main_gen = make_for_loop_generator(
812            self.builder,
813            index2,
814            expr,
815            self.body_block,
816            self.loop_exit,
817            self.line, nested=True)
818
819    def gen_condition(self) -> None:
820        # No need for a check for the index generator, since it's unconditional.
821        self.main_gen.gen_condition()
822
823    def begin_body(self) -> None:
824        self.index_gen.begin_body()
825        self.main_gen.begin_body()
826
827    def gen_step(self) -> None:
828        self.index_gen.gen_step()
829        self.main_gen.gen_step()
830
831    def gen_cleanup(self) -> None:
832        self.index_gen.gen_cleanup()
833        self.main_gen.gen_cleanup()
834
835
836class ForZip(ForGenerator):
837    """Generate IR for a for loop of form `for x, ... in zip(a, ...)`."""
838
839    def need_cleanup(self) -> bool:
840        # The wrapped for loops might need cleanup. We might generate a
841        # redundant cleanup block, but that's okay.
842        return True
843
844    def init(self, indexes: List[Lvalue], exprs: List[Expression]) -> None:
845        assert len(indexes) == len(exprs)
846        # Condition check will require multiple basic blocks, since there will be
847        # multiple conditions to check.
848        self.cond_blocks = [BasicBlock() for _ in range(len(indexes) - 1)] + [self.body_block]
849        self.gens = []  # type: List[ForGenerator]
850        for index, expr, next_block in zip(indexes, exprs, self.cond_blocks):
851            gen = make_for_loop_generator(
852                self.builder,
853                index,
854                expr,
855                next_block,
856                self.loop_exit,
857                self.line, nested=True)
858            self.gens.append(gen)
859
860    def gen_condition(self) -> None:
861        for i, gen in enumerate(self.gens):
862            gen.gen_condition()
863            if i < len(self.gens) - 1:
864                self.builder.activate_block(self.cond_blocks[i])
865
866    def begin_body(self) -> None:
867        for gen in self.gens:
868            gen.begin_body()
869
870    def gen_step(self) -> None:
871        for gen in self.gens:
872            gen.gen_step()
873
874    def gen_cleanup(self) -> None:
875        for gen in self.gens:
876            gen.gen_cleanup()
877