1# util/_collections.py 2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors 3# <see AUTHORS file> 4# 5# This module is part of SQLAlchemy and is released under 6# the MIT License: https://www.opensource.org/licenses/mit-license.php 7 8"""Collection classes and helpers.""" 9 10from __future__ import absolute_import 11 12import operator 13import types 14import weakref 15 16from .compat import binary_types 17from .compat import collections_abc 18from .compat import itertools_filterfalse 19from .compat import py2k 20from .compat import py37 21from .compat import string_types 22from .compat import threading 23 24 25EMPTY_SET = frozenset() 26 27 28class ImmutableContainer(object): 29 def _immutable(self, *arg, **kw): 30 raise TypeError("%s object is immutable" % self.__class__.__name__) 31 32 __delitem__ = __setitem__ = __setattr__ = _immutable 33 34 35def _immutabledict_py_fallback(): 36 class immutabledict(ImmutableContainer, dict): 37 38 clear = ( 39 pop 40 ) = popitem = setdefault = update = ImmutableContainer._immutable 41 42 def __new__(cls, *args): 43 new = dict.__new__(cls) 44 dict.__init__(new, *args) 45 return new 46 47 def __init__(self, *args): 48 pass 49 50 def __reduce__(self): 51 return _immutabledict_reconstructor, (dict(self),) 52 53 def union(self, __d=None): 54 if not __d: 55 return self 56 57 new = dict.__new__(self.__class__) 58 dict.__init__(new, self) 59 dict.update(new, __d) 60 return new 61 62 def _union_w_kw(self, __d=None, **kw): 63 # not sure if C version works correctly w/ this yet 64 if not __d and not kw: 65 return self 66 67 new = dict.__new__(self.__class__) 68 dict.__init__(new, self) 69 if __d: 70 dict.update(new, __d) 71 dict.update(new, kw) 72 return new 73 74 def merge_with(self, *dicts): 75 new = None 76 for d in dicts: 77 if d: 78 if new is None: 79 new = dict.__new__(self.__class__) 80 dict.__init__(new, self) 81 dict.update(new, d) 82 if new is None: 83 return self 84 85 return new 86 87 def __repr__(self): 88 return "immutabledict(%s)" % dict.__repr__(self) 89 90 return immutabledict 91 92 93try: 94 from sqlalchemy.cimmutabledict import immutabledict 95 96 collections_abc.Mapping.register(immutabledict) 97 98except ImportError: 99 immutabledict = _immutabledict_py_fallback() 100 101 def _immutabledict_reconstructor(*arg): 102 """do the pickle dance""" 103 return immutabledict(*arg) 104 105 106def coerce_to_immutabledict(d): 107 if not d: 108 return EMPTY_DICT 109 elif isinstance(d, immutabledict): 110 return d 111 else: 112 return immutabledict(d) 113 114 115EMPTY_DICT = immutabledict() 116 117 118class FacadeDict(ImmutableContainer, dict): 119 """A dictionary that is not publicly mutable.""" 120 121 clear = pop = popitem = setdefault = update = ImmutableContainer._immutable 122 123 def __new__(cls, *args): 124 new = dict.__new__(cls) 125 return new 126 127 def copy(self): 128 raise NotImplementedError( 129 "an immutabledict shouldn't need to be copied. use dict(d) " 130 "if you need a mutable dictionary." 131 ) 132 133 def __reduce__(self): 134 return FacadeDict, (dict(self),) 135 136 def _insert_item(self, key, value): 137 """insert an item into the dictionary directly.""" 138 dict.__setitem__(self, key, value) 139 140 def __repr__(self): 141 return "FacadeDict(%s)" % dict.__repr__(self) 142 143 144class Properties(object): 145 """Provide a __getattr__/__setattr__ interface over a dict.""" 146 147 __slots__ = ("_data",) 148 149 def __init__(self, data): 150 object.__setattr__(self, "_data", data) 151 152 def __len__(self): 153 return len(self._data) 154 155 def __iter__(self): 156 return iter(list(self._data.values())) 157 158 def __dir__(self): 159 return dir(super(Properties, self)) + [ 160 str(k) for k in self._data.keys() 161 ] 162 163 def __add__(self, other): 164 return list(self) + list(other) 165 166 def __setitem__(self, key, obj): 167 self._data[key] = obj 168 169 def __getitem__(self, key): 170 return self._data[key] 171 172 def __delitem__(self, key): 173 del self._data[key] 174 175 def __setattr__(self, key, obj): 176 self._data[key] = obj 177 178 def __getstate__(self): 179 return {"_data": self._data} 180 181 def __setstate__(self, state): 182 object.__setattr__(self, "_data", state["_data"]) 183 184 def __getattr__(self, key): 185 try: 186 return self._data[key] 187 except KeyError: 188 raise AttributeError(key) 189 190 def __contains__(self, key): 191 return key in self._data 192 193 def as_immutable(self): 194 """Return an immutable proxy for this :class:`.Properties`.""" 195 196 return ImmutableProperties(self._data) 197 198 def update(self, value): 199 self._data.update(value) 200 201 def get(self, key, default=None): 202 if key in self: 203 return self[key] 204 else: 205 return default 206 207 def keys(self): 208 return list(self._data) 209 210 def values(self): 211 return list(self._data.values()) 212 213 def items(self): 214 return list(self._data.items()) 215 216 def has_key(self, key): 217 return key in self._data 218 219 def clear(self): 220 self._data.clear() 221 222 223class OrderedProperties(Properties): 224 """Provide a __getattr__/__setattr__ interface with an OrderedDict 225 as backing store.""" 226 227 __slots__ = () 228 229 def __init__(self): 230 Properties.__init__(self, OrderedDict()) 231 232 233class ImmutableProperties(ImmutableContainer, Properties): 234 """Provide immutable dict/object attribute to an underlying dictionary.""" 235 236 __slots__ = () 237 238 239def _ordered_dictionary_sort(d, key=None): 240 """Sort an OrderedDict in-place.""" 241 242 items = [(k, d[k]) for k in sorted(d, key=key)] 243 244 d.clear() 245 246 d.update(items) 247 248 249if py37: 250 OrderedDict = dict 251 sort_dictionary = _ordered_dictionary_sort 252 253else: 254 # prevent sort_dictionary from being used against a plain dictionary 255 # for Python < 3.7 256 257 def sort_dictionary(d, key=None): 258 """Sort an OrderedDict in place.""" 259 260 d._ordered_dictionary_sort(key=key) 261 262 class OrderedDict(dict): 263 """Dictionary that maintains insertion order. 264 265 Superseded by Python dict as of Python 3.7 266 267 """ 268 269 __slots__ = ("_list",) 270 271 def _ordered_dictionary_sort(self, key=None): 272 _ordered_dictionary_sort(self, key=key) 273 274 def __reduce__(self): 275 return OrderedDict, (self.items(),) 276 277 def __init__(self, ____sequence=None, **kwargs): 278 self._list = [] 279 if ____sequence is None: 280 if kwargs: 281 self.update(**kwargs) 282 else: 283 self.update(____sequence, **kwargs) 284 285 def clear(self): 286 self._list = [] 287 dict.clear(self) 288 289 def copy(self): 290 return self.__copy__() 291 292 def __copy__(self): 293 return OrderedDict(self) 294 295 def update(self, ____sequence=None, **kwargs): 296 if ____sequence is not None: 297 if hasattr(____sequence, "keys"): 298 for key in ____sequence.keys(): 299 self.__setitem__(key, ____sequence[key]) 300 else: 301 for key, value in ____sequence: 302 self[key] = value 303 if kwargs: 304 self.update(kwargs) 305 306 def setdefault(self, key, value): 307 if key not in self: 308 self.__setitem__(key, value) 309 return value 310 else: 311 return self.__getitem__(key) 312 313 def __iter__(self): 314 return iter(self._list) 315 316 def keys(self): 317 return list(self) 318 319 def values(self): 320 return [self[key] for key in self._list] 321 322 def items(self): 323 return [(key, self[key]) for key in self._list] 324 325 if py2k: 326 327 def itervalues(self): 328 return iter(self.values()) 329 330 def iterkeys(self): 331 return iter(self) 332 333 def iteritems(self): 334 return iter(self.items()) 335 336 def __setitem__(self, key, obj): 337 if key not in self: 338 try: 339 self._list.append(key) 340 except AttributeError: 341 # work around Python pickle loads() with 342 # dict subclass (seems to ignore __setstate__?) 343 self._list = [key] 344 dict.__setitem__(self, key, obj) 345 346 def __delitem__(self, key): 347 dict.__delitem__(self, key) 348 self._list.remove(key) 349 350 def pop(self, key, *default): 351 present = key in self 352 value = dict.pop(self, key, *default) 353 if present: 354 self._list.remove(key) 355 return value 356 357 def popitem(self): 358 item = dict.popitem(self) 359 self._list.remove(item[0]) 360 return item 361 362 363class OrderedSet(set): 364 def __init__(self, d=None): 365 set.__init__(self) 366 if d is not None: 367 self._list = unique_list(d) 368 set.update(self, self._list) 369 else: 370 self._list = [] 371 372 def add(self, element): 373 if element not in self: 374 self._list.append(element) 375 set.add(self, element) 376 377 def remove(self, element): 378 set.remove(self, element) 379 self._list.remove(element) 380 381 def insert(self, pos, element): 382 if element not in self: 383 self._list.insert(pos, element) 384 set.add(self, element) 385 386 def discard(self, element): 387 if element in self: 388 self._list.remove(element) 389 set.remove(self, element) 390 391 def clear(self): 392 set.clear(self) 393 self._list = [] 394 395 def __getitem__(self, key): 396 return self._list[key] 397 398 def __iter__(self): 399 return iter(self._list) 400 401 def __add__(self, other): 402 return self.union(other) 403 404 def __repr__(self): 405 return "%s(%r)" % (self.__class__.__name__, self._list) 406 407 __str__ = __repr__ 408 409 def update(self, iterable): 410 for e in iterable: 411 if e not in self: 412 self._list.append(e) 413 set.add(self, e) 414 return self 415 416 __ior__ = update 417 418 def union(self, other): 419 result = self.__class__(self) 420 result.update(other) 421 return result 422 423 __or__ = union 424 425 def intersection(self, other): 426 other = set(other) 427 return self.__class__(a for a in self if a in other) 428 429 __and__ = intersection 430 431 def symmetric_difference(self, other): 432 other = set(other) 433 result = self.__class__(a for a in self if a not in other) 434 result.update(a for a in other if a not in self) 435 return result 436 437 __xor__ = symmetric_difference 438 439 def difference(self, other): 440 other = set(other) 441 return self.__class__(a for a in self if a not in other) 442 443 __sub__ = difference 444 445 def intersection_update(self, other): 446 other = set(other) 447 set.intersection_update(self, other) 448 self._list = [a for a in self._list if a in other] 449 return self 450 451 __iand__ = intersection_update 452 453 def symmetric_difference_update(self, other): 454 set.symmetric_difference_update(self, other) 455 self._list = [a for a in self._list if a in self] 456 self._list += [a for a in other._list if a in self] 457 return self 458 459 __ixor__ = symmetric_difference_update 460 461 def difference_update(self, other): 462 set.difference_update(self, other) 463 self._list = [a for a in self._list if a in self] 464 return self 465 466 __isub__ = difference_update 467 468 469class IdentitySet(object): 470 """A set that considers only object id() for uniqueness. 471 472 This strategy has edge cases for builtin types- it's possible to have 473 two 'foo' strings in one of these sets, for example. Use sparingly. 474 475 """ 476 477 def __init__(self, iterable=None): 478 self._members = dict() 479 if iterable: 480 self.update(iterable) 481 482 def add(self, value): 483 self._members[id(value)] = value 484 485 def __contains__(self, value): 486 return id(value) in self._members 487 488 def remove(self, value): 489 del self._members[id(value)] 490 491 def discard(self, value): 492 try: 493 self.remove(value) 494 except KeyError: 495 pass 496 497 def pop(self): 498 try: 499 pair = self._members.popitem() 500 return pair[1] 501 except KeyError: 502 raise KeyError("pop from an empty set") 503 504 def clear(self): 505 self._members.clear() 506 507 def __cmp__(self, other): 508 raise TypeError("cannot compare sets using cmp()") 509 510 def __eq__(self, other): 511 if isinstance(other, IdentitySet): 512 return self._members == other._members 513 else: 514 return False 515 516 def __ne__(self, other): 517 if isinstance(other, IdentitySet): 518 return self._members != other._members 519 else: 520 return True 521 522 def issubset(self, iterable): 523 if isinstance(iterable, self.__class__): 524 other = iterable 525 else: 526 other = self.__class__(iterable) 527 528 if len(self) > len(other): 529 return False 530 for m in itertools_filterfalse( 531 other._members.__contains__, iter(self._members.keys()) 532 ): 533 return False 534 return True 535 536 def __le__(self, other): 537 if not isinstance(other, IdentitySet): 538 return NotImplemented 539 return self.issubset(other) 540 541 def __lt__(self, other): 542 if not isinstance(other, IdentitySet): 543 return NotImplemented 544 return len(self) < len(other) and self.issubset(other) 545 546 def issuperset(self, iterable): 547 if isinstance(iterable, self.__class__): 548 other = iterable 549 else: 550 other = self.__class__(iterable) 551 552 if len(self) < len(other): 553 return False 554 555 for m in itertools_filterfalse( 556 self._members.__contains__, iter(other._members.keys()) 557 ): 558 return False 559 return True 560 561 def __ge__(self, other): 562 if not isinstance(other, IdentitySet): 563 return NotImplemented 564 return self.issuperset(other) 565 566 def __gt__(self, other): 567 if not isinstance(other, IdentitySet): 568 return NotImplemented 569 return len(self) > len(other) and self.issuperset(other) 570 571 def union(self, iterable): 572 result = self.__class__() 573 members = self._members 574 result._members.update(members) 575 result._members.update((id(obj), obj) for obj in iterable) 576 return result 577 578 def __or__(self, other): 579 if not isinstance(other, IdentitySet): 580 return NotImplemented 581 return self.union(other) 582 583 def update(self, iterable): 584 self._members.update((id(obj), obj) for obj in iterable) 585 586 def __ior__(self, other): 587 if not isinstance(other, IdentitySet): 588 return NotImplemented 589 self.update(other) 590 return self 591 592 def difference(self, iterable): 593 result = self.__class__() 594 members = self._members 595 if isinstance(iterable, self.__class__): 596 other = set(iterable._members.keys()) 597 else: 598 other = {id(obj) for obj in iterable} 599 result._members.update( 600 ((k, v) for k, v in members.items() if k not in other) 601 ) 602 return result 603 604 def __sub__(self, other): 605 if not isinstance(other, IdentitySet): 606 return NotImplemented 607 return self.difference(other) 608 609 def difference_update(self, iterable): 610 self._members = self.difference(iterable)._members 611 612 def __isub__(self, other): 613 if not isinstance(other, IdentitySet): 614 return NotImplemented 615 self.difference_update(other) 616 return self 617 618 def intersection(self, iterable): 619 result = self.__class__() 620 members = self._members 621 if isinstance(iterable, self.__class__): 622 other = set(iterable._members.keys()) 623 else: 624 other = {id(obj) for obj in iterable} 625 result._members.update( 626 (k, v) for k, v in members.items() if k in other 627 ) 628 return result 629 630 def __and__(self, other): 631 if not isinstance(other, IdentitySet): 632 return NotImplemented 633 return self.intersection(other) 634 635 def intersection_update(self, iterable): 636 self._members = self.intersection(iterable)._members 637 638 def __iand__(self, other): 639 if not isinstance(other, IdentitySet): 640 return NotImplemented 641 self.intersection_update(other) 642 return self 643 644 def symmetric_difference(self, iterable): 645 result = self.__class__() 646 members = self._members 647 if isinstance(iterable, self.__class__): 648 other = iterable._members 649 else: 650 other = {id(obj): obj for obj in iterable} 651 result._members.update( 652 ((k, v) for k, v in members.items() if k not in other) 653 ) 654 result._members.update( 655 ((k, v) for k, v in other.items() if k not in members) 656 ) 657 return result 658 659 def __xor__(self, other): 660 if not isinstance(other, IdentitySet): 661 return NotImplemented 662 return self.symmetric_difference(other) 663 664 def symmetric_difference_update(self, iterable): 665 self._members = self.symmetric_difference(iterable)._members 666 667 def __ixor__(self, other): 668 if not isinstance(other, IdentitySet): 669 return NotImplemented 670 self.symmetric_difference(other) 671 return self 672 673 def copy(self): 674 return type(self)(iter(self._members.values())) 675 676 __copy__ = copy 677 678 def __len__(self): 679 return len(self._members) 680 681 def __iter__(self): 682 return iter(self._members.values()) 683 684 def __hash__(self): 685 raise TypeError("set objects are unhashable") 686 687 def __repr__(self): 688 return "%s(%r)" % (type(self).__name__, list(self._members.values())) 689 690 691class WeakSequence(object): 692 def __init__(self, __elements=()): 693 # adapted from weakref.WeakKeyDictionary, prevent reference 694 # cycles in the collection itself 695 def _remove(item, selfref=weakref.ref(self)): 696 self = selfref() 697 if self is not None: 698 self._storage.remove(item) 699 700 self._remove = _remove 701 self._storage = [ 702 weakref.ref(element, _remove) for element in __elements 703 ] 704 705 def append(self, item): 706 self._storage.append(weakref.ref(item, self._remove)) 707 708 def __len__(self): 709 return len(self._storage) 710 711 def __iter__(self): 712 return ( 713 obj for obj in (ref() for ref in self._storage) if obj is not None 714 ) 715 716 def __getitem__(self, index): 717 try: 718 obj = self._storage[index] 719 except KeyError: 720 raise IndexError("Index %s out of range" % index) 721 else: 722 return obj() 723 724 725class OrderedIdentitySet(IdentitySet): 726 def __init__(self, iterable=None): 727 IdentitySet.__init__(self) 728 self._members = OrderedDict() 729 if iterable: 730 for o in iterable: 731 self.add(o) 732 733 734class PopulateDict(dict): 735 """A dict which populates missing values via a creation function. 736 737 Note the creation function takes a key, unlike 738 collections.defaultdict. 739 740 """ 741 742 def __init__(self, creator): 743 self.creator = creator 744 745 def __missing__(self, key): 746 self[key] = val = self.creator(key) 747 return val 748 749 750class WeakPopulateDict(dict): 751 """Like PopulateDict, but assumes a self + a method and does not create 752 a reference cycle. 753 754 """ 755 756 def __init__(self, creator_method): 757 self.creator = creator_method.__func__ 758 weakself = creator_method.__self__ 759 self.weakself = weakref.ref(weakself) 760 761 def __missing__(self, key): 762 self[key] = val = self.creator(self.weakself(), key) 763 return val 764 765 766# Define collections that are capable of storing 767# ColumnElement objects as hashable keys/elements. 768# At this point, these are mostly historical, things 769# used to be more complicated. 770column_set = set 771column_dict = dict 772ordered_column_set = OrderedSet 773 774 775_getters = PopulateDict(operator.itemgetter) 776 777_property_getters = PopulateDict( 778 lambda idx: property(operator.itemgetter(idx)) 779) 780 781 782def unique_list(seq, hashfunc=None): 783 seen = set() 784 seen_add = seen.add 785 if not hashfunc: 786 return [x for x in seq if x not in seen and not seen_add(x)] 787 else: 788 return [ 789 x 790 for x in seq 791 if hashfunc(x) not in seen and not seen_add(hashfunc(x)) 792 ] 793 794 795class UniqueAppender(object): 796 """Appends items to a collection ensuring uniqueness. 797 798 Additional appends() of the same object are ignored. Membership is 799 determined by identity (``is a``) not equality (``==``). 800 """ 801 802 def __init__(self, data, via=None): 803 self.data = data 804 self._unique = {} 805 if via: 806 self._data_appender = getattr(data, via) 807 elif hasattr(data, "append"): 808 self._data_appender = data.append 809 elif hasattr(data, "add"): 810 self._data_appender = data.add 811 812 def append(self, item): 813 id_ = id(item) 814 if id_ not in self._unique: 815 self._data_appender(item) 816 self._unique[id_] = True 817 818 def __iter__(self): 819 return iter(self.data) 820 821 822def coerce_generator_arg(arg): 823 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): 824 return list(arg[0]) 825 else: 826 return arg 827 828 829def to_list(x, default=None): 830 if x is None: 831 return default 832 if not isinstance(x, collections_abc.Iterable) or isinstance( 833 x, string_types + binary_types 834 ): 835 return [x] 836 elif isinstance(x, list): 837 return x 838 else: 839 return list(x) 840 841 842def has_intersection(set_, iterable): 843 r"""return True if any items of set\_ are present in iterable. 844 845 Goes through special effort to ensure __hash__ is not called 846 on items in iterable that don't support it. 847 848 """ 849 # TODO: optimize, write in C, etc. 850 return bool(set_.intersection([i for i in iterable if i.__hash__])) 851 852 853def to_set(x): 854 if x is None: 855 return set() 856 if not isinstance(x, set): 857 return set(to_list(x)) 858 else: 859 return x 860 861 862def to_column_set(x): 863 if x is None: 864 return column_set() 865 if not isinstance(x, column_set): 866 return column_set(to_list(x)) 867 else: 868 return x 869 870 871def update_copy(d, _new=None, **kw): 872 """Copy the given dict and update with the given values.""" 873 874 d = d.copy() 875 if _new: 876 d.update(_new) 877 d.update(**kw) 878 return d 879 880 881def flatten_iterator(x): 882 """Given an iterator of which further sub-elements may also be 883 iterators, flatten the sub-elements into a single iterator. 884 885 """ 886 for elem in x: 887 if not isinstance(elem, str) and hasattr(elem, "__iter__"): 888 for y in flatten_iterator(elem): 889 yield y 890 else: 891 yield elem 892 893 894class LRUCache(dict): 895 """Dictionary with 'squishy' removal of least 896 recently used items. 897 898 Note that either get() or [] should be used here, but 899 generally its not safe to do an "in" check first as the dictionary 900 can change subsequent to that call. 901 902 """ 903 904 __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex" 905 906 def __init__(self, capacity=100, threshold=0.5, size_alert=None): 907 self.capacity = capacity 908 self.threshold = threshold 909 self.size_alert = size_alert 910 self._counter = 0 911 self._mutex = threading.Lock() 912 913 def _inc_counter(self): 914 self._counter += 1 915 return self._counter 916 917 def get(self, key, default=None): 918 item = dict.get(self, key, default) 919 if item is not default: 920 item[2] = self._inc_counter() 921 return item[1] 922 else: 923 return default 924 925 def __getitem__(self, key): 926 item = dict.__getitem__(self, key) 927 item[2] = self._inc_counter() 928 return item[1] 929 930 def values(self): 931 return [i[1] for i in dict.values(self)] 932 933 def setdefault(self, key, value): 934 if key in self: 935 return self[key] 936 else: 937 self[key] = value 938 return value 939 940 def __setitem__(self, key, value): 941 item = dict.get(self, key) 942 if item is None: 943 item = [key, value, self._inc_counter()] 944 dict.__setitem__(self, key, item) 945 else: 946 item[1] = value 947 self._manage_size() 948 949 @property 950 def size_threshold(self): 951 return self.capacity + self.capacity * self.threshold 952 953 def _manage_size(self): 954 if not self._mutex.acquire(False): 955 return 956 try: 957 size_alert = bool(self.size_alert) 958 while len(self) > self.capacity + self.capacity * self.threshold: 959 if size_alert: 960 size_alert = False 961 self.size_alert(self) 962 by_counter = sorted( 963 dict.values(self), key=operator.itemgetter(2), reverse=True 964 ) 965 for item in by_counter[self.capacity :]: 966 try: 967 del self[item[0]] 968 except KeyError: 969 # deleted elsewhere; skip 970 continue 971 finally: 972 self._mutex.release() 973 974 975class ScopedRegistry(object): 976 """A Registry that can store one or multiple instances of a single 977 class on the basis of a "scope" function. 978 979 The object implements ``__call__`` as the "getter", so by 980 calling ``myregistry()`` the contained object is returned 981 for the current scope. 982 983 :param createfunc: 984 a callable that returns a new object to be placed in the registry 985 986 :param scopefunc: 987 a callable that will return a key to store/retrieve an object. 988 """ 989 990 def __init__(self, createfunc, scopefunc): 991 """Construct a new :class:`.ScopedRegistry`. 992 993 :param createfunc: A creation function that will generate 994 a new value for the current scope, if none is present. 995 996 :param scopefunc: A function that returns a hashable 997 token representing the current scope (such as, current 998 thread identifier). 999 1000 """ 1001 self.createfunc = createfunc 1002 self.scopefunc = scopefunc 1003 self.registry = {} 1004 1005 def __call__(self): 1006 key = self.scopefunc() 1007 try: 1008 return self.registry[key] 1009 except KeyError: 1010 return self.registry.setdefault(key, self.createfunc()) 1011 1012 def has(self): 1013 """Return True if an object is present in the current scope.""" 1014 1015 return self.scopefunc() in self.registry 1016 1017 def set(self, obj): 1018 """Set the value for the current scope.""" 1019 1020 self.registry[self.scopefunc()] = obj 1021 1022 def clear(self): 1023 """Clear the current scope, if any.""" 1024 1025 try: 1026 del self.registry[self.scopefunc()] 1027 except KeyError: 1028 pass 1029 1030 1031class ThreadLocalRegistry(ScopedRegistry): 1032 """A :class:`.ScopedRegistry` that uses a ``threading.local()`` 1033 variable for storage. 1034 1035 """ 1036 1037 def __init__(self, createfunc): 1038 self.createfunc = createfunc 1039 self.registry = threading.local() 1040 1041 def __call__(self): 1042 try: 1043 return self.registry.value 1044 except AttributeError: 1045 val = self.registry.value = self.createfunc() 1046 return val 1047 1048 def has(self): 1049 return hasattr(self.registry, "value") 1050 1051 def set(self, obj): 1052 self.registry.value = obj 1053 1054 def clear(self): 1055 try: 1056 del self.registry.value 1057 except AttributeError: 1058 pass 1059 1060 1061def has_dupes(sequence, target): 1062 """Given a sequence and search object, return True if there's more 1063 than one, False if zero or one of them. 1064 1065 1066 """ 1067 # compare to .index version below, this version introduces less function 1068 # overhead and is usually the same speed. At 15000 items (way bigger than 1069 # a relationship-bound collection in memory usually is) it begins to 1070 # fall behind the other version only by microseconds. 1071 c = 0 1072 for item in sequence: 1073 if item is target: 1074 c += 1 1075 if c > 1: 1076 return True 1077 return False 1078 1079 1080# .index version. the two __contains__ calls as well 1081# as .index() and isinstance() slow this down. 1082# def has_dupes(sequence, target): 1083# if target not in sequence: 1084# return False 1085# elif not isinstance(sequence, collections_abc.Sequence): 1086# return False 1087# 1088# idx = sequence.index(target) 1089# return target in sequence[idx + 1:] 1090