1"""
2Implementation of various iterable and iterator types.
3"""
4
5from numba.core import types, cgutils
6from numba.core.imputils import (
7    lower_builtin, iternext_impl, call_iternext, call_getiter,
8    impl_ret_borrowed, impl_ret_new_ref, RefType)
9
10
11
12@lower_builtin('getiter', types.IteratorType)
13def iterator_getiter(context, builder, sig, args):
14    [it] = args
15    return impl_ret_borrowed(context, builder, sig.return_type, it)
16
17#-------------------------------------------------------------------------------
18# builtin `enumerate` implementation
19
20@lower_builtin(enumerate, types.IterableType)
21@lower_builtin(enumerate, types.IterableType, types.Integer)
22def make_enumerate_object(context, builder, sig, args):
23    assert len(args) == 1 or len(args) == 2 # enumerate(it) or enumerate(it, start)
24    srcty = sig.args[0]
25
26    if len(args) == 1:
27        src = args[0]
28        start_val = context.get_constant(types.intp, 0)
29    elif len(args) == 2:
30        src = args[0]
31        start_val = context.cast(builder, args[1], sig.args[1], types.intp)
32
33    iterobj = call_getiter(context, builder, srcty, src)
34
35    enum = context.make_helper(builder, sig.return_type)
36
37    countptr = cgutils.alloca_once(builder, start_val.type)
38    builder.store(start_val, countptr)
39
40    enum.count = countptr
41    enum.iter = iterobj
42
43    res = enum._getvalue()
44    return impl_ret_new_ref(context, builder, sig.return_type, res)
45
46@lower_builtin('iternext', types.EnumerateType)
47@iternext_impl(RefType.BORROWED)
48def iternext_enumerate(context, builder, sig, args, result):
49    [enumty] = sig.args
50    [enum] = args
51
52    enum = context.make_helper(builder, enumty, value=enum)
53
54    count = builder.load(enum.count)
55    ncount = builder.add(count, context.get_constant(types.intp, 1))
56    builder.store(ncount, enum.count)
57
58    srcres = call_iternext(context, builder, enumty.source_type, enum.iter)
59    is_valid = srcres.is_valid()
60    result.set_valid(is_valid)
61
62    with builder.if_then(is_valid):
63        srcval = srcres.yielded_value()
64        # As a iternext_impl function, this will incref the yielded value.
65        # We need to release the new reference from call_iternext.
66        if context.enable_nrt:
67            context.nrt.decref(builder, enumty.yield_type[1], srcval)
68        result.yield_(context.make_tuple(builder, enumty.yield_type,
69                                         [count, srcval]))
70
71
72#-------------------------------------------------------------------------------
73# builtin `zip` implementation
74
75@lower_builtin(zip, types.VarArg(types.Any))
76def make_zip_object(context, builder, sig, args):
77    zip_type = sig.return_type
78
79    assert len(args) == len(zip_type.source_types)
80
81    zipobj = context.make_helper(builder, zip_type)
82
83    for i, (arg, srcty) in enumerate(zip(args, sig.args)):
84        zipobj[i] = call_getiter(context, builder, srcty, arg)
85
86    res = zipobj._getvalue()
87    return impl_ret_new_ref(context, builder, sig.return_type, res)
88
89@lower_builtin('iternext', types.ZipType)
90@iternext_impl(RefType.NEW)
91def iternext_zip(context, builder, sig, args, result):
92    [zip_type] = sig.args
93    [zipobj] = args
94
95    zipobj = context.make_helper(builder, zip_type, value=zipobj)
96
97    if len(zipobj) == 0:
98        # zip() is an empty iterator
99        result.set_exhausted()
100        return
101
102    p_ret_tup = cgutils.alloca_once(builder,
103                                    context.get_value_type(zip_type.yield_type))
104    p_is_valid = cgutils.alloca_once_value(builder, value=cgutils.true_bit)
105
106    for i, (iterobj, srcty) in enumerate(zip(zipobj, zip_type.source_types)):
107        is_valid = builder.load(p_is_valid)
108        # Avoid calling the remaining iternext if a iterator has been exhausted
109        with builder.if_then(is_valid):
110            srcres = call_iternext(context, builder, srcty, iterobj)
111            is_valid = builder.and_(is_valid, srcres.is_valid())
112            builder.store(is_valid, p_is_valid)
113            val = srcres.yielded_value()
114            ptr = cgutils.gep_inbounds(builder, p_ret_tup, 0, i)
115            builder.store(val, ptr)
116
117    is_valid = builder.load(p_is_valid)
118    result.set_valid(is_valid)
119
120    with builder.if_then(is_valid):
121        result.yield_(builder.load(p_ret_tup))
122
123
124#-------------------------------------------------------------------------------
125# generator implementation
126
127@lower_builtin('iternext', types.Generator)
128@iternext_impl(RefType.BORROWED)
129def iternext_zip(context, builder, sig, args, result):
130    genty, = sig.args
131    gen, = args
132    impl = context.get_generator_impl(genty)
133    status, retval = impl(context, builder, sig, args)
134    context.add_linking_libs(getattr(impl, 'libs', ()))
135
136    with cgutils.if_likely(builder, status.is_ok):
137        result.set_valid(True)
138        result.yield_(retval)
139    with cgutils.if_unlikely(builder, status.is_stop_iteration):
140        result.set_exhausted()
141    with cgutils.if_unlikely(builder,
142                             builder.and_(status.is_error,
143                                          builder.not_(status.is_stop_iteration))):
144        context.call_conv.return_status_propagate(builder, status)
145