1# util/_collections.py 2# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors 3# <see AUTHORS file> 4# 5# This module is part of SQLAlchemy and is released under 6# the MIT License: http://www.opensource.org/licenses/mit-license.php 7 8"""Collection classes and helpers.""" 9 10from __future__ import absolute_import 11 12import operator 13import types 14import weakref 15 16from .compat import binary_types 17from .compat import collections_abc 18from .compat import itertools_filterfalse 19from .compat import py2k 20from .compat import string_types 21from .compat import threading 22 23 24EMPTY_SET = frozenset() 25 26 27class AbstractKeyedTuple(tuple): 28 __slots__ = () 29 30 def keys(self): 31 """Return a list of string key names for this :class:`.KeyedTuple`. 32 33 .. seealso:: 34 35 :attr:`.KeyedTuple._fields` 36 37 """ 38 39 return list(self._fields) 40 41 42class KeyedTuple(AbstractKeyedTuple): 43 """``tuple`` subclass that adds labeled names. 44 45 E.g.:: 46 47 >>> k = KeyedTuple([1, 2, 3], labels=["one", "two", "three"]) 48 >>> k.one 49 1 50 >>> k.two 51 2 52 53 Result rows returned by :class:`.Query` that contain multiple 54 ORM entities and/or column expressions make use of this 55 class to return rows. 56 57 The :class:`.KeyedTuple` exhibits similar behavior to the 58 ``collections.namedtuple()`` construct provided in the Python 59 standard library, however is architected very differently. 60 Unlike ``collections.namedtuple()``, :class:`.KeyedTuple` is 61 does not rely on creation of custom subtypes in order to represent 62 a new series of keys, instead each :class:`.KeyedTuple` instance 63 receives its list of keys in place. The subtype approach 64 of ``collections.namedtuple()`` introduces significant complexity 65 and performance overhead, which is not necessary for the 66 :class:`.Query` object's use case. 67 68 .. seealso:: 69 70 :ref:`ormtutorial_querying` 71 72 """ 73 74 def __new__(cls, vals, labels=None): 75 t = tuple.__new__(cls, vals) 76 if labels: 77 t.__dict__.update(zip(labels, vals)) 78 else: 79 labels = [] 80 t.__dict__["_labels"] = labels 81 return t 82 83 @property 84 def _fields(self): 85 """Return a tuple of string key names for this :class:`.KeyedTuple`. 86 87 This method provides compatibility with ``collections.namedtuple()``. 88 89 .. seealso:: 90 91 :meth:`.KeyedTuple.keys` 92 93 """ 94 return tuple([l for l in self._labels if l is not None]) 95 96 def __setattr__(self, key, value): 97 raise AttributeError("Can't set attribute: %s" % key) 98 99 def _asdict(self): 100 """Return the contents of this :class:`.KeyedTuple` as a dictionary. 101 102 This method provides compatibility with ``collections.namedtuple()``, 103 with the exception that the dictionary returned is **not** ordered. 104 105 """ 106 return {key: self.__dict__[key] for key in self.keys()} 107 108 109class _LW(AbstractKeyedTuple): 110 __slots__ = () 111 112 def __new__(cls, vals): 113 return tuple.__new__(cls, vals) 114 115 def __reduce__(self): 116 # for pickling, degrade down to the regular 117 # KeyedTuple, thus avoiding anonymous class pickling 118 # difficulties 119 return KeyedTuple, (list(self), self._real_fields) 120 121 def _asdict(self): 122 """Return the contents of this :class:`.KeyedTuple` as a dictionary.""" 123 124 d = dict(zip(self._real_fields, self)) 125 d.pop(None, None) 126 return d 127 128 129class ImmutableContainer(object): 130 def _immutable(self, *arg, **kw): 131 raise TypeError("%s object is immutable" % self.__class__.__name__) 132 133 __delitem__ = __setitem__ = __setattr__ = _immutable 134 135 136class immutabledict(ImmutableContainer, dict): 137 138 clear = pop = popitem = setdefault = update = ImmutableContainer._immutable 139 140 def __new__(cls, *args): 141 new = dict.__new__(cls) 142 dict.__init__(new, *args) 143 return new 144 145 def __init__(self, *args): 146 pass 147 148 def __reduce__(self): 149 return immutabledict, (dict(self),) 150 151 def union(self, d): 152 if not d: 153 return self 154 elif not self: 155 if isinstance(d, immutabledict): 156 return d 157 else: 158 return immutabledict(d) 159 else: 160 d2 = immutabledict(self) 161 dict.update(d2, d) 162 return d2 163 164 def __repr__(self): 165 return "immutabledict(%s)" % dict.__repr__(self) 166 167 168class Properties(object): 169 """Provide a __getattr__/__setattr__ interface over a dict.""" 170 171 __slots__ = ("_data",) 172 173 def __init__(self, data): 174 object.__setattr__(self, "_data", data) 175 176 def __len__(self): 177 return len(self._data) 178 179 def __iter__(self): 180 return iter(list(self._data.values())) 181 182 def __add__(self, other): 183 return list(self) + list(other) 184 185 def __setitem__(self, key, obj): 186 self._data[key] = obj 187 188 def __getitem__(self, key): 189 return self._data[key] 190 191 def __delitem__(self, key): 192 del self._data[key] 193 194 def __setattr__(self, key, obj): 195 self._data[key] = obj 196 197 def __getstate__(self): 198 return {"_data": self._data} 199 200 def __setstate__(self, state): 201 object.__setattr__(self, "_data", state["_data"]) 202 203 def __getattr__(self, key): 204 try: 205 return self._data[key] 206 except KeyError: 207 raise AttributeError(key) 208 209 def __contains__(self, key): 210 return key in self._data 211 212 def as_immutable(self): 213 """Return an immutable proxy for this :class:`.Properties`.""" 214 215 return ImmutableProperties(self._data) 216 217 def update(self, value): 218 self._data.update(value) 219 220 def get(self, key, default=None): 221 if key in self: 222 return self[key] 223 else: 224 return default 225 226 def keys(self): 227 return list(self._data) 228 229 def values(self): 230 return list(self._data.values()) 231 232 def items(self): 233 return list(self._data.items()) 234 235 def has_key(self, key): 236 return key in self._data 237 238 def clear(self): 239 self._data.clear() 240 241 242class OrderedProperties(Properties): 243 """Provide a __getattr__/__setattr__ interface with an OrderedDict 244 as backing store.""" 245 246 __slots__ = () 247 248 def __init__(self): 249 Properties.__init__(self, OrderedDict()) 250 251 252class ImmutableProperties(ImmutableContainer, Properties): 253 """Provide immutable dict/object attribute to an underlying dictionary.""" 254 255 __slots__ = () 256 257 258class OrderedDict(dict): 259 """A dict that returns keys/values/items in the order they were added.""" 260 261 __slots__ = ("_list",) 262 263 def __reduce__(self): 264 return OrderedDict, (self.items(),) 265 266 def __init__(self, ____sequence=None, **kwargs): 267 self._list = [] 268 if ____sequence is None: 269 if kwargs: 270 self.update(**kwargs) 271 else: 272 self.update(____sequence, **kwargs) 273 274 def clear(self): 275 self._list = [] 276 dict.clear(self) 277 278 def copy(self): 279 return self.__copy__() 280 281 def __copy__(self): 282 return OrderedDict(self) 283 284 def sort(self, *arg, **kw): 285 self._list.sort(*arg, **kw) 286 287 def update(self, ____sequence=None, **kwargs): 288 if ____sequence is not None: 289 if hasattr(____sequence, "keys"): 290 for key in ____sequence.keys(): 291 self.__setitem__(key, ____sequence[key]) 292 else: 293 for key, value in ____sequence: 294 self[key] = value 295 if kwargs: 296 self.update(kwargs) 297 298 def setdefault(self, key, value): 299 if key not in self: 300 self.__setitem__(key, value) 301 return value 302 else: 303 return self.__getitem__(key) 304 305 def __iter__(self): 306 return iter(self._list) 307 308 def keys(self): 309 return list(self) 310 311 def values(self): 312 return [self[key] for key in self._list] 313 314 def items(self): 315 return [(key, self[key]) for key in self._list] 316 317 if py2k: 318 319 def itervalues(self): 320 return iter(self.values()) 321 322 def iterkeys(self): 323 return iter(self) 324 325 def iteritems(self): 326 return iter(self.items()) 327 328 def __setitem__(self, key, obj): 329 if key not in self: 330 try: 331 self._list.append(key) 332 except AttributeError: 333 # work around Python pickle loads() with 334 # dict subclass (seems to ignore __setstate__?) 335 self._list = [key] 336 dict.__setitem__(self, key, obj) 337 338 def __delitem__(self, key): 339 dict.__delitem__(self, key) 340 self._list.remove(key) 341 342 def pop(self, key, *default): 343 present = key in self 344 value = dict.pop(self, key, *default) 345 if present: 346 self._list.remove(key) 347 return value 348 349 def popitem(self): 350 item = dict.popitem(self) 351 self._list.remove(item[0]) 352 return item 353 354 355class OrderedSet(set): 356 def __init__(self, d=None): 357 set.__init__(self) 358 self._list = [] 359 if d is not None: 360 self._list = unique_list(d) 361 set.update(self, self._list) 362 else: 363 self._list = [] 364 365 def add(self, element): 366 if element not in self: 367 self._list.append(element) 368 set.add(self, element) 369 370 def remove(self, element): 371 set.remove(self, element) 372 self._list.remove(element) 373 374 def insert(self, pos, element): 375 if element not in self: 376 self._list.insert(pos, element) 377 set.add(self, element) 378 379 def discard(self, element): 380 if element in self: 381 self._list.remove(element) 382 set.remove(self, element) 383 384 def clear(self): 385 set.clear(self) 386 self._list = [] 387 388 def __getitem__(self, key): 389 return self._list[key] 390 391 def __iter__(self): 392 return iter(self._list) 393 394 def __add__(self, other): 395 return self.union(other) 396 397 def __repr__(self): 398 return "%s(%r)" % (self.__class__.__name__, self._list) 399 400 __str__ = __repr__ 401 402 def update(self, iterable): 403 for e in iterable: 404 if e not in self: 405 self._list.append(e) 406 set.add(self, e) 407 return self 408 409 __ior__ = update 410 411 def union(self, other): 412 result = self.__class__(self) 413 result.update(other) 414 return result 415 416 __or__ = union 417 418 def intersection(self, other): 419 other = set(other) 420 return self.__class__(a for a in self if a in other) 421 422 __and__ = intersection 423 424 def symmetric_difference(self, other): 425 other = set(other) 426 result = self.__class__(a for a in self if a not in other) 427 result.update(a for a in other if a not in self) 428 return result 429 430 __xor__ = symmetric_difference 431 432 def difference(self, other): 433 other = set(other) 434 return self.__class__(a for a in self if a not in other) 435 436 __sub__ = difference 437 438 def intersection_update(self, other): 439 other = set(other) 440 set.intersection_update(self, other) 441 self._list = [a for a in self._list if a in other] 442 return self 443 444 __iand__ = intersection_update 445 446 def symmetric_difference_update(self, other): 447 set.symmetric_difference_update(self, other) 448 self._list = [a for a in self._list if a in self] 449 self._list += [a for a in other._list if a in self] 450 return self 451 452 __ixor__ = symmetric_difference_update 453 454 def difference_update(self, other): 455 set.difference_update(self, other) 456 self._list = [a for a in self._list if a in self] 457 return self 458 459 __isub__ = difference_update 460 461 462class IdentitySet(object): 463 """A set that considers only object id() for uniqueness. 464 465 This strategy has edge cases for builtin types- it's possible to have 466 two 'foo' strings in one of these sets, for example. Use sparingly. 467 468 """ 469 470 _working_set = set 471 472 def __init__(self, iterable=None): 473 self._members = dict() 474 if iterable: 475 for o in iterable: 476 self.add(o) 477 478 def add(self, value): 479 self._members[id(value)] = value 480 481 def __contains__(self, value): 482 return id(value) in self._members 483 484 def remove(self, value): 485 del self._members[id(value)] 486 487 def discard(self, value): 488 try: 489 self.remove(value) 490 except KeyError: 491 pass 492 493 def pop(self): 494 try: 495 pair = self._members.popitem() 496 return pair[1] 497 except KeyError: 498 raise KeyError("pop from an empty set") 499 500 def clear(self): 501 self._members.clear() 502 503 def __cmp__(self, other): 504 raise TypeError("cannot compare sets using cmp()") 505 506 def __eq__(self, other): 507 if isinstance(other, IdentitySet): 508 return self._members == other._members 509 else: 510 return False 511 512 def __ne__(self, other): 513 if isinstance(other, IdentitySet): 514 return self._members != other._members 515 else: 516 return True 517 518 def issubset(self, iterable): 519 other = type(self)(iterable) 520 521 if len(self) > len(other): 522 return False 523 for m in itertools_filterfalse( 524 other._members.__contains__, iter(self._members.keys()) 525 ): 526 return False 527 return True 528 529 def __le__(self, other): 530 if not isinstance(other, IdentitySet): 531 return NotImplemented 532 return self.issubset(other) 533 534 def __lt__(self, other): 535 if not isinstance(other, IdentitySet): 536 return NotImplemented 537 return len(self) < len(other) and self.issubset(other) 538 539 def issuperset(self, iterable): 540 other = type(self)(iterable) 541 542 if len(self) < len(other): 543 return False 544 545 for m in itertools_filterfalse( 546 self._members.__contains__, iter(other._members.keys()) 547 ): 548 return False 549 return True 550 551 def __ge__(self, other): 552 if not isinstance(other, IdentitySet): 553 return NotImplemented 554 return self.issuperset(other) 555 556 def __gt__(self, other): 557 if not isinstance(other, IdentitySet): 558 return NotImplemented 559 return len(self) > len(other) and self.issuperset(other) 560 561 def union(self, iterable): 562 result = type(self)() 563 # testlib.pragma exempt:__hash__ 564 members = self._member_id_tuples() 565 other = _iter_id(iterable) 566 result._members.update(self._working_set(members).union(other)) 567 return result 568 569 def __or__(self, other): 570 if not isinstance(other, IdentitySet): 571 return NotImplemented 572 return self.union(other) 573 574 def update(self, iterable): 575 self._members = self.union(iterable)._members 576 577 def __ior__(self, other): 578 if not isinstance(other, IdentitySet): 579 return NotImplemented 580 self.update(other) 581 return self 582 583 def difference(self, iterable): 584 result = type(self)() 585 # testlib.pragma exempt:__hash__ 586 members = self._member_id_tuples() 587 other = _iter_id(iterable) 588 result._members.update(self._working_set(members).difference(other)) 589 return result 590 591 def __sub__(self, other): 592 if not isinstance(other, IdentitySet): 593 return NotImplemented 594 return self.difference(other) 595 596 def difference_update(self, iterable): 597 self._members = self.difference(iterable)._members 598 599 def __isub__(self, other): 600 if not isinstance(other, IdentitySet): 601 return NotImplemented 602 self.difference_update(other) 603 return self 604 605 def intersection(self, iterable): 606 result = type(self)() 607 # testlib.pragma exempt:__hash__ 608 members = self._member_id_tuples() 609 other = _iter_id(iterable) 610 result._members.update(self._working_set(members).intersection(other)) 611 return result 612 613 def __and__(self, other): 614 if not isinstance(other, IdentitySet): 615 return NotImplemented 616 return self.intersection(other) 617 618 def intersection_update(self, iterable): 619 self._members = self.intersection(iterable)._members 620 621 def __iand__(self, other): 622 if not isinstance(other, IdentitySet): 623 return NotImplemented 624 self.intersection_update(other) 625 return self 626 627 def symmetric_difference(self, iterable): 628 result = type(self)() 629 # testlib.pragma exempt:__hash__ 630 members = self._member_id_tuples() 631 other = _iter_id(iterable) 632 result._members.update( 633 self._working_set(members).symmetric_difference(other) 634 ) 635 return result 636 637 def _member_id_tuples(self): 638 return ((id(v), v) for v in self._members.values()) 639 640 def __xor__(self, other): 641 if not isinstance(other, IdentitySet): 642 return NotImplemented 643 return self.symmetric_difference(other) 644 645 def symmetric_difference_update(self, iterable): 646 self._members = self.symmetric_difference(iterable)._members 647 648 def __ixor__(self, other): 649 if not isinstance(other, IdentitySet): 650 return NotImplemented 651 self.symmetric_difference(other) 652 return self 653 654 def copy(self): 655 return type(self)(iter(self._members.values())) 656 657 __copy__ = copy 658 659 def __len__(self): 660 return len(self._members) 661 662 def __iter__(self): 663 return iter(self._members.values()) 664 665 def __hash__(self): 666 raise TypeError("set objects are unhashable") 667 668 def __repr__(self): 669 return "%s(%r)" % (type(self).__name__, list(self._members.values())) 670 671 672class WeakSequence(object): 673 def __init__(self, __elements=()): 674 self._storage = [ 675 weakref.ref(element, self._remove) for element in __elements 676 ] 677 678 def append(self, item): 679 self._storage.append(weakref.ref(item, self._remove)) 680 681 def _remove(self, ref): 682 self._storage.remove(ref) 683 684 def __len__(self): 685 return len(self._storage) 686 687 def __iter__(self): 688 return ( 689 obj for obj in (ref() for ref in self._storage) if obj is not None 690 ) 691 692 def __getitem__(self, index): 693 try: 694 obj = self._storage[index] 695 except KeyError: 696 raise IndexError("Index %s out of range" % index) 697 else: 698 return obj() 699 700 701class OrderedIdentitySet(IdentitySet): 702 class _working_set(OrderedSet): 703 # a testing pragma: exempt the OIDS working set from the test suite's 704 # "never call the user's __hash__" assertions. this is a big hammer, 705 # but it's safe here: IDS operates on (id, instance) tuples in the 706 # working set. 707 __sa_hash_exempt__ = True 708 709 def __init__(self, iterable=None): 710 IdentitySet.__init__(self) 711 self._members = OrderedDict() 712 if iterable: 713 for o in iterable: 714 self.add(o) 715 716 717class PopulateDict(dict): 718 """A dict which populates missing values via a creation function. 719 720 Note the creation function takes a key, unlike 721 collections.defaultdict. 722 723 """ 724 725 def __init__(self, creator): 726 self.creator = creator 727 728 def __missing__(self, key): 729 self[key] = val = self.creator(key) 730 return val 731 732 733# Define collections that are capable of storing 734# ColumnElement objects as hashable keys/elements. 735# At this point, these are mostly historical, things 736# used to be more complicated. 737column_set = set 738column_dict = dict 739ordered_column_set = OrderedSet 740populate_column_dict = PopulateDict 741 742 743_getters = PopulateDict(operator.itemgetter) 744 745_property_getters = PopulateDict( 746 lambda idx: property(operator.itemgetter(idx)) 747) 748 749 750def unique_list(seq, hashfunc=None): 751 seen = set() 752 seen_add = seen.add 753 if not hashfunc: 754 return [x for x in seq if x not in seen and not seen_add(x)] 755 else: 756 return [ 757 x 758 for x in seq 759 if hashfunc(x) not in seen and not seen_add(hashfunc(x)) 760 ] 761 762 763class UniqueAppender(object): 764 """Appends items to a collection ensuring uniqueness. 765 766 Additional appends() of the same object are ignored. Membership is 767 determined by identity (``is a``) not equality (``==``). 768 """ 769 770 def __init__(self, data, via=None): 771 self.data = data 772 self._unique = {} 773 if via: 774 self._data_appender = getattr(data, via) 775 elif hasattr(data, "append"): 776 self._data_appender = data.append 777 elif hasattr(data, "add"): 778 self._data_appender = data.add 779 780 def append(self, item): 781 id_ = id(item) 782 if id_ not in self._unique: 783 self._data_appender(item) 784 self._unique[id_] = True 785 786 def __iter__(self): 787 return iter(self.data) 788 789 790def coerce_generator_arg(arg): 791 if len(arg) == 1 and isinstance(arg[0], types.GeneratorType): 792 return list(arg[0]) 793 else: 794 return arg 795 796 797def to_list(x, default=None): 798 if x is None: 799 return default 800 if not isinstance(x, collections_abc.Iterable) or isinstance( 801 x, string_types + binary_types 802 ): 803 return [x] 804 elif isinstance(x, list): 805 return x 806 else: 807 return list(x) 808 809 810def has_intersection(set_, iterable): 811 r"""return True if any items of set\_ are present in iterable. 812 813 Goes through special effort to ensure __hash__ is not called 814 on items in iterable that don't support it. 815 816 """ 817 # TODO: optimize, write in C, etc. 818 return bool(set_.intersection([i for i in iterable if i.__hash__])) 819 820 821def to_set(x): 822 if x is None: 823 return set() 824 if not isinstance(x, set): 825 return set(to_list(x)) 826 else: 827 return x 828 829 830def to_column_set(x): 831 if x is None: 832 return column_set() 833 if not isinstance(x, column_set): 834 return column_set(to_list(x)) 835 else: 836 return x 837 838 839def update_copy(d, _new=None, **kw): 840 """Copy the given dict and update with the given values.""" 841 842 d = d.copy() 843 if _new: 844 d.update(_new) 845 d.update(**kw) 846 return d 847 848 849def flatten_iterator(x): 850 """Given an iterator of which further sub-elements may also be 851 iterators, flatten the sub-elements into a single iterator. 852 853 """ 854 for elem in x: 855 if not isinstance(elem, str) and hasattr(elem, "__iter__"): 856 for y in flatten_iterator(elem): 857 yield y 858 else: 859 yield elem 860 861 862class LRUCache(dict): 863 """Dictionary with 'squishy' removal of least 864 recently used items. 865 866 Note that either get() or [] should be used here, but 867 generally its not safe to do an "in" check first as the dictionary 868 can change subsequent to that call. 869 870 """ 871 872 __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex" 873 874 def __init__(self, capacity=100, threshold=0.5, size_alert=None): 875 self.capacity = capacity 876 self.threshold = threshold 877 self.size_alert = size_alert 878 self._counter = 0 879 self._mutex = threading.Lock() 880 881 def _inc_counter(self): 882 self._counter += 1 883 return self._counter 884 885 def get(self, key, default=None): 886 item = dict.get(self, key, default) 887 if item is not default: 888 item[2] = self._inc_counter() 889 return item[1] 890 else: 891 return default 892 893 def __getitem__(self, key): 894 item = dict.__getitem__(self, key) 895 item[2] = self._inc_counter() 896 return item[1] 897 898 def values(self): 899 return [i[1] for i in dict.values(self)] 900 901 def setdefault(self, key, value): 902 if key in self: 903 return self[key] 904 else: 905 self[key] = value 906 return value 907 908 def __setitem__(self, key, value): 909 item = dict.get(self, key) 910 if item is None: 911 item = [key, value, self._inc_counter()] 912 dict.__setitem__(self, key, item) 913 else: 914 item[1] = value 915 self._manage_size() 916 917 @property 918 def size_threshold(self): 919 return self.capacity + self.capacity * self.threshold 920 921 def _manage_size(self): 922 if not self._mutex.acquire(False): 923 return 924 try: 925 size_alert = bool(self.size_alert) 926 while len(self) > self.capacity + self.capacity * self.threshold: 927 if size_alert: 928 size_alert = False 929 self.size_alert(self) 930 by_counter = sorted( 931 dict.values(self), key=operator.itemgetter(2), reverse=True 932 ) 933 for item in by_counter[self.capacity :]: 934 try: 935 del self[item[0]] 936 except KeyError: 937 # deleted elsewhere; skip 938 continue 939 finally: 940 self._mutex.release() 941 942 943_lw_tuples = LRUCache(100) 944 945 946def lightweight_named_tuple(name, fields): 947 hash_ = (name,) + tuple(fields) 948 tp_cls = _lw_tuples.get(hash_) 949 if tp_cls: 950 return tp_cls 951 952 tp_cls = type( 953 name, 954 (_LW,), 955 dict( 956 [ 957 (field, _property_getters[idx]) 958 for idx, field in enumerate(fields) 959 if field is not None 960 ] 961 + [("__slots__", ())] 962 ), 963 ) 964 965 tp_cls._real_fields = fields 966 tp_cls._fields = tuple([f for f in fields if f is not None]) 967 968 _lw_tuples[hash_] = tp_cls 969 return tp_cls 970 971 972class ScopedRegistry(object): 973 """A Registry that can store one or multiple instances of a single 974 class on the basis of a "scope" function. 975 976 The object implements ``__call__`` as the "getter", so by 977 calling ``myregistry()`` the contained object is returned 978 for the current scope. 979 980 :param createfunc: 981 a callable that returns a new object to be placed in the registry 982 983 :param scopefunc: 984 a callable that will return a key to store/retrieve an object. 985 """ 986 987 def __init__(self, createfunc, scopefunc): 988 """Construct a new :class:`.ScopedRegistry`. 989 990 :param createfunc: A creation function that will generate 991 a new value for the current scope, if none is present. 992 993 :param scopefunc: A function that returns a hashable 994 token representing the current scope (such as, current 995 thread identifier). 996 997 """ 998 self.createfunc = createfunc 999 self.scopefunc = scopefunc 1000 self.registry = {} 1001 1002 def __call__(self): 1003 key = self.scopefunc() 1004 try: 1005 return self.registry[key] 1006 except KeyError: 1007 return self.registry.setdefault(key, self.createfunc()) 1008 1009 def has(self): 1010 """Return True if an object is present in the current scope.""" 1011 1012 return self.scopefunc() in self.registry 1013 1014 def set(self, obj): 1015 """Set the value for the current scope.""" 1016 1017 self.registry[self.scopefunc()] = obj 1018 1019 def clear(self): 1020 """Clear the current scope, if any.""" 1021 1022 try: 1023 del self.registry[self.scopefunc()] 1024 except KeyError: 1025 pass 1026 1027 1028class ThreadLocalRegistry(ScopedRegistry): 1029 """A :class:`.ScopedRegistry` that uses a ``threading.local()`` 1030 variable for storage. 1031 1032 """ 1033 1034 def __init__(self, createfunc): 1035 self.createfunc = createfunc 1036 self.registry = threading.local() 1037 1038 def __call__(self): 1039 try: 1040 return self.registry.value 1041 except AttributeError: 1042 val = self.registry.value = self.createfunc() 1043 return val 1044 1045 def has(self): 1046 return hasattr(self.registry, "value") 1047 1048 def set(self, obj): 1049 self.registry.value = obj 1050 1051 def clear(self): 1052 try: 1053 del self.registry.value 1054 except AttributeError: 1055 pass 1056 1057 1058def _iter_id(iterable): 1059 """Generator: ((id(o), o) for o in iterable).""" 1060 1061 for item in iterable: 1062 yield id(item), item 1063