1""" 2Support for native homogeneous sets. 3""" 4 5 6import collections 7import contextlib 8import math 9import operator 10 11from llvmlite import ir 12from numba.core import types, typing, cgutils 13from numba.core.imputils import (lower_builtin, lower_cast, 14 iternext_impl, impl_ret_borrowed, 15 impl_ret_new_ref, impl_ret_untracked, 16 for_iter, call_len, RefType) 17from numba.core.utils import cached_property 18from numba.misc import quicksort 19from numba.cpython import slicing 20from numba.extending import intrinsic 21 22 23def get_payload_struct(context, builder, set_type, ptr): 24 """ 25 Given a set value and type, get its payload structure (as a 26 reference, so that mutations are seen by all). 27 """ 28 payload_type = types.SetPayload(set_type) 29 ptrty = context.get_data_type(payload_type).as_pointer() 30 payload = builder.bitcast(ptr, ptrty) 31 return context.make_data_helper(builder, payload_type, ref=payload) 32 33 34def get_entry_size(context, set_type): 35 """ 36 Return the entry size for the given set type. 37 """ 38 llty = context.get_data_type(types.SetEntry(set_type)) 39 return context.get_abi_sizeof(llty) 40 41 42# Note these values are special: 43# - EMPTY is obtained by issuing memset(..., 0xFF) 44# - (unsigned) EMPTY > (unsigned) DELETED > any other hash value 45EMPTY = -1 46DELETED = -2 47FALLBACK = -43 48 49# Minimal size of entries table. Must be a power of 2! 50MINSIZE = 16 51 52# Number of cache-friendly linear probes before switching to non-linear probing 53LINEAR_PROBES = 3 54 55DEBUG_ALLOCS = False 56 57 58def get_hash_value(context, builder, typ, value): 59 """ 60 Compute the hash of the given value. 61 """ 62 typingctx = context.typing_context 63 fnty = typingctx.resolve_value_type(hash) 64 sig = fnty.get_call_type(typingctx, (typ,), {}) 65 fn = context.get_function(fnty, sig) 66 h = fn(builder, (value,)) 67 # Fixup reserved values 68 is_ok = is_hash_used(context, builder, h) 69 fallback = ir.Constant(h.type, FALLBACK) 70 return builder.select(is_ok, h, fallback) 71 72 73@intrinsic 74def _get_hash_value_intrinsic(typingctx, value): 75 def impl(context, builder, typ, args): 76 return get_hash_value(context, builder, value, args[0]) 77 fnty = typingctx.resolve_value_type(hash) 78 sig = fnty.get_call_type(typingctx, (value,), {}) 79 return sig, impl 80 81 82def is_hash_empty(context, builder, h): 83 """ 84 Whether the hash value denotes an empty entry. 85 """ 86 empty = ir.Constant(h.type, EMPTY) 87 return builder.icmp_unsigned('==', h, empty) 88 89def is_hash_deleted(context, builder, h): 90 """ 91 Whether the hash value denotes a deleted entry. 92 """ 93 deleted = ir.Constant(h.type, DELETED) 94 return builder.icmp_unsigned('==', h, deleted) 95 96def is_hash_used(context, builder, h): 97 """ 98 Whether the hash value denotes an active entry. 99 """ 100 # Everything below DELETED is an used entry 101 deleted = ir.Constant(h.type, DELETED) 102 return builder.icmp_unsigned('<', h, deleted) 103 104 105SetLoop = collections.namedtuple('SetLoop', ('index', 'entry', 'do_break')) 106 107 108class _SetPayload(object): 109 110 def __init__(self, context, builder, set_type, ptr): 111 payload = get_payload_struct(context, builder, set_type, ptr) 112 self._context = context 113 self._builder = builder 114 self._ty = set_type 115 self._payload = payload 116 self._entries = payload._get_ptr_by_name('entries') 117 self._ptr = ptr 118 119 @property 120 def mask(self): 121 return self._payload.mask 122 123 @mask.setter 124 def mask(self, value): 125 # CAUTION: mask must be a power of 2 minus 1 126 self._payload.mask = value 127 128 @property 129 def used(self): 130 return self._payload.used 131 132 @used.setter 133 def used(self, value): 134 self._payload.used = value 135 136 @property 137 def fill(self): 138 return self._payload.fill 139 140 @fill.setter 141 def fill(self, value): 142 self._payload.fill = value 143 144 @property 145 def finger(self): 146 return self._payload.finger 147 148 @finger.setter 149 def finger(self, value): 150 self._payload.finger = value 151 152 @property 153 def dirty(self): 154 return self._payload.dirty 155 156 @dirty.setter 157 def dirty(self, value): 158 self._payload.dirty = value 159 160 @property 161 def entries(self): 162 """ 163 A pointer to the start of the entries array. 164 """ 165 return self._entries 166 167 @property 168 def ptr(self): 169 """ 170 A pointer to the start of the NRT-allocated area. 171 """ 172 return self._ptr 173 174 def get_entry(self, idx): 175 """ 176 Get entry number *idx*. 177 """ 178 entry_ptr = cgutils.gep(self._builder, self._entries, idx) 179 entry = self._context.make_data_helper(self._builder, 180 types.SetEntry(self._ty), 181 ref=entry_ptr) 182 return entry 183 184 def _lookup(self, item, h, for_insert=False): 185 """ 186 Lookup the *item* with the given hash values in the entries. 187 188 Return a (found, entry index) tuple: 189 - If found is true, <entry index> points to the entry containing 190 the item. 191 - If found is false, <entry index> points to the empty entry that 192 the item can be written to (only if *for_insert* is true) 193 """ 194 context = self._context 195 builder = self._builder 196 197 intp_t = h.type 198 199 mask = self.mask 200 dtype = self._ty.dtype 201 eqfn = context.get_function(operator.eq, 202 typing.signature(types.boolean, dtype, dtype)) 203 204 one = ir.Constant(intp_t, 1) 205 five = ir.Constant(intp_t, 5) 206 207 # The perturbation value for probing 208 perturb = cgutils.alloca_once_value(builder, h) 209 # The index of the entry being considered: start with (hash & mask) 210 index = cgutils.alloca_once_value(builder, 211 builder.and_(h, mask)) 212 if for_insert: 213 # The index of the first deleted entry in the lookup chain 214 free_index_sentinel = mask.type(-1) # highest unsigned index 215 free_index = cgutils.alloca_once_value(builder, free_index_sentinel) 216 217 bb_body = builder.append_basic_block("lookup.body") 218 bb_found = builder.append_basic_block("lookup.found") 219 bb_not_found = builder.append_basic_block("lookup.not_found") 220 bb_end = builder.append_basic_block("lookup.end") 221 222 def check_entry(i): 223 """ 224 Check entry *i* against the value being searched for. 225 """ 226 entry = self.get_entry(i) 227 entry_hash = entry.hash 228 229 with builder.if_then(builder.icmp_unsigned('==', h, entry_hash)): 230 # Hashes are equal, compare values 231 # (note this also ensures the entry is used) 232 eq = eqfn(builder, (item, entry.key)) 233 with builder.if_then(eq): 234 builder.branch(bb_found) 235 236 with builder.if_then(is_hash_empty(context, builder, entry_hash)): 237 builder.branch(bb_not_found) 238 239 if for_insert: 240 # Memorize the index of the first deleted entry 241 with builder.if_then(is_hash_deleted(context, builder, entry_hash)): 242 j = builder.load(free_index) 243 j = builder.select(builder.icmp_unsigned('==', j, free_index_sentinel), 244 i, j) 245 builder.store(j, free_index) 246 247 # First linear probing. When the number of collisions is small, 248 # the lineary probing loop achieves better cache locality and 249 # is also slightly cheaper computationally. 250 with cgutils.for_range(builder, ir.Constant(intp_t, LINEAR_PROBES)): 251 i = builder.load(index) 252 check_entry(i) 253 i = builder.add(i, one) 254 i = builder.and_(i, mask) 255 builder.store(i, index) 256 257 # If not found after linear probing, switch to a non-linear 258 # perturbation keyed on the unmasked hash value. 259 # XXX how to tell LLVM this branch is unlikely? 260 builder.branch(bb_body) 261 with builder.goto_block(bb_body): 262 i = builder.load(index) 263 check_entry(i) 264 265 # Perturb to go to next entry: 266 # perturb >>= 5 267 # i = (i * 5 + 1 + perturb) & mask 268 p = builder.load(perturb) 269 p = builder.lshr(p, five) 270 i = builder.add(one, builder.mul(i, five)) 271 i = builder.and_(mask, builder.add(i, p)) 272 builder.store(i, index) 273 builder.store(p, perturb) 274 # Loop 275 builder.branch(bb_body) 276 277 with builder.goto_block(bb_not_found): 278 if for_insert: 279 # Not found => for insertion, return the index of the first 280 # deleted entry (if any), to avoid creating an infinite 281 # lookup chain (issue #1913). 282 i = builder.load(index) 283 j = builder.load(free_index) 284 i = builder.select(builder.icmp_unsigned('==', j, free_index_sentinel), 285 i, j) 286 builder.store(i, index) 287 builder.branch(bb_end) 288 289 with builder.goto_block(bb_found): 290 builder.branch(bb_end) 291 292 builder.position_at_end(bb_end) 293 294 found = builder.phi(ir.IntType(1), 'found') 295 found.add_incoming(cgutils.true_bit, bb_found) 296 found.add_incoming(cgutils.false_bit, bb_not_found) 297 298 return found, builder.load(index) 299 300 @contextlib.contextmanager 301 def _iterate(self, start=None): 302 """ 303 Iterate over the payload's entries. Yield a SetLoop. 304 """ 305 context = self._context 306 builder = self._builder 307 308 intp_t = context.get_value_type(types.intp) 309 one = ir.Constant(intp_t, 1) 310 size = builder.add(self.mask, one) 311 312 with cgutils.for_range(builder, size, start=start) as range_loop: 313 entry = self.get_entry(range_loop.index) 314 is_used = is_hash_used(context, builder, entry.hash) 315 with builder.if_then(is_used): 316 loop = SetLoop(index=range_loop.index, entry=entry, 317 do_break=range_loop.do_break) 318 yield loop 319 320 @contextlib.contextmanager 321 def _next_entry(self): 322 """ 323 Yield a random entry from the payload. Caller must ensure the 324 set isn't empty, otherwise the function won't end. 325 """ 326 context = self._context 327 builder = self._builder 328 329 intp_t = context.get_value_type(types.intp) 330 zero = ir.Constant(intp_t, 0) 331 one = ir.Constant(intp_t, 1) 332 mask = self.mask 333 334 # Start walking the entries from the stored "search finger" and 335 # break as soon as we find a used entry. 336 337 bb_body = builder.append_basic_block('next_entry_body') 338 bb_end = builder.append_basic_block('next_entry_end') 339 340 index = cgutils.alloca_once_value(builder, self.finger) 341 builder.branch(bb_body) 342 343 with builder.goto_block(bb_body): 344 i = builder.load(index) 345 # ANDing with mask ensures we stay inside the table boundaries 346 i = builder.and_(mask, builder.add(i, one)) 347 builder.store(i, index) 348 entry = self.get_entry(i) 349 is_used = is_hash_used(context, builder, entry.hash) 350 builder.cbranch(is_used, bb_end, bb_body) 351 352 builder.position_at_end(bb_end) 353 354 # Update the search finger with the next position. This avoids 355 # O(n**2) behaviour when pop() is called in a loop. 356 i = builder.load(index) 357 self.finger = i 358 yield self.get_entry(i) 359 360 361class SetInstance(object): 362 363 def __init__(self, context, builder, set_type, set_val): 364 self._context = context 365 self._builder = builder 366 self._ty = set_type 367 self._entrysize = get_entry_size(context, set_type) 368 self._set = context.make_helper(builder, set_type, set_val) 369 370 @property 371 def dtype(self): 372 return self._ty.dtype 373 374 @property 375 def payload(self): 376 """ 377 The _SetPayload for this set. 378 """ 379 # This cannot be cached as the pointer can move around! 380 context = self._context 381 builder = self._builder 382 383 ptr = self._context.nrt.meminfo_data(builder, self.meminfo) 384 return _SetPayload(context, builder, self._ty, ptr) 385 386 @property 387 def value(self): 388 return self._set._getvalue() 389 390 @property 391 def meminfo(self): 392 return self._set.meminfo 393 394 @property 395 def parent(self): 396 return self._set.parent 397 398 @parent.setter 399 def parent(self, value): 400 self._set.parent = value 401 402 def get_size(self): 403 """ 404 Return the number of elements in the size. 405 """ 406 return self.payload.used 407 408 def set_dirty(self, val): 409 if self._ty.reflected: 410 self.payload.dirty = cgutils.true_bit if val else cgutils.false_bit 411 412 def _add_entry(self, payload, entry, item, h, do_resize=True): 413 context = self._context 414 builder = self._builder 415 416 old_hash = entry.hash 417 entry.hash = h 418 entry.key = item 419 # used++ 420 used = payload.used 421 one = ir.Constant(used.type, 1) 422 used = payload.used = builder.add(used, one) 423 # fill++ if entry wasn't a deleted one 424 with builder.if_then(is_hash_empty(context, builder, old_hash), 425 likely=True): 426 payload.fill = builder.add(payload.fill, one) 427 # Grow table if necessary 428 if do_resize: 429 self.upsize(used) 430 self.set_dirty(True) 431 432 def _add_key(self, payload, item, h, do_resize=True): 433 context = self._context 434 builder = self._builder 435 436 found, i = payload._lookup(item, h, for_insert=True) 437 not_found = builder.not_(found) 438 439 with builder.if_then(not_found): 440 # Not found => add it 441 entry = payload.get_entry(i) 442 old_hash = entry.hash 443 entry.hash = h 444 entry.key = item 445 # used++ 446 used = payload.used 447 one = ir.Constant(used.type, 1) 448 used = payload.used = builder.add(used, one) 449 # fill++ if entry wasn't a deleted one 450 with builder.if_then(is_hash_empty(context, builder, old_hash), 451 likely=True): 452 payload.fill = builder.add(payload.fill, one) 453 # Grow table if necessary 454 if do_resize: 455 self.upsize(used) 456 self.set_dirty(True) 457 458 def _remove_entry(self, payload, entry, do_resize=True): 459 # Mark entry deleted 460 entry.hash = ir.Constant(entry.hash.type, DELETED) 461 # used-- 462 used = payload.used 463 one = ir.Constant(used.type, 1) 464 used = payload.used = self._builder.sub(used, one) 465 # Shrink table if necessary 466 if do_resize: 467 self.downsize(used) 468 self.set_dirty(True) 469 470 def _remove_key(self, payload, item, h, do_resize=True): 471 context = self._context 472 builder = self._builder 473 474 found, i = payload._lookup(item, h) 475 476 with builder.if_then(found): 477 entry = payload.get_entry(i) 478 self._remove_entry(payload, entry, do_resize) 479 480 return found 481 482 def add(self, item, do_resize=True): 483 context = self._context 484 builder = self._builder 485 486 payload = self.payload 487 h = get_hash_value(context, builder, self._ty.dtype, item) 488 self._add_key(payload, item, h, do_resize) 489 490 def add_pyapi(self, pyapi, item, do_resize=True): 491 """A version of .add for use inside functions following Python calling 492 convention. 493 """ 494 context = self._context 495 builder = self._builder 496 497 payload = self.payload 498 h = self._pyapi_get_hash_value(pyapi, context, builder, item) 499 self._add_key(payload, item, h, do_resize) 500 501 def _pyapi_get_hash_value(self, pyapi, context, builder, item): 502 """Python API compatible version of `get_hash_value()`. 503 """ 504 argtypes = [self._ty.dtype] 505 resty = types.intp 506 507 def wrapper(val): 508 return _get_hash_value_intrinsic(val) 509 510 args = [item] 511 sig = typing.signature(resty, *argtypes) 512 is_error, retval = pyapi.call_jit_code(wrapper, sig, args) 513 # Handle return status 514 with builder.if_then(is_error, likely=False): 515 # Raise nopython exception as a Python exception 516 builder.ret(pyapi.get_null_object()) 517 return retval 518 519 def contains(self, item): 520 context = self._context 521 builder = self._builder 522 523 payload = self.payload 524 h = get_hash_value(context, builder, self._ty.dtype, item) 525 found, i = payload._lookup(item, h) 526 return found 527 528 def discard(self, item): 529 context = self._context 530 builder = self._builder 531 532 payload = self.payload 533 h = get_hash_value(context, builder, self._ty.dtype, item) 534 found = self._remove_key(payload, item, h) 535 return found 536 537 def pop(self): 538 context = self._context 539 builder = self._builder 540 541 lty = context.get_value_type(self._ty.dtype) 542 key = cgutils.alloca_once(builder, lty) 543 544 payload = self.payload 545 with payload._next_entry() as entry: 546 builder.store(entry.key, key) 547 self._remove_entry(payload, entry) 548 549 return builder.load(key) 550 551 def clear(self): 552 context = self._context 553 builder = self._builder 554 555 intp_t = context.get_value_type(types.intp) 556 minsize = ir.Constant(intp_t, MINSIZE) 557 self._replace_payload(minsize) 558 self.set_dirty(True) 559 560 def copy(self): 561 """ 562 Return a copy of this set. 563 """ 564 context = self._context 565 builder = self._builder 566 567 payload = self.payload 568 used = payload.used 569 fill = payload.fill 570 571 other = type(self)(context, builder, self._ty, None) 572 573 no_deleted_entries = builder.icmp_unsigned('==', used, fill) 574 with builder.if_else(no_deleted_entries, likely=True) \ 575 as (if_no_deleted, if_deleted): 576 with if_no_deleted: 577 # No deleted entries => raw copy the payload 578 ok = other._copy_payload(payload) 579 with builder.if_then(builder.not_(ok), likely=False): 580 context.call_conv.return_user_exc(builder, MemoryError, 581 ("cannot copy set",)) 582 583 with if_deleted: 584 # Deleted entries => re-insert entries one by one 585 nentries = self.choose_alloc_size(context, builder, used) 586 ok = other._allocate_payload(nentries) 587 with builder.if_then(builder.not_(ok), likely=False): 588 context.call_conv.return_user_exc(builder, MemoryError, 589 ("cannot copy set",)) 590 591 other_payload = other.payload 592 with payload._iterate() as loop: 593 entry = loop.entry 594 other._add_key(other_payload, entry.key, entry.hash, 595 do_resize=False) 596 597 return other 598 599 def intersect(self, other): 600 """ 601 In-place intersection with *other* set. 602 """ 603 context = self._context 604 builder = self._builder 605 payload = self.payload 606 other_payload = other.payload 607 608 with payload._iterate() as loop: 609 entry = loop.entry 610 found, _ = other_payload._lookup(entry.key, entry.hash) 611 with builder.if_then(builder.not_(found)): 612 self._remove_entry(payload, entry, do_resize=False) 613 614 # Final downsize 615 self.downsize(payload.used) 616 617 def difference(self, other): 618 """ 619 In-place difference with *other* set. 620 """ 621 context = self._context 622 builder = self._builder 623 payload = self.payload 624 other_payload = other.payload 625 626 with other_payload._iterate() as loop: 627 entry = loop.entry 628 self._remove_key(payload, entry.key, entry.hash, do_resize=False) 629 630 # Final downsize 631 self.downsize(payload.used) 632 633 def symmetric_difference(self, other): 634 """ 635 In-place symmetric difference with *other* set. 636 """ 637 context = self._context 638 builder = self._builder 639 other_payload = other.payload 640 641 with other_payload._iterate() as loop: 642 key = loop.entry.key 643 h = loop.entry.hash 644 # We must reload our payload as it may be resized during the loop 645 payload = self.payload 646 found, i = payload._lookup(key, h, for_insert=True) 647 entry = payload.get_entry(i) 648 with builder.if_else(found) as (if_common, if_not_common): 649 with if_common: 650 self._remove_entry(payload, entry, do_resize=False) 651 with if_not_common: 652 self._add_entry(payload, entry, key, h) 653 654 # Final downsize 655 self.downsize(self.payload.used) 656 657 def issubset(self, other, strict=False): 658 context = self._context 659 builder = self._builder 660 payload = self.payload 661 other_payload = other.payload 662 663 cmp_op = '<' if strict else '<=' 664 665 res = cgutils.alloca_once_value(builder, cgutils.true_bit) 666 with builder.if_else( 667 builder.icmp_unsigned(cmp_op, payload.used, other_payload.used) 668 ) as (if_smaller, if_larger): 669 with if_larger: 670 # self larger than other => self cannot possibly a subset 671 builder.store(cgutils.false_bit, res) 672 with if_smaller: 673 # check whether each key of self is in other 674 with payload._iterate() as loop: 675 entry = loop.entry 676 found, _ = other_payload._lookup(entry.key, entry.hash) 677 with builder.if_then(builder.not_(found)): 678 builder.store(cgutils.false_bit, res) 679 loop.do_break() 680 681 return builder.load(res) 682 683 def isdisjoint(self, other): 684 context = self._context 685 builder = self._builder 686 payload = self.payload 687 other_payload = other.payload 688 689 res = cgutils.alloca_once_value(builder, cgutils.true_bit) 690 691 def check(smaller, larger): 692 # Loop over the smaller of the two, and search in the larger 693 with smaller._iterate() as loop: 694 entry = loop.entry 695 found, _ = larger._lookup(entry.key, entry.hash) 696 with builder.if_then(found): 697 builder.store(cgutils.false_bit, res) 698 loop.do_break() 699 700 with builder.if_else( 701 builder.icmp_unsigned('>', payload.used, other_payload.used) 702 ) as (if_larger, otherwise): 703 704 with if_larger: 705 # len(self) > len(other) 706 check(other_payload, payload) 707 708 with otherwise: 709 # len(self) <= len(other) 710 check(payload, other_payload) 711 712 return builder.load(res) 713 714 def equals(self, other): 715 context = self._context 716 builder = self._builder 717 payload = self.payload 718 other_payload = other.payload 719 720 res = cgutils.alloca_once_value(builder, cgutils.true_bit) 721 with builder.if_else( 722 builder.icmp_unsigned('==', payload.used, other_payload.used) 723 ) as (if_same_size, otherwise): 724 with if_same_size: 725 # same sizes => check whether each key of self is in other 726 with payload._iterate() as loop: 727 entry = loop.entry 728 found, _ = other_payload._lookup(entry.key, entry.hash) 729 with builder.if_then(builder.not_(found)): 730 builder.store(cgutils.false_bit, res) 731 loop.do_break() 732 with otherwise: 733 # different sizes => cannot possibly be equal 734 builder.store(cgutils.false_bit, res) 735 736 return builder.load(res) 737 738 @classmethod 739 def allocate_ex(cls, context, builder, set_type, nitems=None): 740 """ 741 Allocate a SetInstance with its storage. 742 Return a (ok, instance) tuple where *ok* is a LLVM boolean and 743 *instance* is a SetInstance object (the object's contents are 744 only valid when *ok* is true). 745 """ 746 intp_t = context.get_value_type(types.intp) 747 748 if nitems is None: 749 nentries = ir.Constant(intp_t, MINSIZE) 750 else: 751 if isinstance(nitems, int): 752 nitems = ir.Constant(intp_t, nitems) 753 nentries = cls.choose_alloc_size(context, builder, nitems) 754 755 self = cls(context, builder, set_type, None) 756 ok = self._allocate_payload(nentries) 757 return ok, self 758 759 @classmethod 760 def allocate(cls, context, builder, set_type, nitems=None): 761 """ 762 Allocate a SetInstance with its storage. Same as allocate_ex(), 763 but return an initialized *instance*. If allocation failed, 764 control is transferred to the caller using the target's current 765 call convention. 766 """ 767 ok, self = cls.allocate_ex(context, builder, set_type, nitems) 768 with builder.if_then(builder.not_(ok), likely=False): 769 context.call_conv.return_user_exc(builder, MemoryError, 770 ("cannot allocate set",)) 771 return self 772 773 @classmethod 774 def from_meminfo(cls, context, builder, set_type, meminfo): 775 """ 776 Allocate a new set instance pointing to an existing payload 777 (a meminfo pointer). 778 Note the parent field has to be filled by the caller. 779 """ 780 self = cls(context, builder, set_type, None) 781 self._set.meminfo = meminfo 782 self._set.parent = context.get_constant_null(types.pyobject) 783 context.nrt.incref(builder, set_type, self.value) 784 # Payload is part of the meminfo, no need to touch it 785 return self 786 787 @classmethod 788 def choose_alloc_size(cls, context, builder, nitems): 789 """ 790 Choose a suitable number of entries for the given number of items. 791 """ 792 intp_t = nitems.type 793 one = ir.Constant(intp_t, 1) 794 minsize = ir.Constant(intp_t, MINSIZE) 795 796 # Ensure number of entries >= 2 * used 797 min_entries = builder.shl(nitems, one) 798 # Find out first suitable power of 2, starting from MINSIZE 799 size_p = cgutils.alloca_once_value(builder, minsize) 800 801 bb_body = builder.append_basic_block("calcsize.body") 802 bb_end = builder.append_basic_block("calcsize.end") 803 804 builder.branch(bb_body) 805 806 with builder.goto_block(bb_body): 807 size = builder.load(size_p) 808 is_large_enough = builder.icmp_unsigned('>=', size, min_entries) 809 with builder.if_then(is_large_enough, likely=False): 810 builder.branch(bb_end) 811 next_size = builder.shl(size, one) 812 builder.store(next_size, size_p) 813 builder.branch(bb_body) 814 815 builder.position_at_end(bb_end) 816 return builder.load(size_p) 817 818 def upsize(self, nitems): 819 """ 820 When adding to the set, ensure it is properly sized for the given 821 number of used entries. 822 """ 823 context = self._context 824 builder = self._builder 825 intp_t = nitems.type 826 827 one = ir.Constant(intp_t, 1) 828 two = ir.Constant(intp_t, 2) 829 830 payload = self.payload 831 832 # Ensure number of entries >= 2 * used 833 min_entries = builder.shl(nitems, one) 834 size = builder.add(payload.mask, one) 835 need_resize = builder.icmp_unsigned('>=', min_entries, size) 836 837 with builder.if_then(need_resize, likely=False): 838 # Find out next suitable size 839 new_size_p = cgutils.alloca_once_value(builder, size) 840 841 bb_body = builder.append_basic_block("calcsize.body") 842 bb_end = builder.append_basic_block("calcsize.end") 843 844 builder.branch(bb_body) 845 846 with builder.goto_block(bb_body): 847 # Multiply by 4 (ensuring size remains a power of two) 848 new_size = builder.load(new_size_p) 849 new_size = builder.shl(new_size, two) 850 builder.store(new_size, new_size_p) 851 is_too_small = builder.icmp_unsigned('>=', min_entries, new_size) 852 builder.cbranch(is_too_small, bb_body, bb_end) 853 854 builder.position_at_end(bb_end) 855 856 new_size = builder.load(new_size_p) 857 if DEBUG_ALLOCS: 858 context.printf(builder, 859 "upsize to %zd items: current size = %zd, " 860 "min entries = %zd, new size = %zd\n", 861 nitems, size, min_entries, new_size) 862 self._resize(payload, new_size, "cannot grow set") 863 864 def downsize(self, nitems): 865 """ 866 When removing from the set, ensure it is properly sized for the given 867 number of used entries. 868 """ 869 context = self._context 870 builder = self._builder 871 intp_t = nitems.type 872 873 one = ir.Constant(intp_t, 1) 874 two = ir.Constant(intp_t, 2) 875 minsize = ir.Constant(intp_t, MINSIZE) 876 877 payload = self.payload 878 879 # Ensure entries >= max(2 * used, MINSIZE) 880 min_entries = builder.shl(nitems, one) 881 min_entries = builder.select(builder.icmp_unsigned('>=', min_entries, minsize), 882 min_entries, minsize) 883 # Shrink only if size >= 4 * min_entries && size > MINSIZE 884 max_size = builder.shl(min_entries, two) 885 size = builder.add(payload.mask, one) 886 need_resize = builder.and_( 887 builder.icmp_unsigned('<=', max_size, size), 888 builder.icmp_unsigned('<', minsize, size)) 889 890 with builder.if_then(need_resize, likely=False): 891 # Find out next suitable size 892 new_size_p = cgutils.alloca_once_value(builder, size) 893 894 bb_body = builder.append_basic_block("calcsize.body") 895 bb_end = builder.append_basic_block("calcsize.end") 896 897 builder.branch(bb_body) 898 899 with builder.goto_block(bb_body): 900 # Divide by 2 (ensuring size remains a power of two) 901 new_size = builder.load(new_size_p) 902 new_size = builder.lshr(new_size, one) 903 # Keep current size if new size would be < min_entries 904 is_too_small = builder.icmp_unsigned('>', min_entries, new_size) 905 with builder.if_then(is_too_small): 906 builder.branch(bb_end) 907 builder.store(new_size, new_size_p) 908 builder.branch(bb_body) 909 910 builder.position_at_end(bb_end) 911 912 # Ensure new_size >= MINSIZE 913 new_size = builder.load(new_size_p) 914 # At this point, new_size should be < size if the factors 915 # above were chosen carefully! 916 917 if DEBUG_ALLOCS: 918 context.printf(builder, 919 "downsize to %zd items: current size = %zd, " 920 "min entries = %zd, new size = %zd\n", 921 nitems, size, min_entries, new_size) 922 self._resize(payload, new_size, "cannot shrink set") 923 924 def _resize(self, payload, nentries, errmsg): 925 """ 926 Resize the payload to the given number of entries. 927 928 CAUTION: *nentries* must be a power of 2! 929 """ 930 context = self._context 931 builder = self._builder 932 933 # Allocate new entries 934 old_payload = payload 935 936 ok = self._allocate_payload(nentries, realloc=True) 937 with builder.if_then(builder.not_(ok), likely=False): 938 context.call_conv.return_user_exc(builder, MemoryError, 939 (errmsg,)) 940 941 # Re-insert old entries 942 payload = self.payload 943 with old_payload._iterate() as loop: 944 entry = loop.entry 945 self._add_key(payload, entry.key, entry.hash, 946 do_resize=False) 947 948 self._free_payload(old_payload.ptr) 949 950 def _replace_payload(self, nentries): 951 """ 952 Replace the payload with a new empty payload with the given number 953 of entries. 954 955 CAUTION: *nentries* must be a power of 2! 956 """ 957 context = self._context 958 builder = self._builder 959 960 # Free old payload 961 self._free_payload(self.payload.ptr) 962 963 ok = self._allocate_payload(nentries, realloc=True) 964 with builder.if_then(builder.not_(ok), likely=False): 965 context.call_conv.return_user_exc(builder, MemoryError, 966 ("cannot reallocate set",)) 967 968 def _allocate_payload(self, nentries, realloc=False): 969 """ 970 Allocate and initialize payload for the given number of entries. 971 If *realloc* is True, the existing meminfo is reused. 972 973 CAUTION: *nentries* must be a power of 2! 974 """ 975 context = self._context 976 builder = self._builder 977 978 ok = cgutils.alloca_once_value(builder, cgutils.true_bit) 979 980 intp_t = context.get_value_type(types.intp) 981 zero = ir.Constant(intp_t, 0) 982 one = ir.Constant(intp_t, 1) 983 984 payload_type = context.get_data_type(types.SetPayload(self._ty)) 985 payload_size = context.get_abi_sizeof(payload_type) 986 entry_size = self._entrysize 987 # Account for the fact that the payload struct already contains an entry 988 payload_size -= entry_size 989 990 # Total allocation size = <payload header size> + nentries * entry_size 991 allocsize, ovf = cgutils.muladd_with_overflow(builder, nentries, 992 ir.Constant(intp_t, entry_size), 993 ir.Constant(intp_t, payload_size)) 994 with builder.if_then(ovf, likely=False): 995 builder.store(cgutils.false_bit, ok) 996 997 with builder.if_then(builder.load(ok), likely=True): 998 if realloc: 999 meminfo = self._set.meminfo 1000 ptr = context.nrt.meminfo_varsize_alloc(builder, meminfo, 1001 size=allocsize) 1002 alloc_ok = cgutils.is_null(builder, ptr) 1003 else: 1004 meminfo = context.nrt.meminfo_new_varsize(builder, size=allocsize) 1005 alloc_ok = cgutils.is_null(builder, meminfo) 1006 1007 with builder.if_else(cgutils.is_null(builder, meminfo), 1008 likely=False) as (if_error, if_ok): 1009 with if_error: 1010 builder.store(cgutils.false_bit, ok) 1011 with if_ok: 1012 if not realloc: 1013 self._set.meminfo = meminfo 1014 self._set.parent = context.get_constant_null(types.pyobject) 1015 payload = self.payload 1016 # Initialize entries to 0xff (EMPTY) 1017 cgutils.memset(builder, payload.ptr, allocsize, 0xFF) 1018 payload.used = zero 1019 payload.fill = zero 1020 payload.finger = zero 1021 new_mask = builder.sub(nentries, one) 1022 payload.mask = new_mask 1023 1024 if DEBUG_ALLOCS: 1025 context.printf(builder, 1026 "allocated %zd bytes for set at %p: mask = %zd\n", 1027 allocsize, payload.ptr, new_mask) 1028 1029 return builder.load(ok) 1030 1031 def _free_payload(self, ptr): 1032 """ 1033 Free an allocated old payload at *ptr*. 1034 """ 1035 self._context.nrt.meminfo_varsize_free(self._builder, self.meminfo, ptr) 1036 1037 def _copy_payload(self, src_payload): 1038 """ 1039 Raw-copy the given payload into self. 1040 """ 1041 context = self._context 1042 builder = self._builder 1043 1044 ok = cgutils.alloca_once_value(builder, cgutils.true_bit) 1045 1046 intp_t = context.get_value_type(types.intp) 1047 zero = ir.Constant(intp_t, 0) 1048 one = ir.Constant(intp_t, 1) 1049 1050 payload_type = context.get_data_type(types.SetPayload(self._ty)) 1051 payload_size = context.get_abi_sizeof(payload_type) 1052 entry_size = self._entrysize 1053 # Account for the fact that the payload struct already contains an entry 1054 payload_size -= entry_size 1055 1056 mask = src_payload.mask 1057 nentries = builder.add(one, mask) 1058 1059 # Total allocation size = <payload header size> + nentries * entry_size 1060 # (note there can't be any overflow since we're reusing an existing 1061 # payload's parameters) 1062 allocsize = builder.add(ir.Constant(intp_t, payload_size), 1063 builder.mul(ir.Constant(intp_t, entry_size), 1064 nentries)) 1065 1066 with builder.if_then(builder.load(ok), likely=True): 1067 meminfo = context.nrt.meminfo_new_varsize(builder, size=allocsize) 1068 alloc_ok = cgutils.is_null(builder, meminfo) 1069 1070 with builder.if_else(cgutils.is_null(builder, meminfo), 1071 likely=False) as (if_error, if_ok): 1072 with if_error: 1073 builder.store(cgutils.false_bit, ok) 1074 with if_ok: 1075 self._set.meminfo = meminfo 1076 payload = self.payload 1077 payload.used = src_payload.used 1078 payload.fill = src_payload.fill 1079 payload.finger = zero 1080 payload.mask = mask 1081 cgutils.raw_memcpy(builder, payload.entries, 1082 src_payload.entries, nentries, 1083 entry_size) 1084 1085 if DEBUG_ALLOCS: 1086 context.printf(builder, 1087 "allocated %zd bytes for set at %p: mask = %zd\n", 1088 allocsize, payload.ptr, mask) 1089 1090 return builder.load(ok) 1091 1092 1093class SetIterInstance(object): 1094 1095 def __init__(self, context, builder, iter_type, iter_val): 1096 self._context = context 1097 self._builder = builder 1098 self._ty = iter_type 1099 self._iter = context.make_helper(builder, iter_type, iter_val) 1100 ptr = self._context.nrt.meminfo_data(builder, self.meminfo) 1101 self._payload = _SetPayload(context, builder, self._ty.container, ptr) 1102 1103 @classmethod 1104 def from_set(cls, context, builder, iter_type, set_val): 1105 set_inst = SetInstance(context, builder, iter_type.container, set_val) 1106 self = cls(context, builder, iter_type, None) 1107 index = context.get_constant(types.intp, 0) 1108 self._iter.index = cgutils.alloca_once_value(builder, index) 1109 self._iter.meminfo = set_inst.meminfo 1110 return self 1111 1112 @property 1113 def value(self): 1114 return self._iter._getvalue() 1115 1116 @property 1117 def meminfo(self): 1118 return self._iter.meminfo 1119 1120 @property 1121 def index(self): 1122 return self._builder.load(self._iter.index) 1123 1124 @index.setter 1125 def index(self, value): 1126 self._builder.store(value, self._iter.index) 1127 1128 def iternext(self, result): 1129 index = self.index 1130 payload = self._payload 1131 one = ir.Constant(index.type, 1) 1132 1133 result.set_exhausted() 1134 1135 with payload._iterate(start=index) as loop: 1136 # An entry was found 1137 entry = loop.entry 1138 result.set_valid() 1139 result.yield_(entry.key) 1140 self.index = self._builder.add(loop.index, one) 1141 loop.do_break() 1142 1143 1144#------------------------------------------------------------------------------- 1145# Constructors 1146 1147def build_set(context, builder, set_type, items): 1148 """ 1149 Build a set of the given type, containing the given items. 1150 """ 1151 nitems = len(items) 1152 inst = SetInstance.allocate(context, builder, set_type, nitems) 1153 1154 # Populate set. Inlining the insertion code for each item would be very 1155 # costly, instead we create a LLVM array and iterate over it. 1156 array = cgutils.pack_array(builder, items) 1157 array_ptr = cgutils.alloca_once_value(builder, array) 1158 1159 count = context.get_constant(types.intp, nitems) 1160 with cgutils.for_range(builder, count) as loop: 1161 item = builder.load(cgutils.gep(builder, array_ptr, 0, loop.index)) 1162 inst.add(item) 1163 1164 return impl_ret_new_ref(context, builder, set_type, inst.value) 1165 1166 1167@lower_builtin(set) 1168def set_empty_constructor(context, builder, sig, args): 1169 set_type = sig.return_type 1170 inst = SetInstance.allocate(context, builder, set_type) 1171 return impl_ret_new_ref(context, builder, set_type, inst.value) 1172 1173@lower_builtin(set, types.IterableType) 1174def set_constructor(context, builder, sig, args): 1175 set_type = sig.return_type 1176 items_type, = sig.args 1177 items, = args 1178 1179 # If the argument has a len(), preallocate the set so as to 1180 # avoid resizes. 1181 n = call_len(context, builder, items_type, items) 1182 inst = SetInstance.allocate(context, builder, set_type, n) 1183 with for_iter(context, builder, items_type, items) as loop: 1184 inst.add(loop.value) 1185 1186 return impl_ret_new_ref(context, builder, set_type, inst.value) 1187 1188 1189#------------------------------------------------------------------------------- 1190# Various operations 1191 1192@lower_builtin(len, types.Set) 1193def set_len(context, builder, sig, args): 1194 inst = SetInstance(context, builder, sig.args[0], args[0]) 1195 return inst.get_size() 1196 1197@lower_builtin(operator.contains, types.Set, types.Any) 1198def in_set(context, builder, sig, args): 1199 inst = SetInstance(context, builder, sig.args[0], args[0]) 1200 return inst.contains(args[1]) 1201 1202@lower_builtin('getiter', types.Set) 1203def getiter_set(context, builder, sig, args): 1204 inst = SetIterInstance.from_set(context, builder, sig.return_type, args[0]) 1205 return impl_ret_borrowed(context, builder, sig.return_type, inst.value) 1206 1207@lower_builtin('iternext', types.SetIter) 1208@iternext_impl(RefType.BORROWED) 1209def iternext_listiter(context, builder, sig, args, result): 1210 inst = SetIterInstance(context, builder, sig.args[0], args[0]) 1211 inst.iternext(result) 1212 1213 1214#------------------------------------------------------------------------------- 1215# Methods 1216 1217# One-item-at-a-time operations 1218 1219@lower_builtin("set.add", types.Set, types.Any) 1220def set_add(context, builder, sig, args): 1221 inst = SetInstance(context, builder, sig.args[0], args[0]) 1222 item = args[1] 1223 inst.add(item) 1224 1225 return context.get_dummy_value() 1226 1227@lower_builtin("set.discard", types.Set, types.Any) 1228def set_discard(context, builder, sig, args): 1229 inst = SetInstance(context, builder, sig.args[0], args[0]) 1230 item = args[1] 1231 inst.discard(item) 1232 1233 return context.get_dummy_value() 1234 1235@lower_builtin("set.pop", types.Set) 1236def set_pop(context, builder, sig, args): 1237 inst = SetInstance(context, builder, sig.args[0], args[0]) 1238 used = inst.payload.used 1239 with builder.if_then(cgutils.is_null(builder, used), likely=False): 1240 context.call_conv.return_user_exc(builder, KeyError, 1241 ("set.pop(): empty set",)) 1242 1243 return inst.pop() 1244 1245@lower_builtin("set.remove", types.Set, types.Any) 1246def set_remove(context, builder, sig, args): 1247 inst = SetInstance(context, builder, sig.args[0], args[0]) 1248 item = args[1] 1249 found = inst.discard(item) 1250 with builder.if_then(builder.not_(found), likely=False): 1251 context.call_conv.return_user_exc(builder, KeyError, 1252 ("set.remove(): key not in set",)) 1253 1254 return context.get_dummy_value() 1255 1256 1257# Mutating set operations 1258 1259@lower_builtin("set.clear", types.Set) 1260def set_clear(context, builder, sig, args): 1261 inst = SetInstance(context, builder, sig.args[0], args[0]) 1262 inst.clear() 1263 return context.get_dummy_value() 1264 1265@lower_builtin("set.copy", types.Set) 1266def set_copy(context, builder, sig, args): 1267 inst = SetInstance(context, builder, sig.args[0], args[0]) 1268 other = inst.copy() 1269 return impl_ret_new_ref(context, builder, sig.return_type, other.value) 1270 1271@lower_builtin("set.difference_update", types.Set, types.IterableType) 1272def set_difference_update(context, builder, sig, args): 1273 inst = SetInstance(context, builder, sig.args[0], args[0]) 1274 other = SetInstance(context, builder, sig.args[1], args[1]) 1275 1276 inst.difference(other) 1277 1278 return context.get_dummy_value() 1279 1280@lower_builtin("set.intersection_update", types.Set, types.Set) 1281def set_intersection_update(context, builder, sig, args): 1282 inst = SetInstance(context, builder, sig.args[0], args[0]) 1283 other = SetInstance(context, builder, sig.args[1], args[1]) 1284 1285 inst.intersect(other) 1286 1287 return context.get_dummy_value() 1288 1289@lower_builtin("set.symmetric_difference_update", types.Set, types.Set) 1290def set_symmetric_difference_update(context, builder, sig, args): 1291 inst = SetInstance(context, builder, sig.args[0], args[0]) 1292 other = SetInstance(context, builder, sig.args[1], args[1]) 1293 1294 inst.symmetric_difference(other) 1295 1296 return context.get_dummy_value() 1297 1298@lower_builtin("set.update", types.Set, types.IterableType) 1299def set_update(context, builder, sig, args): 1300 inst = SetInstance(context, builder, sig.args[0], args[0]) 1301 items_type = sig.args[1] 1302 items = args[1] 1303 1304 # If the argument has a len(), assume there are few collisions and 1305 # presize to len(set) + len(items) 1306 n = call_len(context, builder, items_type, items) 1307 if n is not None: 1308 new_size = builder.add(inst.payload.used, n) 1309 inst.upsize(new_size) 1310 1311 with for_iter(context, builder, items_type, items) as loop: 1312 inst.add(loop.value) 1313 1314 if n is not None: 1315 # If we pre-grew the set, downsize in case there were many collisions 1316 inst.downsize(inst.payload.used) 1317 1318 return context.get_dummy_value() 1319 1320for op_, op_impl in [ 1321 (operator.iand, set_intersection_update), 1322 (operator.ior, set_update), 1323 (operator.isub, set_difference_update), 1324 (operator.ixor, set_symmetric_difference_update), 1325 ]: 1326 @lower_builtin(op_, types.Set, types.Set) 1327 def set_inplace(context, builder, sig, args, op_impl=op_impl): 1328 assert sig.return_type == sig.args[0] 1329 op_impl(context, builder, sig, args) 1330 return impl_ret_borrowed(context, builder, sig.args[0], args[0]) 1331 1332 1333# Set operations creating a new set 1334 1335@lower_builtin(operator.sub, types.Set, types.Set) 1336@lower_builtin("set.difference", types.Set, types.Set) 1337def set_difference(context, builder, sig, args): 1338 def difference_impl(a, b): 1339 s = a.copy() 1340 s.difference_update(b) 1341 return s 1342 1343 return context.compile_internal(builder, difference_impl, sig, args) 1344 1345@lower_builtin(operator.and_, types.Set, types.Set) 1346@lower_builtin("set.intersection", types.Set, types.Set) 1347def set_intersection(context, builder, sig, args): 1348 def intersection_impl(a, b): 1349 if len(a) < len(b): 1350 s = a.copy() 1351 s.intersection_update(b) 1352 return s 1353 else: 1354 s = b.copy() 1355 s.intersection_update(a) 1356 return s 1357 1358 return context.compile_internal(builder, intersection_impl, sig, args) 1359 1360@lower_builtin(operator.xor, types.Set, types.Set) 1361@lower_builtin("set.symmetric_difference", types.Set, types.Set) 1362def set_symmetric_difference(context, builder, sig, args): 1363 def symmetric_difference_impl(a, b): 1364 if len(a) > len(b): 1365 s = a.copy() 1366 s.symmetric_difference_update(b) 1367 return s 1368 else: 1369 s = b.copy() 1370 s.symmetric_difference_update(a) 1371 return s 1372 1373 return context.compile_internal(builder, symmetric_difference_impl, 1374 sig, args) 1375 1376@lower_builtin(operator.or_, types.Set, types.Set) 1377@lower_builtin("set.union", types.Set, types.Set) 1378def set_union(context, builder, sig, args): 1379 def union_impl(a, b): 1380 if len(a) > len(b): 1381 s = a.copy() 1382 s.update(b) 1383 return s 1384 else: 1385 s = b.copy() 1386 s.update(a) 1387 return s 1388 1389 return context.compile_internal(builder, union_impl, sig, args) 1390 1391 1392# Predicates 1393 1394@lower_builtin("set.isdisjoint", types.Set, types.Set) 1395def set_isdisjoint(context, builder, sig, args): 1396 inst = SetInstance(context, builder, sig.args[0], args[0]) 1397 other = SetInstance(context, builder, sig.args[1], args[1]) 1398 1399 return inst.isdisjoint(other) 1400 1401@lower_builtin(operator.le, types.Set, types.Set) 1402@lower_builtin("set.issubset", types.Set, types.Set) 1403def set_issubset(context, builder, sig, args): 1404 inst = SetInstance(context, builder, sig.args[0], args[0]) 1405 other = SetInstance(context, builder, sig.args[1], args[1]) 1406 1407 return inst.issubset(other) 1408 1409@lower_builtin(operator.ge, types.Set, types.Set) 1410@lower_builtin("set.issuperset", types.Set, types.Set) 1411def set_issuperset(context, builder, sig, args): 1412 def superset_impl(a, b): 1413 return b.issubset(a) 1414 1415 return context.compile_internal(builder, superset_impl, sig, args) 1416 1417@lower_builtin(operator.eq, types.Set, types.Set) 1418def set_isdisjoint(context, builder, sig, args): 1419 inst = SetInstance(context, builder, sig.args[0], args[0]) 1420 other = SetInstance(context, builder, sig.args[1], args[1]) 1421 1422 return inst.equals(other) 1423 1424@lower_builtin(operator.ne, types.Set, types.Set) 1425def set_ne(context, builder, sig, args): 1426 def ne_impl(a, b): 1427 return not a == b 1428 1429 return context.compile_internal(builder, ne_impl, sig, args) 1430 1431@lower_builtin(operator.lt, types.Set, types.Set) 1432def set_lt(context, builder, sig, args): 1433 inst = SetInstance(context, builder, sig.args[0], args[0]) 1434 other = SetInstance(context, builder, sig.args[1], args[1]) 1435 1436 return inst.issubset(other, strict=True) 1437 1438@lower_builtin(operator.gt, types.Set, types.Set) 1439def set_gt(context, builder, sig, args): 1440 def gt_impl(a, b): 1441 return b < a 1442 1443 return context.compile_internal(builder, gt_impl, sig, args) 1444 1445@lower_builtin(operator.is_, types.Set, types.Set) 1446def set_is(context, builder, sig, args): 1447 a = SetInstance(context, builder, sig.args[0], args[0]) 1448 b = SetInstance(context, builder, sig.args[1], args[1]) 1449 ma = builder.ptrtoint(a.meminfo, cgutils.intp_t) 1450 mb = builder.ptrtoint(b.meminfo, cgutils.intp_t) 1451 return builder.icmp_signed('==', ma, mb) 1452 1453 1454# ----------------------------------------------------------------------------- 1455# Implicit casting 1456 1457@lower_cast(types.Set, types.Set) 1458def set_to_set(context, builder, fromty, toty, val): 1459 # Casting from non-reflected to reflected 1460 assert fromty.dtype == toty.dtype 1461 return val 1462