1"""
2An OrderedSet is a custom MutableSet that remembers its order, so that every
3entry has an index that can be looked up.
4
5Based on a recipe originally posted to ActiveState Recipes by Raymond Hettiger,
6and released under the MIT license.
7"""
8import itertools as it
9from collections import deque
10
11try:
12    # Python 3
13    from collections.abc import MutableSet, Sequence
14except ImportError:
15    # Python 2.7
16    from collections import MutableSet, Sequence
17
18SLICE_ALL = slice(None)
19__version__ = "3.1"
20
21
22def is_iterable(obj):
23    """
24    Are we being asked to look up a list of things, instead of a single thing?
25    We check for the `__iter__` attribute so that this can cover types that
26    don't have to be known by this module, such as NumPy arrays.
27
28    Strings, however, should be considered as atomic values to look up, not
29    iterables. The same goes for tuples, since they are immutable and therefore
30    valid entries.
31
32    We don't need to check for the Python 2 `unicode` type, because it doesn't
33    have an `__iter__` attribute anyway.
34    """
35    return (
36        hasattr(obj, "__iter__")
37        and not isinstance(obj, str)
38        and not isinstance(obj, tuple)
39    )
40
41
42class OrderedSet(MutableSet, Sequence):
43    """
44    An OrderedSet is a custom MutableSet that remembers its order, so that
45    every entry has an index that can be looked up.
46
47    Example:
48        >>> OrderedSet([1, 1, 2, 3, 2])
49        OrderedSet([1, 2, 3])
50    """
51
52    def __init__(self, iterable=None):
53        self.items = []
54        self.map = {}
55        if iterable is not None:
56            self |= iterable
57
58    def __len__(self):
59        """
60        Returns the number of unique elements in the ordered set
61
62        Example:
63            >>> len(OrderedSet([]))
64            0
65            >>> len(OrderedSet([1, 2]))
66            2
67        """
68        return len(self.items)
69
70    def __getitem__(self, index):
71        """
72        Get the item at a given index.
73
74        If `index` is a slice, you will get back that slice of items, as a
75        new OrderedSet.
76
77        If `index` is a list or a similar iterable, you'll get a list of
78        items corresponding to those indices. This is similar to NumPy's
79        "fancy indexing". The result is not an OrderedSet because you may ask
80        for duplicate indices, and the number of elements returned should be
81        the number of elements asked for.
82
83        Example:
84            >>> oset = OrderedSet([1, 2, 3])
85            >>> oset[1]
86            2
87        """
88        if isinstance(index, slice) and index == SLICE_ALL:
89            return self.copy()
90        elif is_iterable(index):
91            return [self.items[i] for i in index]
92        elif hasattr(index, "__index__") or isinstance(index, slice):
93            result = self.items[index]
94            if isinstance(result, list):
95                return self.__class__(result)
96            else:
97                return result
98        else:
99            raise TypeError("Don't know how to index an OrderedSet by %r" % index)
100
101    def copy(self):
102        """
103        Return a shallow copy of this object.
104
105        Example:
106            >>> this = OrderedSet([1, 2, 3])
107            >>> other = this.copy()
108            >>> this == other
109            True
110            >>> this is other
111            False
112        """
113        return self.__class__(self)
114
115    def __getstate__(self):
116        if len(self) == 0:
117            # The state can't be an empty list.
118            # We need to return a truthy value, or else __setstate__ won't be run.
119            #
120            # This could have been done more gracefully by always putting the state
121            # in a tuple, but this way is backwards- and forwards- compatible with
122            # previous versions of OrderedSet.
123            return (None,)
124        else:
125            return list(self)
126
127    def __setstate__(self, state):
128        if state == (None,):
129            self.__init__([])
130        else:
131            self.__init__(state)
132
133    def __contains__(self, key):
134        """
135        Test if the item is in this ordered set
136
137        Example:
138            >>> 1 in OrderedSet([1, 3, 2])
139            True
140            >>> 5 in OrderedSet([1, 3, 2])
141            False
142        """
143        return key in self.map
144
145    def add(self, key):
146        """
147        Add `key` as an item to this OrderedSet, then return its index.
148
149        If `key` is already in the OrderedSet, return the index it already
150        had.
151
152        Example:
153            >>> oset = OrderedSet()
154            >>> oset.append(3)
155            0
156            >>> print(oset)
157            OrderedSet([3])
158        """
159        if key not in self.map:
160            self.map[key] = len(self.items)
161            self.items.append(key)
162        return self.map[key]
163
164    append = add
165
166    def update(self, sequence):
167        """
168        Update the set with the given iterable sequence, then return the index
169        of the last element inserted.
170
171        Example:
172            >>> oset = OrderedSet([1, 2, 3])
173            >>> oset.update([3, 1, 5, 1, 4])
174            4
175            >>> print(oset)
176            OrderedSet([1, 2, 3, 5, 4])
177        """
178        item_index = None
179        try:
180            for item in sequence:
181                item_index = self.add(item)
182        except TypeError:
183            raise ValueError(
184                "Argument needs to be an iterable, got %s" % type(sequence)
185            )
186        return item_index
187
188    def index(self, key):
189        """
190        Get the index of a given entry, raising an IndexError if it's not
191        present.
192
193        `key` can be an iterable of entries that is not a string, in which case
194        this returns a list of indices.
195
196        Example:
197            >>> oset = OrderedSet([1, 2, 3])
198            >>> oset.index(2)
199            1
200        """
201        if is_iterable(key):
202            return [self.index(subkey) for subkey in key]
203        return self.map[key]
204
205    # Provide some compatibility with pd.Index
206    get_loc = index
207    get_indexer = index
208
209    def pop(self):
210        """
211        Remove and return the last element from the set.
212
213        Raises KeyError if the set is empty.
214
215        Example:
216            >>> oset = OrderedSet([1, 2, 3])
217            >>> oset.pop()
218            3
219        """
220        if not self.items:
221            raise KeyError("Set is empty")
222
223        elem = self.items[-1]
224        del self.items[-1]
225        del self.map[elem]
226        return elem
227
228    def discard(self, key):
229        """
230        Remove an element.  Do not raise an exception if absent.
231
232        The MutableSet mixin uses this to implement the .remove() method, which
233        *does* raise an error when asked to remove a non-existent item.
234
235        Example:
236            >>> oset = OrderedSet([1, 2, 3])
237            >>> oset.discard(2)
238            >>> print(oset)
239            OrderedSet([1, 3])
240            >>> oset.discard(2)
241            >>> print(oset)
242            OrderedSet([1, 3])
243        """
244        if key in self:
245            i = self.map[key]
246            del self.items[i]
247            del self.map[key]
248            for k, v in self.map.items():
249                if v >= i:
250                    self.map[k] = v - 1
251
252    def clear(self):
253        """
254        Remove all items from this OrderedSet.
255        """
256        del self.items[:]
257        self.map.clear()
258
259    def __iter__(self):
260        """
261        Example:
262            >>> list(iter(OrderedSet([1, 2, 3])))
263            [1, 2, 3]
264        """
265        return iter(self.items)
266
267    def __reversed__(self):
268        """
269        Example:
270            >>> list(reversed(OrderedSet([1, 2, 3])))
271            [3, 2, 1]
272        """
273        return reversed(self.items)
274
275    def __repr__(self):
276        if not self:
277            return "%s()" % (self.__class__.__name__,)
278        return "%s(%r)" % (self.__class__.__name__, list(self))
279
280    def __eq__(self, other):
281        """
282        Returns true if the containers have the same items. If `other` is a
283        Sequence, then order is checked, otherwise it is ignored.
284
285        Example:
286            >>> oset = OrderedSet([1, 3, 2])
287            >>> oset == [1, 3, 2]
288            True
289            >>> oset == [1, 2, 3]
290            False
291            >>> oset == [2, 3]
292            False
293            >>> oset == OrderedSet([3, 2, 1])
294            False
295        """
296        # In Python 2 deque is not a Sequence, so treat it as one for
297        # consistent behavior with Python 3.
298        if isinstance(other, (Sequence, deque)):
299            # Check that this OrderedSet contains the same elements, in the
300            # same order, as the other object.
301            return list(self) == list(other)
302        try:
303            other_as_set = set(other)
304        except TypeError:
305            # If `other` can't be converted into a set, it's not equal.
306            return False
307        else:
308            return set(self) == other_as_set
309
310    def union(self, *sets):
311        """
312        Combines all unique items.
313        Each items order is defined by its first appearance.
314
315        Example:
316            >>> oset = OrderedSet.union(OrderedSet([3, 1, 4, 1, 5]), [1, 3], [2, 0])
317            >>> print(oset)
318            OrderedSet([3, 1, 4, 5, 2, 0])
319            >>> oset.union([8, 9])
320            OrderedSet([3, 1, 4, 5, 2, 0, 8, 9])
321            >>> oset | {10}
322            OrderedSet([3, 1, 4, 5, 2, 0, 10])
323        """
324        cls = self.__class__ if isinstance(self, OrderedSet) else OrderedSet
325        containers = map(list, it.chain([self], sets))
326        items = it.chain.from_iterable(containers)
327        return cls(items)
328
329    def __and__(self, other):
330        # the parent implementation of this is backwards
331        return self.intersection(other)
332
333    def intersection(self, *sets):
334        """
335        Returns elements in common between all sets. Order is defined only
336        by the first set.
337
338        Example:
339            >>> oset = OrderedSet.intersection(OrderedSet([0, 1, 2, 3]), [1, 2, 3])
340            >>> print(oset)
341            OrderedSet([1, 2, 3])
342            >>> oset.intersection([2, 4, 5], [1, 2, 3, 4])
343            OrderedSet([2])
344            >>> oset.intersection()
345            OrderedSet([1, 2, 3])
346        """
347        cls = self.__class__ if isinstance(self, OrderedSet) else OrderedSet
348        if sets:
349            common = set.intersection(*map(set, sets))
350            items = (item for item in self if item in common)
351        else:
352            items = self
353        return cls(items)
354
355    def difference(self, *sets):
356        """
357        Returns all elements that are in this set but not the others.
358
359        Example:
360            >>> OrderedSet([1, 2, 3]).difference(OrderedSet([2]))
361            OrderedSet([1, 3])
362            >>> OrderedSet([1, 2, 3]).difference(OrderedSet([2]), OrderedSet([3]))
363            OrderedSet([1])
364            >>> OrderedSet([1, 2, 3]) - OrderedSet([2])
365            OrderedSet([1, 3])
366            >>> OrderedSet([1, 2, 3]).difference()
367            OrderedSet([1, 2, 3])
368        """
369        cls = self.__class__
370        if sets:
371            other = set.union(*map(set, sets))
372            items = (item for item in self if item not in other)
373        else:
374            items = self
375        return cls(items)
376
377    def issubset(self, other):
378        """
379        Report whether another set contains this set.
380
381        Example:
382            >>> OrderedSet([1, 2, 3]).issubset({1, 2})
383            False
384            >>> OrderedSet([1, 2, 3]).issubset({1, 2, 3, 4})
385            True
386            >>> OrderedSet([1, 2, 3]).issubset({1, 4, 3, 5})
387            False
388        """
389        if len(self) > len(other):  # Fast check for obvious cases
390            return False
391        return all(item in other for item in self)
392
393    def issuperset(self, other):
394        """
395        Report whether this set contains another set.
396
397        Example:
398            >>> OrderedSet([1, 2]).issuperset([1, 2, 3])
399            False
400            >>> OrderedSet([1, 2, 3, 4]).issuperset({1, 2, 3})
401            True
402            >>> OrderedSet([1, 4, 3, 5]).issuperset({1, 2, 3})
403            False
404        """
405        if len(self) < len(other):  # Fast check for obvious cases
406            return False
407        return all(item in self for item in other)
408
409    def symmetric_difference(self, other):
410        """
411        Return the symmetric difference of two OrderedSets as a new set.
412        That is, the new set will contain all elements that are in exactly
413        one of the sets.
414
415        Their order will be preserved, with elements from `self` preceding
416        elements from `other`.
417
418        Example:
419            >>> this = OrderedSet([1, 4, 3, 5, 7])
420            >>> other = OrderedSet([9, 7, 1, 3, 2])
421            >>> this.symmetric_difference(other)
422            OrderedSet([4, 5, 9, 2])
423        """
424        cls = self.__class__ if isinstance(self, OrderedSet) else OrderedSet
425        diff1 = cls(self).difference(other)
426        diff2 = cls(other).difference(self)
427        return diff1.union(diff2)
428
429    def _update_items(self, items):
430        """
431        Replace the 'items' list of this OrderedSet with a new one, updating
432        self.map accordingly.
433        """
434        self.items = items
435        self.map = {item: idx for (idx, item) in enumerate(items)}
436
437    def difference_update(self, *sets):
438        """
439        Update this OrderedSet to remove items from one or more other sets.
440
441        Example:
442            >>> this = OrderedSet([1, 2, 3])
443            >>> this.difference_update(OrderedSet([2, 4]))
444            >>> print(this)
445            OrderedSet([1, 3])
446
447            >>> this = OrderedSet([1, 2, 3, 4, 5])
448            >>> this.difference_update(OrderedSet([2, 4]), OrderedSet([1, 4, 6]))
449            >>> print(this)
450            OrderedSet([3, 5])
451        """
452        items_to_remove = set()
453        for other in sets:
454            items_to_remove |= set(other)
455        self._update_items([item for item in self.items if item not in items_to_remove])
456
457    def intersection_update(self, other):
458        """
459        Update this OrderedSet to keep only items in another set, preserving
460        their order in this set.
461
462        Example:
463            >>> this = OrderedSet([1, 4, 3, 5, 7])
464            >>> other = OrderedSet([9, 7, 1, 3, 2])
465            >>> this.intersection_update(other)
466            >>> print(this)
467            OrderedSet([1, 3, 7])
468        """
469        other = set(other)
470        self._update_items([item for item in self.items if item in other])
471
472    def symmetric_difference_update(self, other):
473        """
474        Update this OrderedSet to remove items from another set, then
475        add items from the other set that were not present in this set.
476
477        Example:
478            >>> this = OrderedSet([1, 4, 3, 5, 7])
479            >>> other = OrderedSet([9, 7, 1, 3, 2])
480            >>> this.symmetric_difference_update(other)
481            >>> print(this)
482            OrderedSet([4, 5, 9, 2])
483        """
484        items_to_add = [item for item in other if item not in self]
485        items_to_remove = set(other)
486        self._update_items(
487            [item for item in self.items if item not in items_to_remove] + items_to_add
488        )
489