1""" 2Type inference base on CPA. 3The algorithm guarantees monotonic growth of type-sets for each variable. 4 5Steps: 6 1. seed initial types 7 2. build constraints 8 3. propagate constraints 9 4. unify types 10 11Constraint propagation is precise and does not regret (no backtracing). 12Constraints push types forward following the dataflow. 13""" 14 15 16import logging 17import operator 18import contextlib 19import itertools 20from pprint import pprint 21from collections import OrderedDict, defaultdict 22from functools import reduce 23 24from numba.core import types, utils, typing, ir, config 25from numba.core.typing.templates import Signature 26from numba.core.errors import (TypingError, UntypedAttributeError, 27 new_error_context, termcolor, UnsupportedError, 28 ForceLiteralArg, CompilerError) 29from numba.core.funcdesc import qualifying_prefix 30 31_logger = logging.getLogger(__name__) 32 33 34class NOTSET: 35 pass 36 37 38# terminal color markup 39_termcolor = termcolor() 40 41 42class TypeVar(object): 43 def __init__(self, context, var): 44 self.context = context 45 self.var = var 46 self.type = None 47 self.locked = False 48 # Stores source location of first definition 49 self.define_loc = None 50 # Qualifiers 51 self.literal_value = NOTSET 52 53 def add_type(self, tp, loc): 54 assert isinstance(tp, types.Type), type(tp) 55 56 if self.locked: 57 if tp != self.type: 58 if self.context.can_convert(tp, self.type) is None: 59 msg = ("No conversion from %s to %s for '%s', " 60 "defined at %s") 61 raise TypingError(msg % (tp, self.type, self.var, 62 self.define_loc), 63 loc=loc) 64 else: 65 if self.type is not None: 66 unified = self.context.unify_pairs(self.type, tp) 67 if unified is None: 68 msg = "Cannot unify %s and %s for '%s', defined at %s" 69 raise TypingError(msg % (self.type, tp, self.var, 70 self.define_loc), 71 loc=self.define_loc) 72 else: 73 # First time definition 74 unified = tp 75 self.define_loc = loc 76 77 self.type = unified 78 79 return self.type 80 81 def lock(self, tp, loc, literal_value=NOTSET): 82 assert isinstance(tp, types.Type), type(tp) 83 84 if self.locked: 85 msg = ("Invalid reassignment of a type-variable detected, type " 86 "variables are locked according to the user provided " 87 "function signature or from an ir.Const node. This is a " 88 "bug! Type={}. {}").format(tp, self.type) 89 raise CompilerError(msg, loc) 90 91 # If there is already a type, ensure we can convert it to the 92 # locked type. 93 if (self.type is not None and 94 self.context.can_convert(self.type, tp) is None): 95 raise TypingError("No conversion from %s to %s for " 96 "'%s'" % (tp, self.type, self.var), loc=loc) 97 98 self.type = tp 99 self.locked = True 100 if self.define_loc is None: 101 self.define_loc = loc 102 self.literal_value = literal_value 103 104 def union(self, other, loc): 105 if other.type is not None: 106 self.add_type(other.type, loc=loc) 107 108 return self.type 109 110 def __repr__(self): 111 return '%s := %s' % (self.var, self.type or "<undecided>") 112 113 @property 114 def defined(self): 115 return self.type is not None 116 117 def get(self): 118 return (self.type,) if self.type is not None else () 119 120 def getone(self): 121 if self.type is None: 122 raise TypingError("Undecided type {}".format(self)) 123 return self.type 124 125 def __len__(self): 126 return 1 if self.type is not None else 0 127 128 129class ConstraintNetwork(object): 130 """ 131 TODO: It is possible to optimize constraint propagation to consider only 132 dirty type variables. 133 """ 134 135 def __init__(self): 136 self.constraints = [] 137 138 def append(self, constraint): 139 self.constraints.append(constraint) 140 141 def propagate(self, typeinfer): 142 """ 143 Execute all constraints. Errors are caught and returned as a list. 144 This allows progressing even though some constraints may fail 145 due to lack of information 146 (e.g. imprecise types such as List(undefined)). 147 """ 148 errors = [] 149 for constraint in self.constraints: 150 loc = constraint.loc 151 with typeinfer.warnings.catch_warnings(filename=loc.filename, 152 lineno=loc.line): 153 try: 154 constraint(typeinfer) 155 except ForceLiteralArg as e: 156 errors.append(e) 157 except TypingError as e: 158 _logger.debug("captured error", exc_info=e) 159 new_exc = TypingError( 160 str(e), loc=constraint.loc, 161 highlighting=False, 162 ) 163 errors.append(utils.chain_exception(new_exc, e)) 164 except Exception as e: 165 _logger.debug("captured error", exc_info=e) 166 msg = ("Internal error at {con}.\n" 167 "{err}\nEnable logging at debug level for details.") 168 new_exc = TypingError( 169 msg.format(con=constraint, err=str(e)), 170 loc=constraint.loc, 171 highlighting=False, 172 ) 173 errors.append(utils.chain_exception(new_exc, e)) 174 return errors 175 176 177class Propagate(object): 178 """ 179 A simple constraint for direct propagation of types for assignments. 180 """ 181 182 def __init__(self, dst, src, loc): 183 self.dst = dst 184 self.src = src 185 self.loc = loc 186 187 def __call__(self, typeinfer): 188 with new_error_context("typing of assignment at {0}", self.loc, 189 loc=self.loc): 190 typeinfer.copy_type(self.src, self.dst, loc=self.loc) 191 # If `dst` is refined, notify us 192 typeinfer.refine_map[self.dst] = self 193 194 def refine(self, typeinfer, target_type): 195 # Do not back-propagate to locked variables (e.g. constants) 196 assert target_type.is_precise() 197 typeinfer.add_type(self.src, target_type, unless_locked=True, 198 loc=self.loc) 199 200 201class ArgConstraint(object): 202 203 def __init__(self, dst, src, loc): 204 self.dst = dst 205 self.src = src 206 self.loc = loc 207 208 def __call__(self, typeinfer): 209 with new_error_context("typing of argument at {0}", self.loc): 210 typevars = typeinfer.typevars 211 src = typevars[self.src] 212 if not src.defined: 213 return 214 ty = src.getone() 215 if isinstance(ty, types.Omitted): 216 ty = typeinfer.context.resolve_value_type_prefer_literal( 217 ty.value, 218 ) 219 if not ty.is_precise(): 220 raise TypingError('non-precise type {}'.format(ty)) 221 typeinfer.add_type(self.dst, ty, loc=self.loc) 222 223 224class BuildTupleConstraint(object): 225 def __init__(self, target, items, loc): 226 self.target = target 227 self.items = items 228 self.loc = loc 229 230 def __call__(self, typeinfer): 231 with new_error_context("typing of tuple at {0}", self.loc): 232 typevars = typeinfer.typevars 233 tsets = [typevars[i.name].get() for i in self.items] 234 for vals in itertools.product(*tsets): 235 if vals and all(vals[0] == v for v in vals): 236 tup = types.UniTuple(dtype=vals[0], count=len(vals)) 237 else: 238 # empty tuples fall here as well 239 tup = types.Tuple(vals) 240 assert tup.is_precise() 241 typeinfer.add_type(self.target, tup, loc=self.loc) 242 243 244class _BuildContainerConstraint(object): 245 246 def __init__(self, target, items, loc): 247 self.target = target 248 self.items = items 249 self.loc = loc 250 251 def __call__(self, typeinfer): 252 with new_error_context("typing of {0} at {1}", 253 self.container_type, self.loc): 254 typevars = typeinfer.typevars 255 tsets = [typevars[i.name].get() for i in self.items] 256 if not tsets: 257 typeinfer.add_type(self.target, 258 self.container_type(types.undefined), 259 loc=self.loc) 260 else: 261 for typs in itertools.product(*tsets): 262 unified = typeinfer.context.unify_types(*typs) 263 if unified is not None: 264 typeinfer.add_type(self.target, 265 self.container_type(unified), 266 loc=self.loc) 267 268 269class BuildListConstraint(_BuildContainerConstraint): 270 271 def __init__(self, target, items, loc): 272 self.target = target 273 self.items = items 274 self.loc = loc 275 276 def __call__(self, typeinfer): 277 with new_error_context("typing of {0} at {1}", 278 types.List, self.loc): 279 typevars = typeinfer.typevars 280 tsets = [typevars[i.name].get() for i in self.items] 281 if not tsets: 282 typeinfer.add_type(self.target, 283 types.List(types.undefined), 284 loc=self.loc) 285 else: 286 for typs in itertools.product(*tsets): 287 unified = typeinfer.context.unify_types(*typs) 288 if unified is not None: 289 # pull out literals if available 290 islit = [isinstance(x, types.Literal) for x in typs] 291 iv = None 292 if all(islit): 293 iv = [x.literal_value for x in typs] 294 typeinfer.add_type(self.target, 295 types.List(unified, 296 initial_value=iv), 297 loc=self.loc) 298 else: 299 typeinfer.add_type(self.target, 300 types.LiteralList(typs), 301 loc=self.loc) 302 303 304class BuildSetConstraint(_BuildContainerConstraint): 305 container_type = types.Set 306 307 308class BuildMapConstraint(object): 309 310 def __init__(self, target, items, special_value, value_indexes, loc): 311 self.target = target 312 self.items = items 313 self.special_value = special_value 314 self.value_indexes = value_indexes 315 self.loc = loc 316 317 def __call__(self, typeinfer): 318 319 with new_error_context("typing of dict at {0}", self.loc): 320 typevars = typeinfer.typevars 321 322 # figure out what sort of dict is being dealt with 323 tsets = [(typevars[k.name].getone(), typevars[v.name].getone()) 324 for k, v in self.items] 325 326 if not tsets: 327 typeinfer.add_type(self.target, 328 types.DictType(types.undefined, 329 types.undefined, 330 self.special_value), 331 loc=self.loc) 332 else: 333 # all the info is known about the dict, if its 334 # str keys -> random heterogeneous values treat as literalstrkey 335 ktys = [x[0] for x in tsets] 336 vtys = [x[1] for x in tsets] 337 strkey = all([isinstance(x, types.StringLiteral) for x in ktys]) 338 literalvty = all([isinstance(x, types.Literal) for x in vtys]) 339 vt0 = types.unliteral(vtys[0]) 340 # homogeneous values comes in the form of being able to cast 341 # all the other values in the ctor to the type of the first 342 343 def check(other): 344 return typeinfer.context.can_convert(other, vt0) is not None 345 homogeneous = all([check(types.unliteral(x)) for x in vtys]) 346 # Special cases: 347 # Single key:value in ctor, key is str, value is an otherwise 348 # illegal container type, e.g. LiteralStrKeyDict or 349 # List, there's no way to put this into a typed.Dict, so make it 350 # a LiteralStrKeyDict, same goes for LiteralList. 351 if len(vtys) == 1: 352 valty = vtys[0] 353 if isinstance(valty, (types.LiteralStrKeyDict, 354 types.List, 355 types.LiteralList)): 356 homogeneous = False 357 358 if strkey and not homogeneous: 359 resolved_dict = {x: y for x, y in zip(ktys, vtys)} 360 ty = types.LiteralStrKeyDict(resolved_dict, 361 self.value_indexes) 362 typeinfer.add_type(self.target, ty, loc=self.loc) 363 else: 364 init_value = self.special_value if literalvty else None 365 key_type, value_type = tsets[0] 366 typeinfer.add_type(self.target, 367 types.DictType(key_type, 368 value_type, 369 init_value), 370 loc=self.loc) 371 372 373class ExhaustIterConstraint(object): 374 def __init__(self, target, count, iterator, loc): 375 self.target = target 376 self.count = count 377 self.iterator = iterator 378 self.loc = loc 379 380 def __call__(self, typeinfer): 381 with new_error_context("typing of exhaust iter at {0}", self.loc): 382 typevars = typeinfer.typevars 383 for tp in typevars[self.iterator.name].get(): 384 # unpack optional 385 tp = tp.type if isinstance(tp, types.Optional) else tp 386 if isinstance(tp, types.BaseTuple): 387 if len(tp) == self.count: 388 assert tp.is_precise() 389 typeinfer.add_type(self.target, tp, loc=self.loc) 390 break 391 else: 392 raise ValueError("wrong tuple length for %r: " 393 "expected %d, got %d" 394 % (self.iterator.name, self.count, 395 len(tp))) 396 elif isinstance(tp, types.IterableType): 397 tup = types.UniTuple(dtype=tp.iterator_type.yield_type, 398 count=self.count) 399 assert tup.is_precise() 400 typeinfer.add_type(self.target, tup, loc=self.loc) 401 break 402 else: 403 raise TypingError("failed to unpack {}".format(tp), 404 loc=self.loc) 405 406 407class PairFirstConstraint(object): 408 def __init__(self, target, pair, loc): 409 self.target = target 410 self.pair = pair 411 self.loc = loc 412 413 def __call__(self, typeinfer): 414 with new_error_context("typing of pair-first at {0}", self.loc): 415 typevars = typeinfer.typevars 416 for tp in typevars[self.pair.name].get(): 417 if not isinstance(tp, types.Pair): 418 # XXX is this an error? 419 continue 420 assert (isinstance(tp.first_type, types.UndefinedFunctionType) 421 or tp.first_type.is_precise()) 422 typeinfer.add_type(self.target, tp.first_type, loc=self.loc) 423 424 425class PairSecondConstraint(object): 426 def __init__(self, target, pair, loc): 427 self.target = target 428 self.pair = pair 429 self.loc = loc 430 431 def __call__(self, typeinfer): 432 with new_error_context("typing of pair-second at {0}", self.loc): 433 typevars = typeinfer.typevars 434 for tp in typevars[self.pair.name].get(): 435 if not isinstance(tp, types.Pair): 436 # XXX is this an error? 437 continue 438 assert tp.second_type.is_precise() 439 typeinfer.add_type(self.target, tp.second_type, loc=self.loc) 440 441 442class StaticGetItemConstraint(object): 443 def __init__(self, target, value, index, index_var, loc): 444 self.target = target 445 self.value = value 446 self.index = index 447 if index_var is not None: 448 self.fallback = IntrinsicCallConstraint(target, operator.getitem, 449 (value, index_var), {}, 450 None, loc) 451 else: 452 self.fallback = None 453 self.loc = loc 454 455 def __call__(self, typeinfer): 456 with new_error_context("typing of static-get-item at {0}", self.loc): 457 typevars = typeinfer.typevars 458 for ty in typevars[self.value.name].get(): 459 sig = typeinfer.context.resolve_static_getitem( 460 value=ty, index=self.index, 461 ) 462 463 if sig is not None: 464 itemty = sig.return_type 465 # if the itemty is not precise, let it through, unification 466 # will catch it and produce a better error message 467 typeinfer.add_type(self.target, itemty, loc=self.loc) 468 elif self.fallback is not None: 469 self.fallback(typeinfer) 470 471 def get_call_signature(self): 472 # The signature is only needed for the fallback case in lowering 473 return self.fallback and self.fallback.get_call_signature() 474 475 476class TypedGetItemConstraint(object): 477 def __init__(self, target, value, dtype, index, loc): 478 self.target = target 479 self.value = value 480 self.dtype = dtype 481 self.index = index 482 self.loc = loc 483 484 def __call__(self, typeinfer): 485 with new_error_context("typing of typed-get-item at {0}", self.loc): 486 typevars = typeinfer.typevars 487 idx_ty = typevars[self.index.name].get() 488 ty = typevars[self.value.name].get() 489 self.signature = Signature(self.dtype, ty + idx_ty, None) 490 typeinfer.add_type(self.target, self.dtype, loc=self.loc) 491 492 def get_call_signature(self): 493 return self.signature 494 495 496def fold_arg_vars(typevars, args, vararg, kws): 497 """ 498 Fold and resolve the argument variables of a function call. 499 """ 500 # Fetch all argument types, bail if any is unknown 501 n_pos_args = len(args) 502 kwds = [kw for (kw, var) in kws] 503 argtypes = [typevars[a.name] for a in args] 504 argtypes += [typevars[var.name] for (kw, var) in kws] 505 if vararg is not None: 506 argtypes.append(typevars[vararg.name]) 507 508 if not all(a.defined for a in argtypes): 509 return 510 511 args = tuple(a.getone() for a in argtypes) 512 513 pos_args = args[:n_pos_args] 514 if vararg is not None: 515 errmsg = "*args in function call should be a tuple, got %s" 516 # Handle constant literal used for `*args` 517 if isinstance(args[-1], types.Literal): 518 const_val = args[-1].literal_value 519 # Is the constant value a tuple? 520 if not isinstance(const_val, tuple): 521 raise TypeError(errmsg % (args[-1],)) 522 # Append the elements in the const tuple to the positional args 523 pos_args += const_val 524 # Handle non-constant 525 elif not isinstance(args[-1], types.BaseTuple): 526 # Unsuitable for *args 527 # (Python is more lenient and accepts all iterables) 528 raise TypeError(errmsg % (args[-1],)) 529 else: 530 # Append the elements in the tuple to the positional args 531 pos_args += args[-1].types 532 # Drop the last arg 533 args = args[:-1] 534 kw_args = dict(zip(kwds, args[n_pos_args:])) 535 return pos_args, kw_args 536 537 538def _is_array_not_precise(arrty): 539 """Check type is array and it is not precise 540 """ 541 return isinstance(arrty, types.Array) and not arrty.is_precise() 542 543 544class CallConstraint(object): 545 """Constraint for calling functions. 546 Perform case analysis foreach combinations of argument types. 547 """ 548 signature = None 549 550 def __init__(self, target, func, args, kws, vararg, loc): 551 self.target = target 552 self.func = func 553 self.args = args 554 self.kws = kws or {} 555 self.vararg = vararg 556 self.loc = loc 557 558 def __call__(self, typeinfer): 559 msg = "typing of call at {0}\n".format(self.loc) 560 with new_error_context(msg): 561 typevars = typeinfer.typevars 562 with new_error_context( 563 "resolving caller type: {}".format(self.func)): 564 fnty = typevars[self.func].getone() 565 with new_error_context("resolving callee type: {0}", fnty): 566 self.resolve(typeinfer, typevars, fnty) 567 568 def resolve(self, typeinfer, typevars, fnty): 569 assert fnty 570 context = typeinfer.context 571 572 r = fold_arg_vars(typevars, self.args, self.vararg, self.kws) 573 if r is None: 574 # Cannot resolve call type until all argument types are known 575 return 576 pos_args, kw_args = r 577 578 # Check argument to be precise 579 for a in itertools.chain(pos_args, kw_args.values()): 580 # Forbids imprecise type except array of undefined dtype 581 if not a.is_precise() and not isinstance(a, types.Array): 582 return 583 584 # Resolve call type 585 try: 586 sig = typeinfer.resolve_call(fnty, pos_args, kw_args) 587 except ForceLiteralArg as e: 588 # Adjust for bound methods 589 folding_args = ((fnty.this,) + tuple(self.args) 590 if isinstance(fnty, types.BoundFunction) 591 else self.args) 592 folded = e.fold_arguments(folding_args, self.kws) 593 requested = set() 594 unsatisified = set() 595 for idx in e.requested_args: 596 maybe_arg = typeinfer.func_ir.get_definition(folded[idx]) 597 if isinstance(maybe_arg, ir.Arg): 598 requested.add(maybe_arg.index) 599 else: 600 unsatisified.add(idx) 601 if unsatisified: 602 raise TypingError("Cannot request literal type.", loc=self.loc) 603 elif requested: 604 raise ForceLiteralArg(requested, loc=self.loc) 605 if sig is None: 606 # Note: duplicated error checking. 607 # See types.BaseFunction.get_call_type 608 # Arguments are invalid => explain why 609 headtemp = "Invalid use of {0} with parameters ({1})" 610 args = [str(a) for a in pos_args] 611 args += ["%s=%s" % (k, v) for k, v in sorted(kw_args.items())] 612 head = headtemp.format(fnty, ', '.join(map(str, args))) 613 desc = context.explain_function_type(fnty) 614 msg = '\n'.join([head, desc]) 615 raise TypingError(msg) 616 617 typeinfer.add_type(self.target, sig.return_type, loc=self.loc) 618 619 # If the function is a bound function and its receiver type 620 # was refined, propagate it. 621 if (isinstance(fnty, types.BoundFunction) 622 and sig.recvr is not None 623 and sig.recvr != fnty.this): 624 refined_this = context.unify_pairs(sig.recvr, fnty.this) 625 if (refined_this is None and 626 fnty.this.is_precise() and 627 sig.recvr.is_precise()): 628 msg = "Cannot refine type {} to {}".format( 629 sig.recvr, fnty.this, 630 ) 631 raise TypingError(msg, loc=self.loc) 632 if refined_this is not None and refined_this.is_precise(): 633 refined_fnty = fnty.copy(this=refined_this) 634 typeinfer.propagate_refined_type(self.func, refined_fnty) 635 636 # If the return type is imprecise but can be unified with the 637 # target variable's inferred type, use the latter. 638 # Useful for code such as:: 639 # s = set() 640 # s.add(1) 641 # (the set() call must be typed as int64(), not undefined()) 642 if not sig.return_type.is_precise(): 643 target = typevars[self.target] 644 if target.defined: 645 targetty = target.getone() 646 if context.unify_pairs(targetty, sig.return_type) == targetty: 647 sig = sig.replace(return_type=targetty) 648 649 self.signature = sig 650 self._add_refine_map(typeinfer, typevars, sig) 651 652 def _add_refine_map(self, typeinfer, typevars, sig): 653 """Add this expression to the refine_map base on the type of target_type 654 """ 655 target_type = typevars[self.target].getone() 656 # Array 657 if (isinstance(target_type, types.Array) 658 and isinstance(sig.return_type.dtype, types.Undefined)): 659 typeinfer.refine_map[self.target] = self 660 # DictType 661 if (isinstance(target_type, types.DictType) and 662 not target_type.is_precise()): 663 typeinfer.refine_map[self.target] = self 664 665 def refine(self, typeinfer, updated_type): 666 # Is getitem? 667 if self.func == operator.getitem: 668 aryty = typeinfer.typevars[self.args[0].name].getone() 669 # is array not precise? 670 if _is_array_not_precise(aryty): 671 # allow refinement of dtype 672 assert updated_type.is_precise() 673 newtype = aryty.copy(dtype=updated_type.dtype) 674 typeinfer.add_type(self.args[0].name, newtype, loc=self.loc) 675 else: 676 m = 'no type refinement implemented for function {} updating to {}' 677 raise TypingError(m.format(self.func, updated_type)) 678 679 def get_call_signature(self): 680 return self.signature 681 682 683class IntrinsicCallConstraint(CallConstraint): 684 def __call__(self, typeinfer): 685 with new_error_context("typing of intrinsic-call at {0}", self.loc): 686 fnty = self.func 687 if fnty in utils.OPERATORS_TO_BUILTINS: 688 fnty = typeinfer.resolve_value_type(None, fnty) 689 self.resolve(typeinfer, typeinfer.typevars, fnty=fnty) 690 691 692class GetAttrConstraint(object): 693 def __init__(self, target, attr, value, loc, inst): 694 self.target = target 695 self.attr = attr 696 self.value = value 697 self.loc = loc 698 self.inst = inst 699 700 def __call__(self, typeinfer): 701 with new_error_context("typing of get attribute at {0}", self.loc): 702 typevars = typeinfer.typevars 703 valtys = typevars[self.value.name].get() 704 for ty in valtys: 705 attrty = typeinfer.context.resolve_getattr(ty, self.attr) 706 if attrty is None: 707 raise UntypedAttributeError(ty, self.attr, 708 loc=self.inst.loc) 709 else: 710 assert attrty.is_precise() 711 typeinfer.add_type(self.target, attrty, loc=self.loc) 712 typeinfer.refine_map[self.target] = self 713 714 def refine(self, typeinfer, target_type): 715 if isinstance(target_type, types.BoundFunction): 716 recvr = target_type.this 717 assert recvr.is_precise() 718 typeinfer.add_type(self.value.name, recvr, loc=self.loc) 719 source_constraint = typeinfer.refine_map.get(self.value.name) 720 if source_constraint is not None: 721 source_constraint.refine(typeinfer, recvr) 722 723 def __repr__(self): 724 return 'resolving type of attribute "{attr}" of "{value}"'.format( 725 value=self.value, attr=self.attr) 726 727 728class SetItemRefinement(object): 729 """A mixin class to provide the common refinement logic in setitem 730 and static setitem. 731 """ 732 733 def _refine_target_type(self, typeinfer, targetty, idxty, valty, sig): 734 """Refine the target-type given the known index type and value type. 735 """ 736 # For array setitem, refine imprecise array dtype 737 if _is_array_not_precise(targetty): 738 typeinfer.add_type(self.target.name, sig.args[0], loc=self.loc) 739 # For Dict setitem 740 if isinstance(targetty, types.DictType): 741 if not targetty.is_precise(): 742 refined = targetty.refine(idxty, valty) 743 typeinfer.add_type( 744 self.target.name, refined, 745 loc=self.loc, 746 ) 747 elif isinstance(targetty, types.LiteralStrKeyDict): 748 typeinfer.add_type( 749 self.target.name, types.DictType(idxty, valty), 750 loc=self.loc, 751 ) 752 753 754class SetItemConstraint(SetItemRefinement): 755 def __init__(self, target, index, value, loc): 756 self.target = target 757 self.index = index 758 self.value = value 759 self.loc = loc 760 761 def __call__(self, typeinfer): 762 with new_error_context("typing of setitem at {0}", self.loc): 763 typevars = typeinfer.typevars 764 if not all(typevars[var.name].defined 765 for var in (self.target, self.index, self.value)): 766 return 767 targetty = typevars[self.target.name].getone() 768 idxty = typevars[self.index.name].getone() 769 valty = typevars[self.value.name].getone() 770 771 sig = typeinfer.context.resolve_setitem(targetty, idxty, valty) 772 if sig is None: 773 raise TypingError("Cannot resolve setitem: %s[%s] = %s" % 774 (targetty, idxty, valty), loc=self.loc) 775 776 self.signature = sig 777 self._refine_target_type(typeinfer, targetty, idxty, valty, sig) 778 779 def get_call_signature(self): 780 return self.signature 781 782 783class StaticSetItemConstraint(SetItemRefinement): 784 def __init__(self, target, index, index_var, value, loc): 785 self.target = target 786 self.index = index 787 self.index_var = index_var 788 self.value = value 789 self.loc = loc 790 791 def __call__(self, typeinfer): 792 with new_error_context("typing of staticsetitem at {0}", self.loc): 793 typevars = typeinfer.typevars 794 if not all(typevars[var.name].defined 795 for var in (self.target, self.index_var, self.value)): 796 return 797 targetty = typevars[self.target.name].getone() 798 idxty = typevars[self.index_var.name].getone() 799 valty = typevars[self.value.name].getone() 800 801 sig = typeinfer.context.resolve_static_setitem(targetty, 802 self.index, valty) 803 if sig is None: 804 sig = typeinfer.context.resolve_setitem(targetty, idxty, valty) 805 if sig is None: 806 raise TypingError("Cannot resolve setitem: %s[%r] = %s" % 807 (targetty, self.index, valty), loc=self.loc) 808 self.signature = sig 809 self._refine_target_type(typeinfer, targetty, idxty, valty, sig) 810 811 def get_call_signature(self): 812 return self.signature 813 814 815class DelItemConstraint(object): 816 def __init__(self, target, index, loc): 817 self.target = target 818 self.index = index 819 self.loc = loc 820 821 def __call__(self, typeinfer): 822 with new_error_context("typing of delitem at {0}", self.loc): 823 typevars = typeinfer.typevars 824 if not all(typevars[var.name].defined 825 for var in (self.target, self.index)): 826 return 827 targetty = typevars[self.target.name].getone() 828 idxty = typevars[self.index.name].getone() 829 830 sig = typeinfer.context.resolve_delitem(targetty, idxty) 831 if sig is None: 832 raise TypingError("Cannot resolve delitem: %s[%s]" % 833 (targetty, idxty), loc=self.loc) 834 self.signature = sig 835 836 def get_call_signature(self): 837 return self.signature 838 839 840class SetAttrConstraint(object): 841 def __init__(self, target, attr, value, loc): 842 self.target = target 843 self.attr = attr 844 self.value = value 845 self.loc = loc 846 847 def __call__(self, typeinfer): 848 with new_error_context("typing of set attribute {0!r} at {1}", 849 self.attr, self.loc): 850 typevars = typeinfer.typevars 851 if not all(typevars[var.name].defined 852 for var in (self.target, self.value)): 853 return 854 targetty = typevars[self.target.name].getone() 855 valty = typevars[self.value.name].getone() 856 sig = typeinfer.context.resolve_setattr(targetty, self.attr, 857 valty) 858 if sig is None: 859 raise TypingError("Cannot resolve setattr: (%s).%s = %s" % 860 (targetty, self.attr, valty), 861 loc=self.loc) 862 self.signature = sig 863 864 def get_call_signature(self): 865 return self.signature 866 867 868class PrintConstraint(object): 869 def __init__(self, args, vararg, loc): 870 self.args = args 871 self.vararg = vararg 872 self.loc = loc 873 874 def __call__(self, typeinfer): 875 typevars = typeinfer.typevars 876 877 r = fold_arg_vars(typevars, self.args, self.vararg, {}) 878 if r is None: 879 # Cannot resolve call type until all argument types are known 880 return 881 pos_args, kw_args = r 882 883 fnty = typeinfer.context.resolve_value_type(print) 884 assert fnty is not None 885 sig = typeinfer.resolve_call(fnty, pos_args, kw_args) 886 self.signature = sig 887 888 def get_call_signature(self): 889 return self.signature 890 891 892class TypeVarMap(dict): 893 def set_context(self, context): 894 self.context = context 895 896 def __getitem__(self, name): 897 if name not in self: 898 self[name] = TypeVar(self.context, name) 899 return super(TypeVarMap, self).__getitem__(name) 900 901 def __setitem__(self, name, value): 902 assert isinstance(name, str) 903 if name in self: 904 raise KeyError("Cannot redefine typevar %s" % name) 905 else: 906 super(TypeVarMap, self).__setitem__(name, value) 907 908 909# A temporary mapping of {function name: dispatcher object} 910_temporary_dispatcher_map = {} 911# A temporary mapping of {function name: dispatcher object reference count} 912# Reference: https://github.com/numba/numba/issues/3658 913_temporary_dispatcher_map_ref_count = defaultdict(int) 914 915 916@contextlib.contextmanager 917def register_dispatcher(disp): 918 """ 919 Register a Dispatcher for inference while it is not yet stored 920 as global or closure variable (e.g. during execution of the @jit() 921 call). This allows resolution of recursive calls with eager 922 compilation. 923 """ 924 assert callable(disp) 925 assert callable(disp.py_func) 926 name = disp.py_func.__name__ 927 _temporary_dispatcher_map[name] = disp 928 _temporary_dispatcher_map_ref_count[name] += 1 929 try: 930 yield 931 finally: 932 _temporary_dispatcher_map_ref_count[name] -= 1 933 if not _temporary_dispatcher_map_ref_count[name]: 934 del _temporary_dispatcher_map[name] 935 936 937typeinfer_extensions = {} 938 939 940class TypeInferer(object): 941 """ 942 Operates on block that shares the same ir.Scope. 943 """ 944 945 def __init__(self, context, func_ir, warnings): 946 self.context = context 947 # sort based on label, ensure iteration order! 948 self.blocks = OrderedDict() 949 for k in sorted(func_ir.blocks.keys()): 950 self.blocks[k] = func_ir.blocks[k] 951 self.generator_info = func_ir.generator_info 952 self.func_id = func_ir.func_id 953 self.func_ir = func_ir 954 955 self.typevars = TypeVarMap() 956 self.typevars.set_context(context) 957 self.constraints = ConstraintNetwork() 958 self.warnings = warnings 959 960 # { index: mangled name } 961 self.arg_names = {} 962 # self.return_type = None 963 # Set of assumed immutable globals 964 self.assumed_immutables = set() 965 # Track all calls and associated constraints 966 self.calls = [] 967 # The inference result of the above calls 968 self.calltypes = utils.UniqueDict() 969 # Target var -> constraint with refine hook 970 self.refine_map = {} 971 972 if config.DEBUG or config.DEBUG_TYPEINFER: 973 self.debug = TypeInferDebug(self) 974 else: 975 self.debug = NullDebug() 976 977 self._skip_recursion = False 978 979 def copy(self, skip_recursion=False): 980 clone = TypeInferer(self.context, self.func_ir, self.warnings) 981 clone.arg_names = self.arg_names.copy() 982 clone._skip_recursion = skip_recursion 983 984 for k, v in self.typevars.items(): 985 if not v.locked and v.defined: 986 clone.typevars[k].add_type(v.getone(), loc=v.define_loc) 987 988 return clone 989 990 def _mangle_arg_name(self, name): 991 # Disambiguise argument name 992 return "arg.%s" % (name,) 993 994 def _get_return_vars(self): 995 rets = [] 996 for blk in utils.itervalues(self.blocks): 997 inst = blk.terminator 998 if isinstance(inst, ir.Return): 999 rets.append(inst.value) 1000 return rets 1001 1002 def get_argument_types(self): 1003 return [self.typevars[k].getone() for k in self.arg_names.values()] 1004 1005 def seed_argument(self, name, index, typ): 1006 name = self._mangle_arg_name(name) 1007 self.seed_type(name, typ) 1008 self.arg_names[index] = name 1009 1010 def seed_type(self, name, typ): 1011 """All arguments should be seeded. 1012 """ 1013 self.lock_type(name, typ, loc=None) 1014 1015 def seed_return(self, typ): 1016 """Seeding of return value is optional. 1017 """ 1018 for var in self._get_return_vars(): 1019 self.lock_type(var.name, typ, loc=None) 1020 1021 def build_constraint(self): 1022 for blk in utils.itervalues(self.blocks): 1023 for inst in blk.body: 1024 self.constrain_statement(inst) 1025 1026 def return_types_from_partial(self): 1027 """ 1028 Resume type inference partially to deduce the return type. 1029 Note: No side-effect to `self`. 1030 1031 Returns the inferred return type or None if it cannot deduce the return 1032 type. 1033 """ 1034 # Clone the typeinferer and disable typing recursive calls 1035 cloned = self.copy(skip_recursion=True) 1036 # rebuild constraint network 1037 cloned.build_constraint() 1038 # propagate without raising 1039 cloned.propagate(raise_errors=False) 1040 # get return types 1041 rettypes = set() 1042 for retvar in cloned._get_return_vars(): 1043 if retvar.name in cloned.typevars: 1044 typevar = cloned.typevars[retvar.name] 1045 if typevar and typevar.defined: 1046 rettypes.add(types.unliteral(typevar.getone())) 1047 if not rettypes: 1048 return 1049 # unify return types 1050 return cloned._unify_return_types(rettypes) 1051 1052 def propagate(self, raise_errors=True): 1053 newtoken = self.get_state_token() 1054 oldtoken = None 1055 # Since the number of types are finite, the typesets will eventually 1056 # stop growing. 1057 1058 while newtoken != oldtoken: 1059 self.debug.propagate_started() 1060 oldtoken = newtoken 1061 # Errors can appear when the type set is incomplete; only 1062 # raise them when there is no progress anymore. 1063 errors = self.constraints.propagate(self) 1064 newtoken = self.get_state_token() 1065 self.debug.propagate_finished() 1066 if errors: 1067 if raise_errors: 1068 force_lit_args = [e for e in errors 1069 if isinstance(e, ForceLiteralArg)] 1070 if not force_lit_args: 1071 raise errors[0] 1072 else: 1073 raise reduce(operator.or_, force_lit_args) 1074 else: 1075 return errors 1076 1077 def add_type(self, var, tp, loc, unless_locked=False): 1078 assert isinstance(var, str), type(var) 1079 tv = self.typevars[var] 1080 if unless_locked and tv.locked: 1081 return 1082 oldty = tv.type 1083 unified = tv.add_type(tp, loc=loc) 1084 if unified != oldty: 1085 self.propagate_refined_type(var, unified) 1086 1087 def add_calltype(self, inst, signature): 1088 assert signature is not None 1089 self.calltypes[inst] = signature 1090 1091 def copy_type(self, src_var, dest_var, loc): 1092 self.typevars[dest_var].union(self.typevars[src_var], loc=loc) 1093 1094 def lock_type(self, var, tp, loc, literal_value=NOTSET): 1095 tv = self.typevars[var] 1096 tv.lock(tp, loc=loc, literal_value=literal_value) 1097 1098 def propagate_refined_type(self, updated_var, updated_type): 1099 source_constraint = self.refine_map.get(updated_var) 1100 if source_constraint is not None: 1101 source_constraint.refine(self, updated_type) 1102 1103 def unify(self, raise_errors=True): 1104 """ 1105 Run the final unification pass over all inferred types, and 1106 catch imprecise types. 1107 """ 1108 typdict = utils.UniqueDict() 1109 1110 def find_offender(name, exhaustive=False): 1111 # finds the offending variable definition by name 1112 # if exhaustive is set it will try and trace through temporary 1113 # variables to find a concrete offending definition. 1114 offender = None 1115 for block in self.func_ir.blocks.values(): 1116 offender = block.find_variable_assignment(name) 1117 if offender is not None: 1118 if not exhaustive: 1119 break 1120 try: # simple assignment 1121 hasattr(offender.value, 'name') 1122 offender_value = offender.value.name 1123 except (AttributeError, KeyError): 1124 break 1125 orig_offender = offender 1126 if offender_value.startswith('$'): 1127 offender = find_offender(offender_value, 1128 exhaustive=exhaustive) 1129 if offender is None: 1130 offender = orig_offender 1131 break 1132 return offender 1133 1134 def diagnose_imprecision(offender): 1135 # helper for diagnosing imprecise types 1136 1137 list_msg = """\n 1138For Numba to be able to compile a list, the list must have a known and 1139precise type that can be inferred from the other variables. Whilst sometimes 1140the type of empty lists can be inferred, this is not always the case, see this 1141documentation for help: 1142 1143https://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-has-an-untyped-list-problem 1144""" 1145 if offender is not None: 1146 # This block deals with imprecise lists 1147 if hasattr(offender, 'value'): 1148 if hasattr(offender.value, 'op'): 1149 # might be `foo = []` 1150 if offender.value.op == 'build_list': 1151 return list_msg 1152 # or might be `foo = list()` 1153 elif offender.value.op == 'call': 1154 try: # assignment involving a call 1155 call_name = offender.value.func.name 1156 # find the offender based on the call name 1157 offender = find_offender(call_name) 1158 if isinstance(offender.value, ir.Global): 1159 if offender.value.name == 'list': 1160 return list_msg 1161 except (AttributeError, KeyError): 1162 pass 1163 return "" # no help possible 1164 1165 def check_var(name): 1166 tv = self.typevars[name] 1167 if not tv.defined: 1168 if raise_errors: 1169 offender = find_offender(name) 1170 val = getattr(offender, 'value', 'unknown operation') 1171 loc = getattr(offender, 'loc', ir.unknown_loc) 1172 msg = ("Type of variable '%s' cannot be determined, " 1173 "operation: %s, location: %s") 1174 raise TypingError(msg % (var, val, loc), loc) 1175 else: 1176 typdict[var] = types.unknown 1177 return 1178 tp = tv.getone() 1179 1180 if isinstance(tp, types.UndefinedFunctionType): 1181 tp = tp.get_precise() 1182 1183 if not tp.is_precise(): 1184 offender = find_offender(name, exhaustive=True) 1185 msg = ("Cannot infer the type of variable '%s'%s, " 1186 "have imprecise type: %s. %s") 1187 istmp = " (temporary variable)" if var.startswith('$') else "" 1188 loc = getattr(offender, 'loc', ir.unknown_loc) 1189 # is this an untyped list? try and provide help 1190 extra_msg = diagnose_imprecision(offender) 1191 if raise_errors: 1192 raise TypingError(msg % (var, istmp, tp, extra_msg), loc) 1193 else: 1194 typdict[var] = types.unknown 1195 return 1196 else: # type is precise, hold it 1197 typdict[var] = tp 1198 1199 # For better error display, check first user-visible vars, then 1200 # temporaries 1201 temps = set(k for k in self.typevars if not k[0].isalpha()) 1202 others = set(self.typevars) - temps 1203 for var in sorted(others): 1204 check_var(var) 1205 for var in sorted(temps): 1206 check_var(var) 1207 1208 try: 1209 retty = self.get_return_type(typdict) 1210 except Exception as e: 1211 # partial type inference may raise e.g. attribute error if a 1212 # constraint has no computable signature, ignore this as needed 1213 if raise_errors: 1214 raise e 1215 else: 1216 retty = None 1217 1218 try: 1219 fntys = self.get_function_types(typdict) 1220 except Exception as e: 1221 # partial type inference may raise e.g. attribute error if a 1222 # constraint has no computable signature, ignore this as needed 1223 if raise_errors: 1224 raise e 1225 else: 1226 fntys = None 1227 1228 if self.generator_info: 1229 retty = self.get_generator_type(typdict, retty, 1230 raise_errors=raise_errors) 1231 1232 self.debug.unify_finished(typdict, retty, fntys) 1233 1234 return typdict, retty, fntys 1235 1236 def get_generator_type(self, typdict, retty, raise_errors=True): 1237 gi = self.generator_info 1238 arg_types = [None] * len(self.arg_names) 1239 for index, name in self.arg_names.items(): 1240 arg_types[index] = typdict[name] 1241 1242 state_types = None 1243 try: 1244 state_types = [typdict[var_name] for var_name in gi.state_vars] 1245 except KeyError: 1246 msg = "Cannot type generator: state variable types cannot be found" 1247 if raise_errors: 1248 raise TypingError(msg) 1249 state_types = [types.unknown for _ in gi.state_vars] 1250 1251 yield_types = None 1252 try: 1253 yield_types = [typdict[y.inst.value.name] 1254 for y in gi.get_yield_points()] 1255 except KeyError: 1256 msg = "Cannot type generator: yield type cannot be found" 1257 if raise_errors: 1258 raise TypingError(msg) 1259 if not yield_types: 1260 msg = "Cannot type generator: it does not yield any value" 1261 if raise_errors: 1262 raise TypingError(msg) 1263 yield_types = [types.unknown for _ in gi.get_yield_points()] 1264 1265 if not yield_types or all(yield_types) == types.unknown: 1266 # unknown yield, probably partial type inference, escape 1267 return types.Generator(self.func_id.func, types.unknown, arg_types, 1268 state_types, has_finalizer=True) 1269 1270 yield_type = self.context.unify_types(*yield_types) 1271 if yield_type is None or isinstance(yield_type, types.Optional): 1272 msg = "Cannot type generator: cannot unify yielded types %s" 1273 yp_highlights = [] 1274 for y in gi.get_yield_points(): 1275 msg = (_termcolor.errmsg("Yield of: IR '%s', type '%s', " 1276 "location: %s")) 1277 yp_highlights.append(msg % (str(y.inst), 1278 typdict[y.inst.value.name], 1279 y.inst.loc.strformat())) 1280 1281 explain_ty = set() 1282 for ty in yield_types: 1283 if isinstance(ty, types.Optional): 1284 explain_ty.add(ty.type) 1285 explain_ty.add(types.NoneType('none')) 1286 else: 1287 explain_ty.add(ty) 1288 if raise_errors: 1289 raise TypingError("Can't unify yield type from the " 1290 "following types: %s" 1291 % ", ".join(sorted(map(str, explain_ty))) + 1292 "\n\n" + "\n".join(yp_highlights)) 1293 1294 return types.Generator(self.func_id.func, yield_type, arg_types, 1295 state_types, has_finalizer=True) 1296 1297 def get_function_types(self, typemap): 1298 """ 1299 Fill and return the calltypes map. 1300 """ 1301 # XXX why can't this be done on the fly? 1302 calltypes = self.calltypes 1303 for call, constraint in self.calls: 1304 calltypes[call] = constraint.get_call_signature() 1305 return calltypes 1306 1307 def _unify_return_types(self, rettypes): 1308 if rettypes: 1309 unified = self.context.unify_types(*rettypes) 1310 if isinstance(unified, types.FunctionType): 1311 # unified is allowed to be UndefinedFunctionType 1312 # instance (that is imprecise). 1313 return unified 1314 if unified is None or not unified.is_precise(): 1315 def check_type(atype): 1316 lst = [] 1317 for k, v in self.typevars.items(): 1318 if atype == v.type: 1319 lst.append(k) 1320 returns = {} 1321 for x in reversed(lst): 1322 for block in self.func_ir.blocks.values(): 1323 for instr in block.find_insts(ir.Return): 1324 value = instr.value 1325 if isinstance(value, ir.Var): 1326 name = value.name 1327 else: 1328 pass 1329 if x == name: 1330 returns[x] = instr 1331 break 1332 1333 interped = "" 1334 for name, offender in returns.items(): 1335 loc = getattr(offender, 'loc', ir.unknown_loc) 1336 msg = ("Return of: IR name '%s', type '%s', " 1337 "location: %s") 1338 interped = msg % (name, atype, loc.strformat()) 1339 return interped 1340 1341 problem_str = [] 1342 for xtype in rettypes: 1343 problem_str.append(_termcolor.errmsg(check_type(xtype))) 1344 1345 raise TypingError("Can't unify return type from the " 1346 "following types: %s" 1347 % ", ".join(sorted(map(str, rettypes))) + 1348 "\n" + "\n".join(problem_str)) 1349 return unified 1350 else: 1351 # Function without a successful return path 1352 return types.none 1353 1354 def get_return_type(self, typemap): 1355 rettypes = set() 1356 for var in self._get_return_vars(): 1357 rettypes.add(typemap[var.name]) 1358 return self._unify_return_types(rettypes) 1359 1360 def get_state_token(self): 1361 """The algorithm is monotonic. It can only grow or "refine" the 1362 typevar map. 1363 """ 1364 return [tv.type for name, tv in sorted(self.typevars.items())] 1365 1366 def constrain_statement(self, inst): 1367 if isinstance(inst, ir.Assign): 1368 self.typeof_assign(inst) 1369 elif isinstance(inst, ir.SetItem): 1370 self.typeof_setitem(inst) 1371 elif isinstance(inst, ir.StaticSetItem): 1372 self.typeof_static_setitem(inst) 1373 elif isinstance(inst, ir.DelItem): 1374 self.typeof_delitem(inst) 1375 elif isinstance(inst, ir.SetAttr): 1376 self.typeof_setattr(inst) 1377 elif isinstance(inst, ir.Print): 1378 self.typeof_print(inst) 1379 elif isinstance(inst, ir.StoreMap): 1380 self.typeof_storemap(inst) 1381 elif isinstance(inst, (ir.Jump, ir.Branch, ir.Return, ir.Del)): 1382 pass 1383 elif isinstance(inst, (ir.StaticRaise, ir.StaticTryRaise)): 1384 pass 1385 elif type(inst) in typeinfer_extensions: 1386 # let external calls handle stmt if type matches 1387 f = typeinfer_extensions[type(inst)] 1388 f(inst, self) 1389 else: 1390 msg = "Unsupported constraint encountered: %s" % inst 1391 raise UnsupportedError(msg, loc=inst.loc) 1392 1393 def typeof_setitem(self, inst): 1394 constraint = SetItemConstraint(target=inst.target, index=inst.index, 1395 value=inst.value, loc=inst.loc) 1396 self.constraints.append(constraint) 1397 self.calls.append((inst, constraint)) 1398 1399 def typeof_storemap(self, inst): 1400 constraint = SetItemConstraint(target=inst.dct, index=inst.key, 1401 value=inst.value, loc=inst.loc) 1402 self.constraints.append(constraint) 1403 self.calls.append((inst, constraint)) 1404 1405 def typeof_static_setitem(self, inst): 1406 constraint = StaticSetItemConstraint(target=inst.target, 1407 index=inst.index, 1408 index_var=inst.index_var, 1409 value=inst.value, loc=inst.loc) 1410 self.constraints.append(constraint) 1411 self.calls.append((inst, constraint)) 1412 1413 def typeof_delitem(self, inst): 1414 constraint = DelItemConstraint(target=inst.target, index=inst.index, 1415 loc=inst.loc) 1416 self.constraints.append(constraint) 1417 self.calls.append((inst, constraint)) 1418 1419 def typeof_setattr(self, inst): 1420 constraint = SetAttrConstraint(target=inst.target, attr=inst.attr, 1421 value=inst.value, loc=inst.loc) 1422 self.constraints.append(constraint) 1423 self.calls.append((inst, constraint)) 1424 1425 def typeof_print(self, inst): 1426 constraint = PrintConstraint(args=inst.args, vararg=inst.vararg, 1427 loc=inst.loc) 1428 self.constraints.append(constraint) 1429 self.calls.append((inst, constraint)) 1430 1431 def typeof_assign(self, inst): 1432 value = inst.value 1433 if isinstance(value, ir.Const): 1434 self.typeof_const(inst, inst.target, value.value) 1435 elif isinstance(value, ir.Var): 1436 self.constraints.append(Propagate(dst=inst.target.name, 1437 src=value.name, loc=inst.loc)) 1438 elif isinstance(value, (ir.Global, ir.FreeVar)): 1439 self.typeof_global(inst, inst.target, value) 1440 elif isinstance(value, ir.Arg): 1441 self.typeof_arg(inst, inst.target, value) 1442 elif isinstance(value, ir.Expr): 1443 self.typeof_expr(inst, inst.target, value) 1444 elif isinstance(value, ir.Yield): 1445 self.typeof_yield(inst, inst.target, value) 1446 else: 1447 msg = ("Unsupported assignment encountered: %s %s" % 1448 (type(value), str(value))) 1449 raise UnsupportedError(msg, loc=inst.loc) 1450 1451 def resolve_value_type(self, inst, val): 1452 """ 1453 Resolve the type of a simple Python value, such as can be 1454 represented by literals. 1455 """ 1456 try: 1457 return self.context.resolve_value_type(val) 1458 except ValueError as e: 1459 msg = str(e) 1460 raise TypingError(msg, loc=inst.loc) 1461 1462 def typeof_arg(self, inst, target, arg): 1463 src_name = self._mangle_arg_name(arg.name) 1464 self.constraints.append(ArgConstraint(dst=target.name, 1465 src=src_name, 1466 loc=inst.loc)) 1467 1468 def typeof_const(self, inst, target, const): 1469 ty = self.resolve_value_type(inst, const) 1470 if inst.value.use_literal_type: 1471 lit = types.maybe_literal(value=const) 1472 else: 1473 lit = None 1474 self.add_type(target.name, lit or ty, loc=inst.loc) 1475 1476 def typeof_yield(self, inst, target, yield_): 1477 # Sending values into generators isn't supported. 1478 self.add_type(target.name, types.none, loc=inst.loc) 1479 1480 def sentry_modified_builtin(self, inst, gvar): 1481 """ 1482 Ensure that builtins are not modified. 1483 """ 1484 if (gvar.name in ('range', 'xrange') and 1485 gvar.value not in utils.RANGE_ITER_OBJECTS): 1486 bad = True 1487 elif gvar.name == 'slice' and gvar.value is not slice: 1488 bad = True 1489 elif gvar.name == 'len' and gvar.value is not len: 1490 bad = True 1491 else: 1492 bad = False 1493 1494 if bad: 1495 raise TypingError("Modified builtin '%s'" % gvar.name, 1496 loc=inst.loc) 1497 1498 def resolve_call(self, fnty, pos_args, kw_args): 1499 """ 1500 Resolve a call to a given function type. A signature is returned. 1501 """ 1502 if isinstance(fnty, types.FunctionType): 1503 return fnty.get_call_type(self, pos_args, kw_args) 1504 if isinstance(fnty, types.RecursiveCall) and not self._skip_recursion: 1505 # Recursive call 1506 disp = fnty.dispatcher_type.dispatcher 1507 pysig, args = disp.fold_argument_types(pos_args, kw_args) 1508 1509 frame = self.context.callstack.match(disp.py_func, args) 1510 1511 # If the signature is not being compiled 1512 if frame is None: 1513 sig = self.context.resolve_function_type(fnty.dispatcher_type, 1514 pos_args, kw_args) 1515 fndesc = disp.overloads[args].fndesc 1516 fnty.overloads[args] = qualifying_prefix(fndesc.modname, 1517 fndesc.unique_name) 1518 return sig 1519 1520 fnid = frame.func_id 1521 fnty.overloads[args] = qualifying_prefix(fnid.modname, 1522 fnid.unique_name) 1523 # Resume propagation in parent frame 1524 return_type = frame.typeinfer.return_types_from_partial() 1525 # No known return type 1526 if return_type is None: 1527 raise TypingError("cannot type infer runaway recursion") 1528 1529 sig = typing.signature(return_type, *args) 1530 sig = sig.replace(pysig=pysig) 1531 # Keep track of unique return_type 1532 frame.add_return_type(return_type) 1533 return sig 1534 else: 1535 # Normal non-recursive call 1536 return self.context.resolve_function_type(fnty, pos_args, kw_args) 1537 1538 def typeof_global(self, inst, target, gvar): 1539 try: 1540 typ = self.resolve_value_type(inst, gvar.value) 1541 except TypingError as e: 1542 if (gvar.name == self.func_id.func_name 1543 and gvar.name in _temporary_dispatcher_map): 1544 # Self-recursion case where the dispatcher is not (yet?) known 1545 # as a global variable 1546 typ = types.Dispatcher(_temporary_dispatcher_map[gvar.name]) 1547 else: 1548 from numba.misc import special 1549 1550 nm = gvar.name 1551 # check if the problem is actually a name error 1552 func_glbls = self.func_id.func.__globals__ 1553 if (nm not in func_glbls.keys() and 1554 nm not in special.__all__ and 1555 nm not in __builtins__.keys() and 1556 nm not in self.func_id.code.co_freevars): 1557 errstr = "NameError: name '%s' is not defined" 1558 msg = _termcolor.errmsg(errstr % nm) 1559 e.patch_message(msg) 1560 raise 1561 else: 1562 msg = _termcolor.errmsg("Untyped global name '%s':" % nm) 1563 msg += " %s" # interps the actual error 1564 1565 # if the untyped global is a numba internal function then add 1566 # to the error message asking if it's been imported. 1567 1568 if nm in special.__all__: 1569 tmp = ("\n'%s' looks like a Numba internal function, has " 1570 "it been imported (i.e. 'from numba import %s')?\n" % 1571 (nm, nm)) 1572 msg += _termcolor.errmsg(tmp) 1573 e.patch_message(msg % e) 1574 raise 1575 1576 if isinstance(typ, types.Dispatcher) and typ.dispatcher.is_compiling: 1577 # Recursive call 1578 callstack = self.context.callstack 1579 callframe = callstack.findfirst(typ.dispatcher.py_func) 1580 if callframe is not None: 1581 typ = types.RecursiveCall(typ) 1582 else: 1583 raise NotImplementedError( 1584 "call to %s: unsupported recursion" 1585 % typ.dispatcher) 1586 1587 if isinstance(typ, types.Array): 1588 # Global array in nopython mode is constant 1589 typ = typ.copy(readonly=True) 1590 1591 if isinstance(typ, types.BaseAnonymousTuple): 1592 # if it's a tuple of literal types, swap the type for the more 1593 # specific literal version 1594 literaled = [types.maybe_literal(x) for x in gvar.value] 1595 if all(literaled): 1596 typ = types.Tuple(literaled) 1597 1598 self.sentry_modified_builtin(inst, gvar) 1599 # Setting literal_value for globals because they are handled 1600 # like const value in numba 1601 lit = types.maybe_literal(gvar.value) 1602 self.lock_type(target.name, lit or typ, loc=inst.loc) 1603 self.assumed_immutables.add(inst) 1604 1605 def typeof_expr(self, inst, target, expr): 1606 if expr.op == 'call': 1607 if isinstance(expr.func, ir.Intrinsic): 1608 sig = expr.func.type 1609 self.add_type(target.name, sig.return_type, loc=inst.loc) 1610 self.add_calltype(expr, sig) 1611 else: 1612 self.typeof_call(inst, target, expr) 1613 elif expr.op in ('getiter', 'iternext'): 1614 self.typeof_intrinsic_call(inst, target, expr.op, expr.value) 1615 elif expr.op == 'exhaust_iter': 1616 constraint = ExhaustIterConstraint(target.name, count=expr.count, 1617 iterator=expr.value, 1618 loc=expr.loc) 1619 self.constraints.append(constraint) 1620 elif expr.op == 'pair_first': 1621 constraint = PairFirstConstraint(target.name, pair=expr.value, 1622 loc=expr.loc) 1623 self.constraints.append(constraint) 1624 elif expr.op == 'pair_second': 1625 constraint = PairSecondConstraint(target.name, pair=expr.value, 1626 loc=expr.loc) 1627 self.constraints.append(constraint) 1628 elif expr.op == 'binop': 1629 self.typeof_intrinsic_call(inst, target, expr.fn, expr.lhs, 1630 expr.rhs) 1631 elif expr.op == 'inplace_binop': 1632 self.typeof_intrinsic_call(inst, target, expr.fn, 1633 expr.lhs, expr.rhs) 1634 elif expr.op == 'unary': 1635 self.typeof_intrinsic_call(inst, target, expr.fn, expr.value) 1636 elif expr.op == 'static_getitem': 1637 constraint = StaticGetItemConstraint(target.name, value=expr.value, 1638 index=expr.index, 1639 index_var=expr.index_var, 1640 loc=expr.loc) 1641 self.constraints.append(constraint) 1642 self.calls.append((inst.value, constraint)) 1643 elif expr.op == 'getitem': 1644 self.typeof_intrinsic_call(inst, target, operator.getitem, 1645 expr.value, expr.index,) 1646 elif expr.op == 'typed_getitem': 1647 constraint = TypedGetItemConstraint(target.name, value=expr.value, 1648 dtype=expr.dtype, 1649 index=expr.index, 1650 loc=expr.loc) 1651 self.constraints.append(constraint) 1652 self.calls.append((inst.value, constraint)) 1653 1654 elif expr.op == 'getattr': 1655 constraint = GetAttrConstraint(target.name, attr=expr.attr, 1656 value=expr.value, loc=inst.loc, 1657 inst=inst) 1658 self.constraints.append(constraint) 1659 elif expr.op == 'build_tuple': 1660 constraint = BuildTupleConstraint(target.name, items=expr.items, 1661 loc=inst.loc) 1662 self.constraints.append(constraint) 1663 elif expr.op == 'build_list': 1664 constraint = BuildListConstraint(target.name, items=expr.items, 1665 loc=inst.loc) 1666 self.constraints.append(constraint) 1667 elif expr.op == 'build_set': 1668 constraint = BuildSetConstraint(target.name, items=expr.items, 1669 loc=inst.loc) 1670 self.constraints.append(constraint) 1671 elif expr.op == 'build_map': 1672 constraint = BuildMapConstraint( 1673 target.name, 1674 items=expr.items, 1675 special_value=expr.literal_value, 1676 value_indexes=expr.value_indexes, 1677 loc=inst.loc) 1678 self.constraints.append(constraint) 1679 elif expr.op == 'cast': 1680 self.constraints.append(Propagate(dst=target.name, 1681 src=expr.value.name, 1682 loc=inst.loc)) 1683 elif expr.op == 'phi': 1684 for iv in expr.incoming_values: 1685 if iv is not ir.UNDEFINED: 1686 self.constraints.append(Propagate(dst=target.name, 1687 src=iv.name, 1688 loc=inst.loc)) 1689 elif expr.op == 'make_function': 1690 self.lock_type(target.name, types.MakeFunctionLiteral(expr), 1691 loc=inst.loc, literal_value=expr) 1692 else: 1693 msg = "Unsupported op-code encountered: %s" % expr 1694 raise UnsupportedError(msg, loc=inst.loc) 1695 1696 def typeof_call(self, inst, target, call): 1697 constraint = CallConstraint(target.name, call.func.name, call.args, 1698 call.kws, call.vararg, loc=inst.loc) 1699 self.constraints.append(constraint) 1700 self.calls.append((inst.value, constraint)) 1701 1702 def typeof_intrinsic_call(self, inst, target, func, *args): 1703 constraint = IntrinsicCallConstraint(target.name, func, args, 1704 kws=(), vararg=None, loc=inst.loc) 1705 self.constraints.append(constraint) 1706 self.calls.append((inst.value, constraint)) 1707 1708 1709class NullDebug(object): 1710 1711 def propagate_started(self): 1712 pass 1713 1714 def propagate_finished(self): 1715 pass 1716 1717 def unify_finished(self, typdict, retty, fntys): 1718 pass 1719 1720 1721class TypeInferDebug(object): 1722 1723 def __init__(self, typeinfer): 1724 self.typeinfer = typeinfer 1725 1726 def _dump_state(self): 1727 print('---- type variables ----') 1728 pprint([v for k, v in sorted(self.typeinfer.typevars.items())]) 1729 1730 def propagate_started(self): 1731 print("propagate".center(80, '-')) 1732 1733 def propagate_finished(self): 1734 self._dump_state() 1735 1736 def unify_finished(self, typdict, retty, fntys): 1737 print("Variable types".center(80, "-")) 1738 pprint(typdict) 1739 print("Return type".center(80, "-")) 1740 pprint(retty) 1741 print("Call types".center(80, "-")) 1742 pprint(fntys) 1743