1import itertools 2 3import numpy as np 4import operator 5 6from numba.core import types, errors 7from numba import prange 8from numba.parfors.parfor import internal_prange 9 10from numba.core.utils import RANGE_ITER_OBJECTS 11from numba.core.typing.templates import (AttributeTemplate, ConcreteTemplate, 12 AbstractTemplate, infer_global, infer, 13 infer_getattr, signature, 14 bound_function, make_callable_template) 15 16from numba.cpython.builtins import get_type_min_value, get_type_max_value 17 18from numba.core.extending import ( 19 typeof_impl, type_callable, models, register_model, make_attribute_wrapper, 20 ) 21 22 23@infer_global(print) 24class Print(AbstractTemplate): 25 def generic(self, args, kws): 26 for a in args: 27 sig = self.context.resolve_function_type("print_item", (a,), {}) 28 if sig is None: 29 raise TypeError("Type %s is not printable." % a) 30 assert sig.return_type is types.none 31 return signature(types.none, *args) 32 33@infer 34class PrintItem(AbstractTemplate): 35 key = "print_item" 36 37 def generic(self, args, kws): 38 arg, = args 39 return signature(types.none, *args) 40 41 42@infer_global(abs) 43class Abs(ConcreteTemplate): 44 int_cases = [signature(ty, ty) for ty in sorted(types.signed_domain)] 45 uint_cases = [signature(ty, ty) for ty in sorted(types.unsigned_domain)] 46 real_cases = [signature(ty, ty) for ty in sorted(types.real_domain)] 47 complex_cases = [signature(ty.underlying_float, ty) 48 for ty in sorted(types.complex_domain)] 49 cases = int_cases + uint_cases + real_cases + complex_cases 50 51 52@infer_global(slice) 53class Slice(ConcreteTemplate): 54 cases = [ 55 signature(types.slice2_type, types.intp), 56 signature(types.slice2_type, types.none), 57 signature(types.slice2_type, types.none, types.none), 58 signature(types.slice2_type, types.none, types.intp), 59 signature(types.slice2_type, types.intp, types.none), 60 signature(types.slice2_type, types.intp, types.intp), 61 signature(types.slice3_type, types.intp, types.intp, types.intp), 62 signature(types.slice3_type, types.none, types.intp, types.intp), 63 signature(types.slice3_type, types.intp, types.none, types.intp), 64 signature(types.slice3_type, types.intp, types.intp, types.none), 65 signature(types.slice3_type, types.intp, types.none, types.none), 66 signature(types.slice3_type, types.none, types.intp, types.none), 67 signature(types.slice3_type, types.none, types.none, types.intp), 68 signature(types.slice3_type, types.none, types.none, types.none), 69 ] 70 71 72class Range(ConcreteTemplate): 73 cases = [ 74 signature(types.range_state32_type, types.int32), 75 signature(types.range_state32_type, types.int32, types.int32), 76 signature(types.range_state32_type, types.int32, types.int32, 77 types.int32), 78 signature(types.range_state64_type, types.int64), 79 signature(types.range_state64_type, types.int64, types.int64), 80 signature(types.range_state64_type, types.int64, types.int64, 81 types.int64), 82 signature(types.unsigned_range_state64_type, types.uint64), 83 signature(types.unsigned_range_state64_type, types.uint64, types.uint64), 84 signature(types.unsigned_range_state64_type, types.uint64, types.uint64, 85 types.uint64), 86 ] 87 88for func in RANGE_ITER_OBJECTS: 89 infer_global(func, typing_key=range)(Range) 90 91infer_global(prange, typing_key=prange)(Range) 92infer_global(internal_prange, typing_key=internal_prange)(Range) 93 94@infer 95class GetIter(AbstractTemplate): 96 key = "getiter" 97 98 def generic(self, args, kws): 99 assert not kws 100 [obj] = args 101 if isinstance(obj, types.IterableType): 102 return signature(obj.iterator_type, obj) 103 104 105@infer 106class IterNext(AbstractTemplate): 107 key = "iternext" 108 109 def generic(self, args, kws): 110 assert not kws 111 [it] = args 112 if isinstance(it, types.IteratorType): 113 return signature(types.Pair(it.yield_type, types.boolean), it) 114 115 116@infer 117class PairFirst(AbstractTemplate): 118 """ 119 Given a heterogeneous pair, return the first element. 120 """ 121 key = "pair_first" 122 123 def generic(self, args, kws): 124 assert not kws 125 [pair] = args 126 if isinstance(pair, types.Pair): 127 return signature(pair.first_type, pair) 128 129 130@infer 131class PairSecond(AbstractTemplate): 132 """ 133 Given a heterogeneous pair, return the second element. 134 """ 135 key = "pair_second" 136 137 def generic(self, args, kws): 138 assert not kws 139 [pair] = args 140 if isinstance(pair, types.Pair): 141 return signature(pair.second_type, pair) 142 143 144def choose_result_bitwidth(*inputs): 145 return max(types.intp.bitwidth, *(tp.bitwidth for tp in inputs)) 146 147def choose_result_int(*inputs): 148 """ 149 Choose the integer result type for an operation on integer inputs, 150 according to the integer typing NBEP. 151 """ 152 bitwidth = choose_result_bitwidth(*inputs) 153 signed = any(tp.signed for tp in inputs) 154 return types.Integer.from_bitwidth(bitwidth, signed) 155 156 157# The "machine" integer types to take into consideration for operator typing 158# (according to the integer typing NBEP) 159machine_ints = ( 160 sorted(set((types.intp, types.int64))) + 161 sorted(set((types.uintp, types.uint64))) 162 ) 163 164# Explicit integer rules for binary operators; smaller ints will be 165# automatically upcast. 166integer_binop_cases = tuple( 167 signature(choose_result_int(op1, op2), op1, op2) 168 for op1, op2 in itertools.product(machine_ints, machine_ints) 169 ) 170 171 172class BinOp(ConcreteTemplate): 173 cases = list(integer_binop_cases) 174 cases += [signature(op, op, op) for op in sorted(types.real_domain)] 175 cases += [signature(op, op, op) for op in sorted(types.complex_domain)] 176 177 178@infer_global(operator.add) 179class BinOpAdd(BinOp): 180 pass 181 182 183@infer_global(operator.iadd) 184class BinOpAdd(BinOp): 185 pass 186 187 188@infer_global(operator.sub) 189class BinOpSub(BinOp): 190 pass 191 192 193@infer_global(operator.isub) 194class BinOpSub(BinOp): 195 pass 196 197 198@infer_global(operator.mul) 199class BinOpMul(BinOp): 200 pass 201 202 203@infer_global(operator.imul) 204class BinOpMul(BinOp): 205 pass 206 207 208@infer_global(operator.mod) 209class BinOpMod(ConcreteTemplate): 210 cases = list(integer_binop_cases) 211 cases += [signature(op, op, op) for op in sorted(types.real_domain)] 212 213 214@infer_global(operator.imod) 215class BinOpMod(ConcreteTemplate): 216 cases = list(integer_binop_cases) 217 cases += [signature(op, op, op) for op in sorted(types.real_domain)] 218 219 220@infer_global(operator.truediv) 221class BinOpTrueDiv(ConcreteTemplate): 222 cases = [signature(types.float64, op1, op2) 223 for op1, op2 in itertools.product(machine_ints, machine_ints)] 224 cases += [signature(op, op, op) for op in sorted(types.real_domain)] 225 cases += [signature(op, op, op) for op in sorted(types.complex_domain)] 226 227 228@infer_global(operator.itruediv) 229class BinOpTrueDiv(ConcreteTemplate): 230 cases = [signature(types.float64, op1, op2) 231 for op1, op2 in itertools.product(machine_ints, machine_ints)] 232 cases += [signature(op, op, op) for op in sorted(types.real_domain)] 233 cases += [signature(op, op, op) for op in sorted(types.complex_domain)] 234 235 236@infer_global(operator.floordiv) 237class BinOpFloorDiv(ConcreteTemplate): 238 cases = list(integer_binop_cases) 239 cases += [signature(op, op, op) for op in sorted(types.real_domain)] 240 241 242@infer_global(operator.ifloordiv) 243class BinOpFloorDiv(ConcreteTemplate): 244 cases = list(integer_binop_cases) 245 cases += [signature(op, op, op) for op in sorted(types.real_domain)] 246 247 248@infer_global(divmod) 249class DivMod(ConcreteTemplate): 250 _tys = machine_ints + sorted(types.real_domain) 251 cases = [signature(types.UniTuple(ty, 2), ty, ty) for ty in _tys] 252 253 254@infer_global(operator.pow) 255class BinOpPower(ConcreteTemplate): 256 cases = list(integer_binop_cases) 257 # Ensure that float32 ** int doesn't go through DP computations 258 cases += [signature(types.float32, types.float32, op) 259 for op in (types.int32, types.int64, types.uint64)] 260 cases += [signature(types.float64, types.float64, op) 261 for op in (types.int32, types.int64, types.uint64)] 262 cases += [signature(op, op, op) 263 for op in sorted(types.real_domain)] 264 cases += [signature(op, op, op) 265 for op in sorted(types.complex_domain)] 266 267 268@infer_global(operator.ipow) 269class BinOpPower(ConcreteTemplate): 270 cases = list(integer_binop_cases) 271 # Ensure that float32 ** int doesn't go through DP computations 272 cases += [signature(types.float32, types.float32, op) 273 for op in (types.int32, types.int64, types.uint64)] 274 cases += [signature(types.float64, types.float64, op) 275 for op in (types.int32, types.int64, types.uint64)] 276 cases += [signature(op, op, op) 277 for op in sorted(types.real_domain)] 278 cases += [signature(op, op, op) 279 for op in sorted(types.complex_domain)] 280 281 282@infer_global(pow) 283class PowerBuiltin(BinOpPower): 284 # TODO add 3 operand version 285 pass 286 287 288class BitwiseShiftOperation(ConcreteTemplate): 289 # For bitshifts, only the first operand's signedness matters 290 # to choose the operation's signedness (the second operand 291 # should always be positive but will generally be considered 292 # signed anyway, since it's often a constant integer). 293 # (also, see issue #1995 for right-shifts) 294 295 # The RHS type is fixed to 64-bit signed/unsigned ints. 296 # The implementation will always cast the operands to the width of the 297 # result type, which is the widest between the LHS type and (u)intp. 298 cases = [signature(max(op, types.intp), op, op2) 299 for op in sorted(types.signed_domain) 300 for op2 in [types.uint64, types.int64]] 301 cases += [signature(max(op, types.uintp), op, op2) 302 for op in sorted(types.unsigned_domain) 303 for op2 in [types.uint64, types.int64]] 304 unsafe_casting = False 305 306 307@infer_global(operator.lshift) 308class BitwiseLeftShift(BitwiseShiftOperation): 309 pass 310 311@infer_global(operator.ilshift) 312class BitwiseLeftShift(BitwiseShiftOperation): 313 pass 314 315 316@infer_global(operator.rshift) 317class BitwiseRightShift(BitwiseShiftOperation): 318 pass 319 320 321@infer_global(operator.irshift) 322class BitwiseRightShift(BitwiseShiftOperation): 323 pass 324 325 326class BitwiseLogicOperation(BinOp): 327 cases = [signature(types.boolean, types.boolean, types.boolean)] 328 cases += list(integer_binop_cases) 329 unsafe_casting = False 330 331 332@infer_global(operator.and_) 333class BitwiseAnd(BitwiseLogicOperation): 334 pass 335 336 337@infer_global(operator.iand) 338class BitwiseAnd(BitwiseLogicOperation): 339 pass 340 341 342@infer_global(operator.or_) 343class BitwiseOr(BitwiseLogicOperation): 344 pass 345 346 347@infer_global(operator.ior) 348class BitwiseOr(BitwiseLogicOperation): 349 pass 350 351 352@infer_global(operator.xor) 353class BitwiseXor(BitwiseLogicOperation): 354 pass 355 356 357@infer_global(operator.ixor) 358class BitwiseXor(BitwiseLogicOperation): 359 pass 360 361 362# Bitwise invert and negate are special: we must not upcast the operand 363# for unsigned numbers, as that would change the result. 364# (i.e. ~np.int8(0) == 255 but ~np.int32(0) == 4294967295). 365 366@infer_global(operator.invert) 367class BitwiseInvert(ConcreteTemplate): 368 # Note Numba follows the Numpy semantics of returning a bool, 369 # while Python returns an int. This makes it consistent with 370 # np.invert() and makes array expressions correct. 371 cases = [signature(types.boolean, types.boolean)] 372 cases += [signature(choose_result_int(op), op) for op in sorted(types.unsigned_domain)] 373 cases += [signature(choose_result_int(op), op) for op in sorted(types.signed_domain)] 374 375 unsafe_casting = False 376 377 378class UnaryOp(ConcreteTemplate): 379 cases = [signature(choose_result_int(op), op) for op in sorted(types.unsigned_domain)] 380 cases += [signature(choose_result_int(op), op) for op in sorted(types.signed_domain)] 381 cases += [signature(op, op) for op in sorted(types.real_domain)] 382 cases += [signature(op, op) for op in sorted(types.complex_domain)] 383 cases += [signature(types.intp, types.boolean)] 384 385 386@infer_global(operator.neg) 387class UnaryNegate(UnaryOp): 388 pass 389 390 391@infer_global(operator.pos) 392class UnaryPositive(UnaryOp): 393 pass 394 395 396@infer_global(operator.not_) 397class UnaryNot(ConcreteTemplate): 398 cases = [signature(types.boolean, types.boolean)] 399 cases += [signature(types.boolean, op) for op in sorted(types.signed_domain)] 400 cases += [signature(types.boolean, op) for op in sorted(types.unsigned_domain)] 401 cases += [signature(types.boolean, op) for op in sorted(types.real_domain)] 402 cases += [signature(types.boolean, op) for op in sorted(types.complex_domain)] 403 404 405class OrderedCmpOp(ConcreteTemplate): 406 cases = [signature(types.boolean, types.boolean, types.boolean)] 407 cases += [signature(types.boolean, op, op) for op in sorted(types.signed_domain)] 408 cases += [signature(types.boolean, op, op) for op in sorted(types.unsigned_domain)] 409 cases += [signature(types.boolean, op, op) for op in sorted(types.real_domain)] 410 411 412class UnorderedCmpOp(ConcreteTemplate): 413 cases = OrderedCmpOp.cases + [ 414 signature(types.boolean, op, op) for op in sorted(types.complex_domain)] 415 416 417@infer_global(operator.lt) 418class CmpOpLt(OrderedCmpOp): 419 pass 420 421 422@infer_global(operator.le) 423class CmpOpLe(OrderedCmpOp): 424 pass 425 426 427@infer_global(operator.gt) 428class CmpOpGt(OrderedCmpOp): 429 pass 430 431 432@infer_global(operator.ge) 433class CmpOpGe(OrderedCmpOp): 434 pass 435 436 437@infer_global(operator.eq) 438class CmpOpEq(UnorderedCmpOp): 439 pass 440 441 442@infer_global(operator.eq) 443class ConstOpEq(AbstractTemplate): 444 def generic(self, args, kws): 445 assert not kws 446 (arg1, arg2) = args 447 if isinstance(arg1, types.Literal) and isinstance(arg2, types.Literal): 448 return signature(types.boolean, arg1, arg2) 449 450 451@infer_global(operator.ne) 452class ConstOpNotEq(ConstOpEq): 453 pass 454 455 456@infer_global(operator.ne) 457class CmpOpNe(UnorderedCmpOp): 458 pass 459 460 461class TupleCompare(AbstractTemplate): 462 def generic(self, args, kws): 463 [lhs, rhs] = args 464 if isinstance(lhs, types.BaseTuple) and isinstance(rhs, types.BaseTuple): 465 for u, v in zip(lhs, rhs): 466 # Check element-wise comparability 467 res = self.context.resolve_function_type(self.key, (u, v), {}) 468 if res is None: 469 break 470 else: 471 return signature(types.boolean, lhs, rhs) 472 473 474@infer_global(operator.eq) 475class TupleEq(TupleCompare): 476 pass 477 478 479@infer_global(operator.ne) 480class TupleNe(TupleCompare): 481 pass 482 483 484@infer_global(operator.ge) 485class TupleGe(TupleCompare): 486 pass 487 488 489@infer_global(operator.gt) 490class TupleGt(TupleCompare): 491 pass 492 493 494@infer_global(operator.le) 495class TupleLe(TupleCompare): 496 pass 497 498 499@infer_global(operator.lt) 500class TupleLt(TupleCompare): 501 pass 502 503 504@infer_global(operator.add) 505class TupleAdd(AbstractTemplate): 506 def generic(self, args, kws): 507 if len(args) == 2: 508 a, b = args 509 if (isinstance(a, types.BaseTuple) and isinstance(b, types.BaseTuple) 510 and not isinstance(a, types.BaseNamedTuple) 511 and not isinstance(b, types.BaseNamedTuple)): 512 res = types.BaseTuple.from_types(tuple(a) + tuple(b)) 513 return signature(res, a, b) 514 515 516class CmpOpIdentity(AbstractTemplate): 517 def generic(self, args, kws): 518 [lhs, rhs] = args 519 return signature(types.boolean, lhs, rhs) 520 521 522@infer_global(operator.is_) 523class CmpOpIs(CmpOpIdentity): 524 pass 525 526 527@infer_global(operator.is_not) 528class CmpOpIsNot(CmpOpIdentity): 529 pass 530 531 532def normalize_1d_index(index): 533 """ 534 Normalize the *index* type (an integer or slice) for indexing a 1D 535 sequence. 536 """ 537 if isinstance(index, types.SliceType): 538 return index 539 540 elif isinstance(index, types.Integer): 541 return types.intp if index.signed else types.uintp 542 543 544@infer_global(operator.getitem) 545class GetItemCPointer(AbstractTemplate): 546 def generic(self, args, kws): 547 assert not kws 548 ptr, idx = args 549 if isinstance(ptr, types.CPointer) and isinstance(idx, types.Integer): 550 return signature(ptr.dtype, ptr, normalize_1d_index(idx)) 551 552 553@infer_global(operator.setitem) 554class SetItemCPointer(AbstractTemplate): 555 def generic(self, args, kws): 556 assert not kws 557 ptr, idx, val = args 558 if isinstance(ptr, types.CPointer) and isinstance(idx, types.Integer): 559 return signature(types.none, ptr, normalize_1d_index(idx), ptr.dtype) 560 561 562@infer_global(len) 563class Len(AbstractTemplate): 564 def generic(self, args, kws): 565 assert not kws 566 (val,) = args 567 if isinstance(val, (types.Buffer, types.BaseTuple)): 568 return signature(types.intp, val) 569 elif isinstance(val, (types.RangeType)): 570 return signature(val.dtype, val) 571 572@infer_global(tuple) 573class TupleConstructor(AbstractTemplate): 574 def generic(self, args, kws): 575 assert not kws 576 # empty tuple case 577 if len(args) == 0: 578 return signature(types.Tuple(())) 579 (val,) = args 580 # tuple as input 581 if isinstance(val, types.BaseTuple): 582 return signature(val, val) 583 584 585@infer_global(operator.contains) 586class Contains(AbstractTemplate): 587 def generic(self, args, kws): 588 assert not kws 589 (seq, val) = args 590 591 if isinstance(seq, (types.Sequence)): 592 return signature(types.boolean, seq, val) 593 594@infer_global(operator.truth) 595class TupleBool(AbstractTemplate): 596 def generic(self, args, kws): 597 assert not kws 598 (val,) = args 599 if isinstance(val, (types.BaseTuple)): 600 return signature(types.boolean, val) 601 602 603@infer 604class StaticGetItemTuple(AbstractTemplate): 605 key = "static_getitem" 606 607 def generic(self, args, kws): 608 tup, idx = args 609 ret = None 610 if not isinstance(tup, types.BaseTuple): 611 return 612 if isinstance(idx, int): 613 ret = tup.types[idx] 614 elif isinstance(idx, slice): 615 ret = types.BaseTuple.from_types(tup.types[idx]) 616 if ret is not None: 617 sig = signature(ret, *args) 618 return sig 619 620 621@infer 622class StaticGetItemLiteralList(AbstractTemplate): 623 key = "static_getitem" 624 625 def generic(self, args, kws): 626 tup, idx = args 627 ret = None 628 if not isinstance(tup, types.LiteralList): 629 return 630 if isinstance(idx, int): 631 ret = tup.types[idx] 632 if ret is not None: 633 sig = signature(ret, *args) 634 return sig 635 636 637@infer 638class StaticGetItemLiteralStrKeyDict(AbstractTemplate): 639 key = "static_getitem" 640 641 def generic(self, args, kws): 642 tup, idx = args 643 ret = None 644 if not isinstance(tup, types.LiteralStrKeyDict): 645 return 646 if isinstance(idx, str): 647 lookup = tup.fields.index(idx) 648 ret = tup.types[lookup] 649 if ret is not None: 650 sig = signature(ret, *args) 651 return sig 652 653# Generic implementation for "not in" 654 655@infer 656class GenericNotIn(AbstractTemplate): 657 key = "not in" 658 659 def generic(self, args, kws): 660 args = args[::-1] 661 sig = self.context.resolve_function_type(operator.contains, args, kws) 662 return signature(sig.return_type, *sig.args[::-1]) 663 664 665#------------------------------------------------------------------------------- 666 667@infer_getattr 668class MemoryViewAttribute(AttributeTemplate): 669 key = types.MemoryView 670 671 def resolve_contiguous(self, buf): 672 return types.boolean 673 674 def resolve_c_contiguous(self, buf): 675 return types.boolean 676 677 def resolve_f_contiguous(self, buf): 678 return types.boolean 679 680 def resolve_itemsize(self, buf): 681 return types.intp 682 683 def resolve_nbytes(self, buf): 684 return types.intp 685 686 def resolve_readonly(self, buf): 687 return types.boolean 688 689 def resolve_shape(self, buf): 690 return types.UniTuple(types.intp, buf.ndim) 691 692 def resolve_strides(self, buf): 693 return types.UniTuple(types.intp, buf.ndim) 694 695 def resolve_ndim(self, buf): 696 return types.intp 697 698 699#------------------------------------------------------------------------------- 700 701 702@infer_getattr 703class BooleanAttribute(AttributeTemplate): 704 key = types.Boolean 705 706 def resolve___class__(self, ty): 707 return types.NumberClass(ty) 708 709 @bound_function("number.item") 710 def resolve_item(self, ty, args, kws): 711 assert not kws 712 if not args: 713 return signature(ty) 714 715 716@infer_getattr 717class NumberAttribute(AttributeTemplate): 718 key = types.Number 719 720 def resolve___class__(self, ty): 721 return types.NumberClass(ty) 722 723 def resolve_real(self, ty): 724 return getattr(ty, "underlying_float", ty) 725 726 def resolve_imag(self, ty): 727 return getattr(ty, "underlying_float", ty) 728 729 @bound_function("complex.conjugate") 730 def resolve_conjugate(self, ty, args, kws): 731 assert not args 732 assert not kws 733 return signature(ty) 734 735 @bound_function("number.item") 736 def resolve_item(self, ty, args, kws): 737 assert not kws 738 if not args: 739 return signature(ty) 740 741 742@infer_getattr 743class NPTimedeltaAttribute(AttributeTemplate): 744 key = types.NPTimedelta 745 746 def resolve___class__(self, ty): 747 return types.NumberClass(ty) 748 749 750@infer_getattr 751class NPDatetimeAttribute(AttributeTemplate): 752 key = types.NPDatetime 753 754 def resolve___class__(self, ty): 755 return types.NumberClass(ty) 756 757 758@infer_getattr 759class SliceAttribute(AttributeTemplate): 760 key = types.SliceType 761 762 def resolve_start(self, ty): 763 return types.intp 764 765 def resolve_stop(self, ty): 766 return types.intp 767 768 def resolve_step(self, ty): 769 return types.intp 770 771 @bound_function("slice.indices") 772 def resolve_indices(self, ty, args, kws): 773 assert not kws 774 if len(args) != 1: 775 raise TypeError( 776 "indices() takes exactly one argument (%d given)" % len(args) 777 ) 778 typ, = args 779 if not isinstance(typ, types.Integer): 780 raise TypeError( 781 "'%s' object cannot be interpreted as an integer" % typ 782 ) 783 return signature(types.UniTuple(types.intp, 3), types.intp) 784 785 786#------------------------------------------------------------------------------- 787 788 789@infer_getattr 790class NumberClassAttribute(AttributeTemplate): 791 key = types.NumberClass 792 793 def resolve___call__(self, classty): 794 """ 795 Resolve a number class's constructor (e.g. calling int(...)) 796 """ 797 ty = classty.instance_type 798 799 def typer(val): 800 if isinstance(val, (types.BaseTuple, types.Sequence)): 801 # Array constructor, e.g. np.int32([1, 2]) 802 sig = self.context.resolve_function_type( 803 np.array, (val,), {'dtype': types.DType(ty)}) 804 return sig.return_type 805 else: 806 # Scalar constructor, e.g. np.int32(42) 807 return ty 808 809 return types.Function(make_callable_template(key=ty, typer=typer)) 810 811 812@infer_getattr 813class TypeRefAttribute(AttributeTemplate): 814 key = types.TypeRef 815 816 def resolve___call__(self, classty): 817 """ 818 Resolve a number class's constructor (e.g. calling int(...)) 819 820 Note: 821 822 This is needed because of the limitation of the current type-system 823 implementation. Specifically, the lack of a higher-order type 824 (i.e. passing the ``DictType`` vs ``DictType(key_type, value_type)``) 825 """ 826 ty = classty.instance_type 827 828 if isinstance(ty, type) and issubclass(ty, types.Type): 829 # Redirect the typing to a: 830 # @type_callable(ty) 831 # def typeddict_call(context): 832 # ... 833 # For example, see numba/typed/typeddict.py 834 # @type_callable(DictType) 835 # def typeddict_call(context): 836 class Redirect(object): 837 838 def __init__(self, context): 839 self.context = context 840 841 def __call__(self, *args, **kwargs): 842 result = self.context.resolve_function_type(ty, args, kwargs) 843 if hasattr(result, "pysig"): 844 self.pysig = result.pysig 845 return result 846 847 return types.Function(make_callable_template(key=ty, 848 typer=Redirect(self.context))) 849 850 851#------------------------------------------------------------------------------ 852 853 854class MinMaxBase(AbstractTemplate): 855 856 def _unify_minmax(self, tys): 857 for ty in tys: 858 if not isinstance(ty, types.Number): 859 return 860 return self.context.unify_types(*tys) 861 862 def generic(self, args, kws): 863 """ 864 Resolve a min() or max() call. 865 """ 866 assert not kws 867 868 if not args: 869 return 870 if len(args) == 1: 871 # max(arg) only supported if arg is an iterable 872 if isinstance(args[0], types.BaseTuple): 873 tys = list(args[0]) 874 if not tys: 875 raise TypeError("%s() argument is an empty tuple" 876 % (self.key.__name__,)) 877 else: 878 return 879 else: 880 # max(*args) 881 tys = args 882 retty = self._unify_minmax(tys) 883 if retty is not None: 884 return signature(retty, *args) 885 886 887@infer_global(max) 888class Max(MinMaxBase): 889 pass 890 891 892@infer_global(min) 893class Min(MinMaxBase): 894 pass 895 896 897@infer_global(round) 898class Round(ConcreteTemplate): 899 cases = [ 900 signature(types.intp, types.float32), 901 signature(types.int64, types.float64), 902 signature(types.float32, types.float32, types.intp), 903 signature(types.float64, types.float64, types.intp), 904 ] 905 906 907#------------------------------------------------------------------------------ 908 909 910@infer_global(bool) 911class Bool(AbstractTemplate): 912 913 def generic(self, args, kws): 914 assert not kws 915 [arg] = args 916 if isinstance(arg, (types.Boolean, types.Number)): 917 return signature(types.boolean, arg) 918 # XXX typing for bool cannot be polymorphic because of the 919 # types.Function thing, so we redirect to the operator.truth 920 # intrinsic. 921 return self.context.resolve_function_type(operator.truth, args, kws) 922 923 924@infer_global(int) 925class Int(AbstractTemplate): 926 927 def generic(self, args, kws): 928 assert not kws 929 930 [arg] = args 931 932 if isinstance(arg, types.Integer): 933 return signature(arg, arg) 934 if isinstance(arg, (types.Float, types.Boolean)): 935 return signature(types.intp, arg) 936 937 938@infer_global(float) 939class Float(AbstractTemplate): 940 941 def generic(self, args, kws): 942 assert not kws 943 944 [arg] = args 945 946 if arg not in types.number_domain: 947 raise TypeError("float() only support for numbers") 948 949 if arg in types.complex_domain: 950 raise TypeError("float() does not support complex") 951 952 if arg in types.integer_domain: 953 return signature(types.float64, arg) 954 955 elif arg in types.real_domain: 956 return signature(arg, arg) 957 958 959@infer_global(complex) 960class Complex(AbstractTemplate): 961 962 def generic(self, args, kws): 963 assert not kws 964 965 if len(args) == 1: 966 [arg] = args 967 if arg not in types.number_domain: 968 raise TypeError("complex() only support for numbers") 969 if arg == types.float32: 970 return signature(types.complex64, arg) 971 else: 972 return signature(types.complex128, arg) 973 974 elif len(args) == 2: 975 [real, imag] = args 976 if (real not in types.number_domain or 977 imag not in types.number_domain): 978 raise TypeError("complex() only support for numbers") 979 if real == imag == types.float32: 980 return signature(types.complex64, real, imag) 981 else: 982 return signature(types.complex128, real, imag) 983 984 985#------------------------------------------------------------------------------ 986 987@infer_global(enumerate) 988class Enumerate(AbstractTemplate): 989 990 def generic(self, args, kws): 991 assert not kws 992 it = args[0] 993 if len(args) > 1 and not isinstance(args[1], types.Integer): 994 raise TypeError("Only integers supported as start value in " 995 "enumerate") 996 elif len(args) > 2: 997 #let python raise its own error 998 enumerate(*args) 999 1000 if isinstance(it, types.IterableType): 1001 enumerate_type = types.EnumerateType(it) 1002 return signature(enumerate_type, *args) 1003 1004 1005@infer_global(zip) 1006class Zip(AbstractTemplate): 1007 1008 def generic(self, args, kws): 1009 assert not kws 1010 if all(isinstance(it, types.IterableType) for it in args): 1011 zip_type = types.ZipType(args) 1012 return signature(zip_type, *args) 1013 1014 1015@infer_global(iter) 1016class Iter(AbstractTemplate): 1017 1018 def generic(self, args, kws): 1019 assert not kws 1020 if len(args) == 1: 1021 it = args[0] 1022 if isinstance(it, types.IterableType): 1023 return signature(it.iterator_type, *args) 1024 1025 1026@infer_global(next) 1027class Next(AbstractTemplate): 1028 1029 def generic(self, args, kws): 1030 assert not kws 1031 if len(args) == 1: 1032 it = args[0] 1033 if isinstance(it, types.IteratorType): 1034 return signature(it.yield_type, *args) 1035 1036 1037#------------------------------------------------------------------------------ 1038 1039@infer_global(type) 1040class TypeBuiltin(AbstractTemplate): 1041 1042 def generic(self, args, kws): 1043 assert not kws 1044 if len(args) == 1: 1045 # One-argument type() -> return the __class__ 1046 # Avoid literal types 1047 arg = types.unliteral(args[0]) 1048 classty = self.context.resolve_getattr(arg, "__class__") 1049 if classty is not None: 1050 return signature(classty, *args) 1051 1052 1053#------------------------------------------------------------------------------ 1054 1055@infer_getattr 1056class OptionalAttribute(AttributeTemplate): 1057 key = types.Optional 1058 1059 def generic_resolve(self, optional, attr): 1060 return self.context.resolve_getattr(optional.type, attr) 1061 1062#------------------------------------------------------------------------------ 1063 1064@infer_getattr 1065class DeferredAttribute(AttributeTemplate): 1066 key = types.DeferredType 1067 1068 def generic_resolve(self, deferred, attr): 1069 return self.context.resolve_getattr(deferred.get(), attr) 1070 1071#------------------------------------------------------------------------------ 1072 1073@infer_global(get_type_min_value) 1074@infer_global(get_type_max_value) 1075class MinValInfer(AbstractTemplate): 1076 def generic(self, args, kws): 1077 assert not kws 1078 assert len(args) == 1 1079 assert isinstance(args[0], (types.DType, types.NumberClass)) 1080 return signature(args[0].dtype, *args) 1081 1082 1083#------------------------------------------------------------------------------ 1084 1085 1086class IndexValue(object): 1087 """ 1088 Index and value 1089 """ 1090 def __init__(self, ind, val): 1091 self.index = ind 1092 self.value = val 1093 1094 def __repr__(self): 1095 return 'IndexValue(%f, %f)' % (self.index, self.value) 1096 1097 1098class IndexValueType(types.Type): 1099 def __init__(self, val_typ): 1100 self.val_typ = val_typ 1101 super(IndexValueType, self).__init__( 1102 name='IndexValueType({})'.format(val_typ)) 1103 1104 1105@typeof_impl.register(IndexValue) 1106def typeof_index(val, c): 1107 val_typ = typeof_impl(val.value, c) 1108 return IndexValueType(val_typ) 1109 1110 1111@type_callable(IndexValue) 1112def type_index_value(context): 1113 def typer(ind, mval): 1114 if ind == types.intp or ind == types.uintp: 1115 return IndexValueType(mval) 1116 return typer 1117 1118 1119@register_model(IndexValueType) 1120class IndexValueModel(models.StructModel): 1121 def __init__(self, dmm, fe_type): 1122 members = [ 1123 ('index', types.intp), 1124 ('value', fe_type.val_typ), 1125 ] 1126 models.StructModel.__init__(self, dmm, fe_type, members) 1127 1128 1129make_attribute_wrapper(IndexValueType, 'index', 'index') 1130make_attribute_wrapper(IndexValueType, 'value', 'value') 1131