1# Access WeakSet through the weakref module.
2# This code is separated-out because it is needed
3# by abc.py to load everything else at startup.
4
5from _weakref import ref
6
7__all__ = ['WeakSet']
8
9
10class _IterationGuard(object):
11    # This context manager registers itself in the current iterators of the
12    # weak container, such as to delay all removals until the context manager
13    # exits.
14    # This technique should be relatively thread-safe (since sets are).
15
16    def __init__(self, weakcontainer):
17        # Don't create cycles
18        self.weakcontainer = ref(weakcontainer)
19
20    def __enter__(self):
21        w = self.weakcontainer()
22        if w is not None:
23            w._iterating.add(self)
24        return self
25
26    def __exit__(self, e, t, b):
27        w = self.weakcontainer()
28        if w is not None:
29            s = w._iterating
30            s.remove(self)
31            if not s:
32                w._commit_removals()
33
34
35class WeakSet(object):
36    def __init__(self, data=None):
37        self.data = set()
38        def _remove(item, selfref=ref(self)):
39            self = selfref()
40            if self is not None:
41                if self._iterating:
42                    self._pending_removals.append(item)
43                else:
44                    self.data.discard(item)
45        self._remove = _remove
46        # A list of keys to be removed
47        self._pending_removals = []
48        self._iterating = set()
49        if data is not None:
50            self.update(data)
51
52    def _commit_removals(self):
53        l = self._pending_removals
54        discard = self.data.discard
55        while l:
56            discard(l.pop())
57
58    def __iter__(self):
59        with _IterationGuard(self):
60            for itemref in self.data:
61                item = itemref()
62                if item is not None:
63                    yield item
64
65    def __len__(self):
66        return sum(x() is not None for x in self.data)
67
68    def __contains__(self, item):
69        try:
70            wr = ref(item)
71        except TypeError:
72            return False
73        return wr in self.data
74
75    def __reduce__(self):
76        return (self.__class__, (list(self),),
77                getattr(self, '__dict__', None))
78
79    __hash__ = None
80
81    def add(self, item):
82        if self._pending_removals:
83            self._commit_removals()
84        self.data.add(ref(item, self._remove))
85
86    def clear(self):
87        if self._pending_removals:
88            self._commit_removals()
89        self.data.clear()
90
91    def copy(self):
92        return self.__class__(self)
93
94    def pop(self):
95        if self._pending_removals:
96            self._commit_removals()
97        while True:
98            try:
99                itemref = self.data.pop()
100            except KeyError:
101                raise KeyError('pop from empty WeakSet')
102            item = itemref()
103            if item is not None:
104                return item
105
106    def remove(self, item):
107        if self._pending_removals:
108            self._commit_removals()
109        self.data.remove(ref(item))
110
111    def discard(self, item):
112        if self._pending_removals:
113            self._commit_removals()
114        self.data.discard(ref(item))
115
116    def update(self, other):
117        if self._pending_removals:
118            self._commit_removals()
119        if isinstance(other, self.__class__):
120            self.data.update(other.data)
121        else:
122            for element in other:
123                self.add(element)
124
125    def __ior__(self, other):
126        self.update(other)
127        return self
128
129    # Helper functions for simple delegating methods.
130    def _apply(self, other, method):
131        if not isinstance(other, self.__class__):
132            other = self.__class__(other)
133        newdata = method(other.data)
134        newset = self.__class__()
135        newset.data = newdata
136        return newset
137
138    def difference(self, other):
139        return self._apply(other, self.data.difference)
140    __sub__ = difference
141
142    def difference_update(self, other):
143        if self._pending_removals:
144            self._commit_removals()
145        if self is other:
146            self.data.clear()
147        else:
148            self.data.difference_update(ref(item) for item in other)
149    def __isub__(self, other):
150        if self._pending_removals:
151            self._commit_removals()
152        if self is other:
153            self.data.clear()
154        else:
155            self.data.difference_update(ref(item) for item in other)
156        return self
157
158    def intersection(self, other):
159        return self._apply(other, self.data.intersection)
160    __and__ = intersection
161
162    def intersection_update(self, other):
163        if self._pending_removals:
164            self._commit_removals()
165        self.data.intersection_update(ref(item) for item in other)
166    def __iand__(self, other):
167        if self._pending_removals:
168            self._commit_removals()
169        self.data.intersection_update(ref(item) for item in other)
170        return self
171
172    def issubset(self, other):
173        return self.data.issubset(ref(item) for item in other)
174    __lt__ = issubset
175
176    def __le__(self, other):
177        return self.data <= set(ref(item) for item in other)
178
179    def issuperset(self, other):
180        return self.data.issuperset(ref(item) for item in other)
181    __gt__ = issuperset
182
183    def __ge__(self, other):
184        return self.data >= set(ref(item) for item in other)
185
186    def __eq__(self, other):
187        if not isinstance(other, self.__class__):
188            return NotImplemented
189        return self.data == set(ref(item) for item in other)
190
191    def symmetric_difference(self, other):
192        return self._apply(other, self.data.symmetric_difference)
193    __xor__ = symmetric_difference
194
195    def symmetric_difference_update(self, other):
196        if self._pending_removals:
197            self._commit_removals()
198        if self is other:
199            self.data.clear()
200        else:
201            self.data.symmetric_difference_update(ref(item) for item in other)
202    def __ixor__(self, other):
203        if self._pending_removals:
204            self._commit_removals()
205        if self is other:
206            self.data.clear()
207        else:
208            self.data.symmetric_difference_update(ref(item) for item in other)
209        return self
210
211    def union(self, other):
212        return self._apply(other, self.data.union)
213    __or__ = union
214
215    def isdisjoint(self, other):
216        return len(self.intersection(other)) == 0
217