1import itertools
2
3import numpy as np
4import operator
5
6from numba.core import types, errors
7from numba import prange
8from numba.parfors.parfor import internal_prange
9
10from numba.core.utils import RANGE_ITER_OBJECTS
11from numba.core.typing.templates import (AttributeTemplate, ConcreteTemplate,
12                                         AbstractTemplate, infer_global, infer,
13                                         infer_getattr, signature,
14                                         bound_function, make_callable_template)
15
16from numba.cpython.builtins import get_type_min_value, get_type_max_value
17
18from numba.core.extending import (
19    typeof_impl, type_callable, models, register_model, make_attribute_wrapper,
20    )
21
22
23@infer_global(print)
24class Print(AbstractTemplate):
25    def generic(self, args, kws):
26        for a in args:
27            sig = self.context.resolve_function_type("print_item", (a,), {})
28            if sig is None:
29                raise TypeError("Type %s is not printable." % a)
30            assert sig.return_type is types.none
31        return signature(types.none, *args)
32
33@infer
34class PrintItem(AbstractTemplate):
35    key = "print_item"
36
37    def generic(self, args, kws):
38        arg, = args
39        return signature(types.none, *args)
40
41
42@infer_global(abs)
43class Abs(ConcreteTemplate):
44    int_cases = [signature(ty, ty) for ty in sorted(types.signed_domain)]
45    uint_cases = [signature(ty, ty) for ty in sorted(types.unsigned_domain)]
46    real_cases = [signature(ty, ty) for ty in sorted(types.real_domain)]
47    complex_cases = [signature(ty.underlying_float, ty)
48                     for ty in sorted(types.complex_domain)]
49    cases = int_cases + uint_cases +  real_cases + complex_cases
50
51
52@infer_global(slice)
53class Slice(ConcreteTemplate):
54    cases = [
55        signature(types.slice2_type, types.intp),
56        signature(types.slice2_type, types.none),
57        signature(types.slice2_type, types.none, types.none),
58        signature(types.slice2_type, types.none, types.intp),
59        signature(types.slice2_type, types.intp, types.none),
60        signature(types.slice2_type, types.intp, types.intp),
61        signature(types.slice3_type, types.intp, types.intp, types.intp),
62        signature(types.slice3_type, types.none, types.intp, types.intp),
63        signature(types.slice3_type, types.intp, types.none, types.intp),
64        signature(types.slice3_type, types.intp, types.intp, types.none),
65        signature(types.slice3_type, types.intp, types.none, types.none),
66        signature(types.slice3_type, types.none, types.intp, types.none),
67        signature(types.slice3_type, types.none, types.none, types.intp),
68        signature(types.slice3_type, types.none, types.none, types.none),
69    ]
70
71
72class Range(ConcreteTemplate):
73    cases = [
74        signature(types.range_state32_type, types.int32),
75        signature(types.range_state32_type, types.int32, types.int32),
76        signature(types.range_state32_type, types.int32, types.int32,
77                  types.int32),
78        signature(types.range_state64_type, types.int64),
79        signature(types.range_state64_type, types.int64, types.int64),
80        signature(types.range_state64_type, types.int64, types.int64,
81                  types.int64),
82        signature(types.unsigned_range_state64_type, types.uint64),
83        signature(types.unsigned_range_state64_type, types.uint64, types.uint64),
84        signature(types.unsigned_range_state64_type, types.uint64, types.uint64,
85                  types.uint64),
86    ]
87
88for func in RANGE_ITER_OBJECTS:
89    infer_global(func, typing_key=range)(Range)
90
91infer_global(prange, typing_key=prange)(Range)
92infer_global(internal_prange, typing_key=internal_prange)(Range)
93
94@infer
95class GetIter(AbstractTemplate):
96    key = "getiter"
97
98    def generic(self, args, kws):
99        assert not kws
100        [obj] = args
101        if isinstance(obj, types.IterableType):
102            return signature(obj.iterator_type, obj)
103
104
105@infer
106class IterNext(AbstractTemplate):
107    key = "iternext"
108
109    def generic(self, args, kws):
110        assert not kws
111        [it] = args
112        if isinstance(it, types.IteratorType):
113            return signature(types.Pair(it.yield_type, types.boolean), it)
114
115
116@infer
117class PairFirst(AbstractTemplate):
118    """
119    Given a heterogeneous pair, return the first element.
120    """
121    key = "pair_first"
122
123    def generic(self, args, kws):
124        assert not kws
125        [pair] = args
126        if isinstance(pair, types.Pair):
127            return signature(pair.first_type, pair)
128
129
130@infer
131class PairSecond(AbstractTemplate):
132    """
133    Given a heterogeneous pair, return the second element.
134    """
135    key = "pair_second"
136
137    def generic(self, args, kws):
138        assert not kws
139        [pair] = args
140        if isinstance(pair, types.Pair):
141            return signature(pair.second_type, pair)
142
143
144def choose_result_bitwidth(*inputs):
145    return max(types.intp.bitwidth, *(tp.bitwidth for tp in inputs))
146
147def choose_result_int(*inputs):
148    """
149    Choose the integer result type for an operation on integer inputs,
150    according to the integer typing NBEP.
151    """
152    bitwidth = choose_result_bitwidth(*inputs)
153    signed = any(tp.signed for tp in inputs)
154    return types.Integer.from_bitwidth(bitwidth, signed)
155
156
157# The "machine" integer types to take into consideration for operator typing
158# (according to the integer typing NBEP)
159machine_ints = (
160    sorted(set((types.intp, types.int64))) +
161    sorted(set((types.uintp, types.uint64)))
162    )
163
164# Explicit integer rules for binary operators; smaller ints will be
165# automatically upcast.
166integer_binop_cases = tuple(
167    signature(choose_result_int(op1, op2), op1, op2)
168    for op1, op2 in itertools.product(machine_ints, machine_ints)
169    )
170
171
172class BinOp(ConcreteTemplate):
173    cases = list(integer_binop_cases)
174    cases += [signature(op, op, op) for op in sorted(types.real_domain)]
175    cases += [signature(op, op, op) for op in sorted(types.complex_domain)]
176
177
178@infer_global(operator.add)
179class BinOpAdd(BinOp):
180    pass
181
182
183@infer_global(operator.iadd)
184class BinOpAdd(BinOp):
185    pass
186
187
188@infer_global(operator.sub)
189class BinOpSub(BinOp):
190    pass
191
192
193@infer_global(operator.isub)
194class BinOpSub(BinOp):
195    pass
196
197
198@infer_global(operator.mul)
199class BinOpMul(BinOp):
200    pass
201
202
203@infer_global(operator.imul)
204class BinOpMul(BinOp):
205    pass
206
207
208@infer_global(operator.mod)
209class BinOpMod(ConcreteTemplate):
210    cases = list(integer_binop_cases)
211    cases += [signature(op, op, op) for op in sorted(types.real_domain)]
212
213
214@infer_global(operator.imod)
215class BinOpMod(ConcreteTemplate):
216    cases = list(integer_binop_cases)
217    cases += [signature(op, op, op) for op in sorted(types.real_domain)]
218
219
220@infer_global(operator.truediv)
221class BinOpTrueDiv(ConcreteTemplate):
222    cases = [signature(types.float64, op1, op2)
223             for op1, op2 in itertools.product(machine_ints, machine_ints)]
224    cases += [signature(op, op, op) for op in sorted(types.real_domain)]
225    cases += [signature(op, op, op) for op in sorted(types.complex_domain)]
226
227
228@infer_global(operator.itruediv)
229class BinOpTrueDiv(ConcreteTemplate):
230    cases = [signature(types.float64, op1, op2)
231             for op1, op2 in itertools.product(machine_ints, machine_ints)]
232    cases += [signature(op, op, op) for op in sorted(types.real_domain)]
233    cases += [signature(op, op, op) for op in sorted(types.complex_domain)]
234
235
236@infer_global(operator.floordiv)
237class BinOpFloorDiv(ConcreteTemplate):
238    cases = list(integer_binop_cases)
239    cases += [signature(op, op, op) for op in sorted(types.real_domain)]
240
241
242@infer_global(operator.ifloordiv)
243class BinOpFloorDiv(ConcreteTemplate):
244    cases = list(integer_binop_cases)
245    cases += [signature(op, op, op) for op in sorted(types.real_domain)]
246
247
248@infer_global(divmod)
249class DivMod(ConcreteTemplate):
250    _tys = machine_ints + sorted(types.real_domain)
251    cases = [signature(types.UniTuple(ty, 2), ty, ty) for ty in _tys]
252
253
254@infer_global(operator.pow)
255class BinOpPower(ConcreteTemplate):
256    cases = list(integer_binop_cases)
257    # Ensure that float32 ** int doesn't go through DP computations
258    cases += [signature(types.float32, types.float32, op)
259              for op in (types.int32, types.int64, types.uint64)]
260    cases += [signature(types.float64, types.float64, op)
261              for op in (types.int32, types.int64, types.uint64)]
262    cases += [signature(op, op, op)
263              for op in sorted(types.real_domain)]
264    cases += [signature(op, op, op)
265              for op in sorted(types.complex_domain)]
266
267
268@infer_global(operator.ipow)
269class BinOpPower(ConcreteTemplate):
270    cases = list(integer_binop_cases)
271    # Ensure that float32 ** int doesn't go through DP computations
272    cases += [signature(types.float32, types.float32, op)
273              for op in (types.int32, types.int64, types.uint64)]
274    cases += [signature(types.float64, types.float64, op)
275              for op in (types.int32, types.int64, types.uint64)]
276    cases += [signature(op, op, op)
277              for op in sorted(types.real_domain)]
278    cases += [signature(op, op, op)
279              for op in sorted(types.complex_domain)]
280
281
282@infer_global(pow)
283class PowerBuiltin(BinOpPower):
284    # TODO add 3 operand version
285    pass
286
287
288class BitwiseShiftOperation(ConcreteTemplate):
289    # For bitshifts, only the first operand's signedness matters
290    # to choose the operation's signedness (the second operand
291    # should always be positive but will generally be considered
292    # signed anyway, since it's often a constant integer).
293    # (also, see issue #1995 for right-shifts)
294
295    # The RHS type is fixed to 64-bit signed/unsigned ints.
296    # The implementation will always cast the operands to the width of the
297    # result type, which is the widest between the LHS type and (u)intp.
298    cases = [signature(max(op, types.intp), op, op2)
299             for op in sorted(types.signed_domain)
300             for op2 in [types.uint64, types.int64]]
301    cases += [signature(max(op, types.uintp), op, op2)
302              for op in sorted(types.unsigned_domain)
303              for op2 in [types.uint64, types.int64]]
304    unsafe_casting = False
305
306
307@infer_global(operator.lshift)
308class BitwiseLeftShift(BitwiseShiftOperation):
309    pass
310
311@infer_global(operator.ilshift)
312class BitwiseLeftShift(BitwiseShiftOperation):
313    pass
314
315
316@infer_global(operator.rshift)
317class BitwiseRightShift(BitwiseShiftOperation):
318    pass
319
320
321@infer_global(operator.irshift)
322class BitwiseRightShift(BitwiseShiftOperation):
323    pass
324
325
326class BitwiseLogicOperation(BinOp):
327    cases = [signature(types.boolean, types.boolean, types.boolean)]
328    cases += list(integer_binop_cases)
329    unsafe_casting = False
330
331
332@infer_global(operator.and_)
333class BitwiseAnd(BitwiseLogicOperation):
334    pass
335
336
337@infer_global(operator.iand)
338class BitwiseAnd(BitwiseLogicOperation):
339    pass
340
341
342@infer_global(operator.or_)
343class BitwiseOr(BitwiseLogicOperation):
344    pass
345
346
347@infer_global(operator.ior)
348class BitwiseOr(BitwiseLogicOperation):
349    pass
350
351
352@infer_global(operator.xor)
353class BitwiseXor(BitwiseLogicOperation):
354    pass
355
356
357@infer_global(operator.ixor)
358class BitwiseXor(BitwiseLogicOperation):
359    pass
360
361
362# Bitwise invert and negate are special: we must not upcast the operand
363# for unsigned numbers, as that would change the result.
364# (i.e. ~np.int8(0) == 255 but ~np.int32(0) == 4294967295).
365
366@infer_global(operator.invert)
367class BitwiseInvert(ConcreteTemplate):
368    # Note Numba follows the Numpy semantics of returning a bool,
369    # while Python returns an int.  This makes it consistent with
370    # np.invert() and makes array expressions correct.
371    cases = [signature(types.boolean, types.boolean)]
372    cases += [signature(choose_result_int(op), op) for op in sorted(types.unsigned_domain)]
373    cases += [signature(choose_result_int(op), op) for op in sorted(types.signed_domain)]
374
375    unsafe_casting = False
376
377
378class UnaryOp(ConcreteTemplate):
379    cases = [signature(choose_result_int(op), op) for op in sorted(types.unsigned_domain)]
380    cases += [signature(choose_result_int(op), op) for op in sorted(types.signed_domain)]
381    cases += [signature(op, op) for op in sorted(types.real_domain)]
382    cases += [signature(op, op) for op in sorted(types.complex_domain)]
383    cases += [signature(types.intp, types.boolean)]
384
385
386@infer_global(operator.neg)
387class UnaryNegate(UnaryOp):
388    pass
389
390
391@infer_global(operator.pos)
392class UnaryPositive(UnaryOp):
393   pass
394
395
396@infer_global(operator.not_)
397class UnaryNot(ConcreteTemplate):
398    cases = [signature(types.boolean, types.boolean)]
399    cases += [signature(types.boolean, op) for op in sorted(types.signed_domain)]
400    cases += [signature(types.boolean, op) for op in sorted(types.unsigned_domain)]
401    cases += [signature(types.boolean, op) for op in sorted(types.real_domain)]
402    cases += [signature(types.boolean, op) for op in sorted(types.complex_domain)]
403
404
405class OrderedCmpOp(ConcreteTemplate):
406    cases = [signature(types.boolean, types.boolean, types.boolean)]
407    cases += [signature(types.boolean, op, op) for op in sorted(types.signed_domain)]
408    cases += [signature(types.boolean, op, op) for op in sorted(types.unsigned_domain)]
409    cases += [signature(types.boolean, op, op) for op in sorted(types.real_domain)]
410
411
412class UnorderedCmpOp(ConcreteTemplate):
413    cases = OrderedCmpOp.cases + [
414        signature(types.boolean, op, op) for op in sorted(types.complex_domain)]
415
416
417@infer_global(operator.lt)
418class CmpOpLt(OrderedCmpOp):
419    pass
420
421
422@infer_global(operator.le)
423class CmpOpLe(OrderedCmpOp):
424    pass
425
426
427@infer_global(operator.gt)
428class CmpOpGt(OrderedCmpOp):
429    pass
430
431
432@infer_global(operator.ge)
433class CmpOpGe(OrderedCmpOp):
434    pass
435
436
437@infer_global(operator.eq)
438class CmpOpEq(UnorderedCmpOp):
439    pass
440
441
442@infer_global(operator.eq)
443class ConstOpEq(AbstractTemplate):
444    def generic(self, args, kws):
445        assert not kws
446        (arg1, arg2) = args
447        if isinstance(arg1, types.Literal) and isinstance(arg2, types.Literal):
448            return signature(types.boolean, arg1, arg2)
449
450
451@infer_global(operator.ne)
452class ConstOpNotEq(ConstOpEq):
453    pass
454
455
456@infer_global(operator.ne)
457class CmpOpNe(UnorderedCmpOp):
458    pass
459
460
461class TupleCompare(AbstractTemplate):
462    def generic(self, args, kws):
463        [lhs, rhs] = args
464        if isinstance(lhs, types.BaseTuple) and isinstance(rhs, types.BaseTuple):
465            for u, v in zip(lhs, rhs):
466                # Check element-wise comparability
467                res = self.context.resolve_function_type(self.key, (u, v), {})
468                if res is None:
469                    break
470            else:
471                return signature(types.boolean, lhs, rhs)
472
473
474@infer_global(operator.eq)
475class TupleEq(TupleCompare):
476    pass
477
478
479@infer_global(operator.ne)
480class TupleNe(TupleCompare):
481    pass
482
483
484@infer_global(operator.ge)
485class TupleGe(TupleCompare):
486    pass
487
488
489@infer_global(operator.gt)
490class TupleGt(TupleCompare):
491    pass
492
493
494@infer_global(operator.le)
495class TupleLe(TupleCompare):
496    pass
497
498
499@infer_global(operator.lt)
500class TupleLt(TupleCompare):
501    pass
502
503
504@infer_global(operator.add)
505class TupleAdd(AbstractTemplate):
506    def generic(self, args, kws):
507        if len(args) == 2:
508            a, b = args
509            if (isinstance(a, types.BaseTuple) and isinstance(b, types.BaseTuple)
510                and not isinstance(a, types.BaseNamedTuple)
511                and not isinstance(b, types.BaseNamedTuple)):
512                res = types.BaseTuple.from_types(tuple(a) + tuple(b))
513                return signature(res, a, b)
514
515
516class CmpOpIdentity(AbstractTemplate):
517    def generic(self, args, kws):
518        [lhs, rhs] = args
519        return signature(types.boolean, lhs, rhs)
520
521
522@infer_global(operator.is_)
523class CmpOpIs(CmpOpIdentity):
524    pass
525
526
527@infer_global(operator.is_not)
528class CmpOpIsNot(CmpOpIdentity):
529    pass
530
531
532def normalize_1d_index(index):
533    """
534    Normalize the *index* type (an integer or slice) for indexing a 1D
535    sequence.
536    """
537    if isinstance(index, types.SliceType):
538        return index
539
540    elif isinstance(index, types.Integer):
541        return types.intp if index.signed else types.uintp
542
543
544@infer_global(operator.getitem)
545class GetItemCPointer(AbstractTemplate):
546    def generic(self, args, kws):
547        assert not kws
548        ptr, idx = args
549        if isinstance(ptr, types.CPointer) and isinstance(idx, types.Integer):
550            return signature(ptr.dtype, ptr, normalize_1d_index(idx))
551
552
553@infer_global(operator.setitem)
554class SetItemCPointer(AbstractTemplate):
555    def generic(self, args, kws):
556        assert not kws
557        ptr, idx, val = args
558        if isinstance(ptr, types.CPointer) and isinstance(idx, types.Integer):
559            return signature(types.none, ptr, normalize_1d_index(idx), ptr.dtype)
560
561
562@infer_global(len)
563class Len(AbstractTemplate):
564    def generic(self, args, kws):
565        assert not kws
566        (val,) = args
567        if isinstance(val, (types.Buffer, types.BaseTuple)):
568            return signature(types.intp, val)
569        elif isinstance(val, (types.RangeType)):
570            return signature(val.dtype, val)
571
572@infer_global(tuple)
573class TupleConstructor(AbstractTemplate):
574    def generic(self, args, kws):
575        assert not kws
576        # empty tuple case
577        if len(args) == 0:
578            return signature(types.Tuple(()))
579        (val,) = args
580        # tuple as input
581        if isinstance(val, types.BaseTuple):
582            return signature(val, val)
583
584
585@infer_global(operator.contains)
586class Contains(AbstractTemplate):
587    def generic(self, args, kws):
588        assert not kws
589        (seq, val) = args
590
591        if isinstance(seq, (types.Sequence)):
592            return signature(types.boolean, seq, val)
593
594@infer_global(operator.truth)
595class TupleBool(AbstractTemplate):
596    def generic(self, args, kws):
597        assert not kws
598        (val,) = args
599        if isinstance(val, (types.BaseTuple)):
600            return signature(types.boolean, val)
601
602
603@infer
604class StaticGetItemTuple(AbstractTemplate):
605    key = "static_getitem"
606
607    def generic(self, args, kws):
608        tup, idx = args
609        ret = None
610        if not isinstance(tup, types.BaseTuple):
611            return
612        if isinstance(idx, int):
613            ret = tup.types[idx]
614        elif isinstance(idx, slice):
615            ret = types.BaseTuple.from_types(tup.types[idx])
616        if ret is not None:
617            sig = signature(ret, *args)
618            return sig
619
620
621@infer
622class StaticGetItemLiteralList(AbstractTemplate):
623    key = "static_getitem"
624
625    def generic(self, args, kws):
626        tup, idx = args
627        ret = None
628        if not isinstance(tup, types.LiteralList):
629            return
630        if isinstance(idx, int):
631            ret = tup.types[idx]
632        if ret is not None:
633            sig = signature(ret, *args)
634            return sig
635
636
637@infer
638class StaticGetItemLiteralStrKeyDict(AbstractTemplate):
639    key = "static_getitem"
640
641    def generic(self, args, kws):
642        tup, idx = args
643        ret = None
644        if not isinstance(tup, types.LiteralStrKeyDict):
645            return
646        if isinstance(idx, str):
647            lookup = tup.fields.index(idx)
648            ret = tup.types[lookup]
649        if ret is not None:
650            sig = signature(ret, *args)
651            return sig
652
653# Generic implementation for "not in"
654
655@infer
656class GenericNotIn(AbstractTemplate):
657    key = "not in"
658
659    def generic(self, args, kws):
660        args = args[::-1]
661        sig = self.context.resolve_function_type(operator.contains, args, kws)
662        return signature(sig.return_type, *sig.args[::-1])
663
664
665#-------------------------------------------------------------------------------
666
667@infer_getattr
668class MemoryViewAttribute(AttributeTemplate):
669    key = types.MemoryView
670
671    def resolve_contiguous(self, buf):
672        return types.boolean
673
674    def resolve_c_contiguous(self, buf):
675        return types.boolean
676
677    def resolve_f_contiguous(self, buf):
678        return types.boolean
679
680    def resolve_itemsize(self, buf):
681        return types.intp
682
683    def resolve_nbytes(self, buf):
684        return types.intp
685
686    def resolve_readonly(self, buf):
687        return types.boolean
688
689    def resolve_shape(self, buf):
690        return types.UniTuple(types.intp, buf.ndim)
691
692    def resolve_strides(self, buf):
693        return types.UniTuple(types.intp, buf.ndim)
694
695    def resolve_ndim(self, buf):
696        return types.intp
697
698
699#-------------------------------------------------------------------------------
700
701
702@infer_getattr
703class BooleanAttribute(AttributeTemplate):
704    key = types.Boolean
705
706    def resolve___class__(self, ty):
707        return types.NumberClass(ty)
708
709    @bound_function("number.item")
710    def resolve_item(self, ty, args, kws):
711        assert not kws
712        if not args:
713            return signature(ty)
714
715
716@infer_getattr
717class NumberAttribute(AttributeTemplate):
718    key = types.Number
719
720    def resolve___class__(self, ty):
721        return types.NumberClass(ty)
722
723    def resolve_real(self, ty):
724        return getattr(ty, "underlying_float", ty)
725
726    def resolve_imag(self, ty):
727        return getattr(ty, "underlying_float", ty)
728
729    @bound_function("complex.conjugate")
730    def resolve_conjugate(self, ty, args, kws):
731        assert not args
732        assert not kws
733        return signature(ty)
734
735    @bound_function("number.item")
736    def resolve_item(self, ty, args, kws):
737        assert not kws
738        if not args:
739            return signature(ty)
740
741
742@infer_getattr
743class NPTimedeltaAttribute(AttributeTemplate):
744    key = types.NPTimedelta
745
746    def resolve___class__(self, ty):
747        return types.NumberClass(ty)
748
749
750@infer_getattr
751class NPDatetimeAttribute(AttributeTemplate):
752    key = types.NPDatetime
753
754    def resolve___class__(self, ty):
755        return types.NumberClass(ty)
756
757
758@infer_getattr
759class SliceAttribute(AttributeTemplate):
760    key = types.SliceType
761
762    def resolve_start(self, ty):
763        return types.intp
764
765    def resolve_stop(self, ty):
766        return types.intp
767
768    def resolve_step(self, ty):
769        return types.intp
770
771    @bound_function("slice.indices")
772    def resolve_indices(self, ty, args, kws):
773        assert not kws
774        if len(args) != 1:
775            raise TypeError(
776                "indices() takes exactly one argument (%d given)" % len(args)
777            )
778        typ, = args
779        if not isinstance(typ, types.Integer):
780            raise TypeError(
781                "'%s' object cannot be interpreted as an integer" % typ
782            )
783        return signature(types.UniTuple(types.intp, 3), types.intp)
784
785
786#-------------------------------------------------------------------------------
787
788
789@infer_getattr
790class NumberClassAttribute(AttributeTemplate):
791    key = types.NumberClass
792
793    def resolve___call__(self, classty):
794        """
795        Resolve a number class's constructor (e.g. calling int(...))
796        """
797        ty = classty.instance_type
798
799        def typer(val):
800            if isinstance(val, (types.BaseTuple, types.Sequence)):
801                # Array constructor, e.g. np.int32([1, 2])
802                sig = self.context.resolve_function_type(
803                    np.array, (val,), {'dtype': types.DType(ty)})
804                return sig.return_type
805            else:
806                # Scalar constructor, e.g. np.int32(42)
807                return ty
808
809        return types.Function(make_callable_template(key=ty, typer=typer))
810
811
812@infer_getattr
813class TypeRefAttribute(AttributeTemplate):
814    key = types.TypeRef
815
816    def resolve___call__(self, classty):
817        """
818        Resolve a number class's constructor (e.g. calling int(...))
819
820        Note:
821
822        This is needed because of the limitation of the current type-system
823        implementation.  Specifically, the lack of a higher-order type
824        (i.e. passing the ``DictType`` vs ``DictType(key_type, value_type)``)
825        """
826        ty = classty.instance_type
827
828        if isinstance(ty, type) and issubclass(ty, types.Type):
829            # Redirect the typing to a:
830            #   @type_callable(ty)
831            #   def typeddict_call(context):
832            #        ...
833            # For example, see numba/typed/typeddict.py
834            #   @type_callable(DictType)
835            #   def typeddict_call(context):
836            class Redirect(object):
837
838                def __init__(self, context):
839                    self.context =  context
840
841                def __call__(self, *args, **kwargs):
842                    result = self.context.resolve_function_type(ty, args, kwargs)
843                    if hasattr(result, "pysig"):
844                        self.pysig = result.pysig
845                    return result
846
847            return types.Function(make_callable_template(key=ty,
848                                                         typer=Redirect(self.context)))
849
850
851#------------------------------------------------------------------------------
852
853
854class MinMaxBase(AbstractTemplate):
855
856    def _unify_minmax(self, tys):
857        for ty in tys:
858            if not isinstance(ty, types.Number):
859                return
860        return self.context.unify_types(*tys)
861
862    def generic(self, args, kws):
863        """
864        Resolve a min() or max() call.
865        """
866        assert not kws
867
868        if not args:
869            return
870        if len(args) == 1:
871            # max(arg) only supported if arg is an iterable
872            if isinstance(args[0], types.BaseTuple):
873                tys = list(args[0])
874                if not tys:
875                    raise TypeError("%s() argument is an empty tuple"
876                                    % (self.key.__name__,))
877            else:
878                return
879        else:
880            # max(*args)
881            tys = args
882        retty = self._unify_minmax(tys)
883        if retty is not None:
884            return signature(retty, *args)
885
886
887@infer_global(max)
888class Max(MinMaxBase):
889    pass
890
891
892@infer_global(min)
893class Min(MinMaxBase):
894    pass
895
896
897@infer_global(round)
898class Round(ConcreteTemplate):
899    cases = [
900        signature(types.intp, types.float32),
901        signature(types.int64, types.float64),
902        signature(types.float32, types.float32, types.intp),
903        signature(types.float64, types.float64, types.intp),
904    ]
905
906
907#------------------------------------------------------------------------------
908
909
910@infer_global(bool)
911class Bool(AbstractTemplate):
912
913    def generic(self, args, kws):
914        assert not kws
915        [arg] = args
916        if isinstance(arg, (types.Boolean, types.Number)):
917            return signature(types.boolean, arg)
918        # XXX typing for bool cannot be polymorphic because of the
919        # types.Function thing, so we redirect to the operator.truth
920        # intrinsic.
921        return self.context.resolve_function_type(operator.truth, args, kws)
922
923
924@infer_global(int)
925class Int(AbstractTemplate):
926
927    def generic(self, args, kws):
928        assert not kws
929
930        [arg] = args
931
932        if isinstance(arg, types.Integer):
933            return signature(arg, arg)
934        if isinstance(arg, (types.Float, types.Boolean)):
935            return signature(types.intp, arg)
936
937
938@infer_global(float)
939class Float(AbstractTemplate):
940
941    def generic(self, args, kws):
942        assert not kws
943
944        [arg] = args
945
946        if arg not in types.number_domain:
947            raise TypeError("float() only support for numbers")
948
949        if arg in types.complex_domain:
950            raise TypeError("float() does not support complex")
951
952        if arg in types.integer_domain:
953            return signature(types.float64, arg)
954
955        elif arg in types.real_domain:
956            return signature(arg, arg)
957
958
959@infer_global(complex)
960class Complex(AbstractTemplate):
961
962    def generic(self, args, kws):
963        assert not kws
964
965        if len(args) == 1:
966            [arg] = args
967            if arg not in types.number_domain:
968                raise TypeError("complex() only support for numbers")
969            if arg == types.float32:
970                return signature(types.complex64, arg)
971            else:
972                return signature(types.complex128, arg)
973
974        elif len(args) == 2:
975            [real, imag] = args
976            if (real not in types.number_domain or
977                imag not in types.number_domain):
978                raise TypeError("complex() only support for numbers")
979            if real == imag == types.float32:
980                return signature(types.complex64, real, imag)
981            else:
982                return signature(types.complex128, real, imag)
983
984
985#------------------------------------------------------------------------------
986
987@infer_global(enumerate)
988class Enumerate(AbstractTemplate):
989
990    def generic(self, args, kws):
991        assert not kws
992        it = args[0]
993        if len(args) > 1 and not isinstance(args[1], types.Integer):
994            raise TypeError("Only integers supported as start value in "
995                            "enumerate")
996        elif len(args) > 2:
997            #let python raise its own error
998            enumerate(*args)
999
1000        if isinstance(it, types.IterableType):
1001            enumerate_type = types.EnumerateType(it)
1002            return signature(enumerate_type, *args)
1003
1004
1005@infer_global(zip)
1006class Zip(AbstractTemplate):
1007
1008    def generic(self, args, kws):
1009        assert not kws
1010        if all(isinstance(it, types.IterableType) for it in args):
1011            zip_type = types.ZipType(args)
1012            return signature(zip_type, *args)
1013
1014
1015@infer_global(iter)
1016class Iter(AbstractTemplate):
1017
1018    def generic(self, args, kws):
1019        assert not kws
1020        if len(args) == 1:
1021            it = args[0]
1022            if isinstance(it, types.IterableType):
1023                return signature(it.iterator_type, *args)
1024
1025
1026@infer_global(next)
1027class Next(AbstractTemplate):
1028
1029    def generic(self, args, kws):
1030        assert not kws
1031        if len(args) == 1:
1032            it = args[0]
1033            if isinstance(it, types.IteratorType):
1034                return signature(it.yield_type, *args)
1035
1036
1037#------------------------------------------------------------------------------
1038
1039@infer_global(type)
1040class TypeBuiltin(AbstractTemplate):
1041
1042    def generic(self, args, kws):
1043        assert not kws
1044        if len(args) == 1:
1045            # One-argument type() -> return the __class__
1046            # Avoid literal types
1047            arg = types.unliteral(args[0])
1048            classty = self.context.resolve_getattr(arg, "__class__")
1049            if classty is not None:
1050                return signature(classty, *args)
1051
1052
1053#------------------------------------------------------------------------------
1054
1055@infer_getattr
1056class OptionalAttribute(AttributeTemplate):
1057    key = types.Optional
1058
1059    def generic_resolve(self, optional, attr):
1060        return self.context.resolve_getattr(optional.type, attr)
1061
1062#------------------------------------------------------------------------------
1063
1064@infer_getattr
1065class DeferredAttribute(AttributeTemplate):
1066    key = types.DeferredType
1067
1068    def generic_resolve(self, deferred, attr):
1069        return self.context.resolve_getattr(deferred.get(), attr)
1070
1071#------------------------------------------------------------------------------
1072
1073@infer_global(get_type_min_value)
1074@infer_global(get_type_max_value)
1075class MinValInfer(AbstractTemplate):
1076    def generic(self, args, kws):
1077        assert not kws
1078        assert len(args) == 1
1079        assert isinstance(args[0], (types.DType, types.NumberClass))
1080        return signature(args[0].dtype, *args)
1081
1082
1083#------------------------------------------------------------------------------
1084
1085
1086class IndexValue(object):
1087    """
1088    Index and value
1089    """
1090    def __init__(self, ind, val):
1091        self.index = ind
1092        self.value = val
1093
1094    def __repr__(self):
1095        return 'IndexValue(%f, %f)' % (self.index, self.value)
1096
1097
1098class IndexValueType(types.Type):
1099    def __init__(self, val_typ):
1100        self.val_typ = val_typ
1101        super(IndexValueType, self).__init__(
1102                                    name='IndexValueType({})'.format(val_typ))
1103
1104
1105@typeof_impl.register(IndexValue)
1106def typeof_index(val, c):
1107    val_typ = typeof_impl(val.value, c)
1108    return IndexValueType(val_typ)
1109
1110
1111@type_callable(IndexValue)
1112def type_index_value(context):
1113    def typer(ind, mval):
1114        if ind == types.intp or ind == types.uintp:
1115            return IndexValueType(mval)
1116    return typer
1117
1118
1119@register_model(IndexValueType)
1120class IndexValueModel(models.StructModel):
1121    def __init__(self, dmm, fe_type):
1122        members = [
1123            ('index', types.intp),
1124            ('value', fe_type.val_typ),
1125            ]
1126        models.StructModel.__init__(self, dmm, fe_type, members)
1127
1128
1129make_attribute_wrapper(IndexValueType, 'index', 'index')
1130make_attribute_wrapper(IndexValueType, 'value', 'value')
1131