1# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX
2# All rights reserved.
3#
4# This software is provided without warranty under the terms of the BSD
5# license included in LICENSE.txt and may be redistributed only under
6# the conditions described in the aforementioned license. The license
7# is also available online at http://www.enthought.com/licenses/BSD.txt
8#
9# Thanks for using Enthought open source!
10
11import copy
12import copyreg
13from itertools import chain
14from weakref import ref
15
16from traits.observation.i_observable import IObservable
17from traits.trait_base import _validate_everything
18from traits.trait_errors import TraitError
19
20
21class TraitSetEvent(object):
22    """ An object reporting in-place changes to a traits sets.
23
24    Parameters
25    ----------
26    removed : set, optional
27        Old values that were removed from the set.
28    added : set, optional
29        New values added to the set.
30
31    Attributes
32    ----------
33    removed : set
34        Old values that were removed from the set.
35    added : set
36        New values added to the set.
37    """
38
39    def __init__(self, *, removed=None, added=None):
40
41        if removed is None:
42            removed = set()
43        self.removed = removed
44
45        if added is None:
46            added = set()
47        self.added = added
48
49    def __repr__(self):
50        return (
51            f"{self.__class__.__name__}("
52            f"removed={self.removed!r}, "
53            f"added={self.added!r})"
54        )
55
56
57@IObservable.register
58class TraitSet(set):
59    """ A subclass of set that validates and notifies listeners of changes.
60
61    Parameters
62    ----------
63    value : iterable, optional
64        Iterable providing the items for the set.
65    item_validator : callable, optional
66        Called to validate and/or transform items added to the set. The
67        callable should accept a single item and return the transformed
68        item, raising TraitError for invalid items. If not given, no
69        item validation is performed.
70    notifiers : list of callable, optional
71        A list of callables with the signature::
72
73            notifier(trait_set, removed, added)
74
75        Where 'added' is a set containing new values that have been added.
76        And 'removed' is a set containing old values that have been removed.
77
78        If this argument is not given, the list of notifiers is initially
79        empty.
80
81    Attributes
82    ----------
83    item_validator : callable
84        Called to validate and/or transform items added to the set. The
85        callable should accept a single item and return the transformed
86        item, raising TraitError for invalid items.
87    notifiers : list of callable
88        A list of callables with the signature::
89
90            notifier(trait_set, removed, added)
91
92        where 'added' is a set containing new values that have been added
93        and 'removed' is a set containing old values that have been removed.
94    """
95
96    def __new__(cls, *args, **kwargs):
97        self = super().__new__(cls)
98        self.item_validator = _validate_everything
99        self.notifiers = []
100        return self
101
102    def __init__(self, value=(), *, item_validator=None, notifiers=None):
103        if item_validator is not None:
104            self.item_validator = item_validator
105        super().__init__(self.item_validator(item) for item in value)
106        if notifiers is not None:
107            self.notifiers = notifiers
108
109    def notify(self, removed, added):
110        """ Call all notifiers.
111
112        This simply calls all notifiers provided by the class, if any.
113        The notifiers are expected to have the signature::
114
115            notifier(trait_set, removed, added)
116
117        Any return values are ignored. Any exceptions raised are not
118        handled. Notifiers are therefore expected not to raise any
119        exceptions under normal use.
120
121        Parameters
122        ----------
123        removed : set
124            The items that have been removed.
125        added : set
126            The new items that have been added to the set.
127        """
128        for notifier in self.notifiers:
129            notifier(self, removed, added)
130
131    # -- set interface -------------------------------------------------------
132
133    def __iand__(self, value):
134        """  Return self &= value.
135
136        Parameters
137        ----------
138        value : set or frozenset
139            A value.
140
141        Returns
142        -------
143        self : TraitSet
144            The updated set.
145        """
146
147        old_set = self.copy()
148        retval = super().__iand__(value)
149        removed = old_set.difference(self)
150
151        if len(removed) > 0:
152            self.notify(removed, set())
153
154        return retval
155
156    def __ior__(self, value):
157        """ Return self |= value.
158
159        Parameters
160        ----------
161        value : set or frozenset
162            A value.
163
164        Returns
165        -------
166        self : TraitSet
167            The updated set.
168        """
169        old_set = self.copy()
170
171        # Validate each item in value, only if value is a set or frozenset.
172        # We do not want to convert any other iterable type to a set
173        # so that super().__ior__ raises the appropriate error message
174        # for all other iterables.
175        if isinstance(value, (set, frozenset)):
176            value = {self.item_validator(item)
177                     for item in value}
178
179        retval = super().__ior__(value)
180
181        added = self.difference(old_set)
182
183        if len(added) > 0:
184            self.notify(set(), added)
185
186        return retval
187
188    def __isub__(self, value):
189        """ Return self-=value.
190
191        Parameters
192        ----------
193        value : set or frozenset
194            A value.
195
196        Returns
197        -------
198        self : TraitSet
199            The updated set.
200        """
201
202        old_set = self.copy()
203        retval = super().__isub__(value)
204        removed = old_set.difference(self)
205
206        if len(removed) > 0:
207            self.notify(removed, set())
208
209        return retval
210
211    def __ixor__(self, value):
212        """ Return self ^= value.
213
214        Parameters
215        ----------
216        value : set or frozenset
217            A value.
218
219        Returns
220        -------
221        self : TraitSet
222            The updated set.
223        """
224
225        removed = set()
226        added = set()
227
228        # Validate each item in value, only if value is a set or frozenset.
229        # We do not want to convert any other iterable type to a set
230        # so that super().__ixor__ raises the appropriate error message
231        # for all other iterables.
232        if isinstance(value, (set, frozenset)):
233            values = set(value)
234            removed = self.intersection(values)
235            raw_added = values.difference(removed)
236            validated_added = {self.item_validator(item) for item in
237                               raw_added}
238            added = validated_added.difference(self)
239            value = added | removed
240
241        retval = super().__ixor__(value)
242
243        if removed or added:
244            self.notify(removed, added)
245
246        return retval
247
248    def add(self, value):
249        """ Add an element to a set.
250
251        This has no effect if the element is already present.
252
253        Parameters
254        ----------
255        value : any
256            The value to add to the set.
257        """
258
259        value = self.item_validator(value)
260        value_in_self = value in self
261        super().add(value)
262        if not value_in_self:
263            self.notify(set(), {value})
264
265    def clear(self):
266        """ Remove all elements from this set. """
267
268        removed = set(self)
269        super().clear()
270        if removed:
271            self.notify(removed, set())
272
273    def discard(self, value):
274        """ Remove an element from the set if it is a member.
275
276        If the element is not a member, do nothing.
277
278        Parameters
279        ----------
280        value : any
281            An item in the set
282        """
283
284        value_in_self = value in self
285        super().discard(value)
286
287        if value_in_self:
288            self.notify({value}, set())
289
290    def difference_update(self, *args):
291        """  Remove all elements of another set from this set.
292
293        Parameters
294        ----------
295        args : iterables
296            The other iterables.
297        """
298
299        old_set = self.copy()
300        super().difference_update(*args)
301        removed = old_set.difference(self)
302
303        if len(removed) > 0:
304            self.notify(removed, set())
305
306    def intersection_update(self, *args):
307        """  Update the set with the intersection of itself and another set.
308
309        Parameters
310        ----------
311        args : iterables
312            The other iterables.
313        """
314
315        old_set = self.copy()
316        super().intersection_update(*args)
317        removed = old_set.difference(self)
318
319        if len(removed) > 0:
320            self.notify(removed, set())
321
322    def pop(self):
323        """ Remove and return an arbitrary set element.
324
325        Raises KeyError if the set is empty.
326
327        Returns
328        -------
329        item : any
330            An element from the set.
331
332        Raises
333        ------
334        KeyError
335            If the set is empty.
336        """
337
338        removed = super().pop()
339        self.notify({removed}, set())
340        return removed
341
342    def remove(self, value):
343        """ Remove an element that is a member of the set.
344
345        If the element is not a member, raise a KeyError.
346
347        Parameters
348        ----------
349        value : any
350            An element in the set
351
352        Raises
353        ------
354        KeyError
355            If the value is not found in the set.
356        """
357
358        super().remove(value)
359        self.notify({value}, set())
360
361    def symmetric_difference_update(self, value):
362        """ Update the set with the symmetric difference of itself and another.
363
364        Parameters
365        ----------
366        value : iterable
367        """
368
369        values = set(value)
370        removed = self.intersection(values)
371        raw_result = values.difference(removed)
372        validated_result = {self.item_validator(item) for item in raw_result}
373        added = validated_result.difference(self)
374
375        super().symmetric_difference_update(removed | added)
376        if removed or added:
377            self.notify(removed, added)
378
379    def update(self, *args):
380        """ Update the set with the union of itself and others.
381
382        Parameters
383        ----------
384        args : iterables
385            The other iterables.
386        """
387
388        validated_values = {self.item_validator(item)
389                            for item in chain.from_iterable(args)}
390        added = validated_values.difference(self)
391        super().update(added)
392
393        if len(added) > 0:
394            self.notify(set(), added)
395
396    # -- pickle and copy support ----------------------------------------------
397
398    def __deepcopy__(self, memo):
399        """ Perform a deepcopy operation.
400
401        Notifiers are transient and should not be copied.
402        """
403        # notifiers are transient and should not be copied
404        result = TraitSet(
405            [copy.deepcopy(x, memo) for x in self],
406            item_validator=copy.deepcopy(self.validator, memo),
407            notifiers=[],
408        )
409
410        return result
411
412    def __getstate__(self):
413        """ Get the state of the object for serialization.
414
415        Notifiers are transient and should not be serialized.
416        """
417        result = self.__dict__.copy()
418        # notifiers are transient and should not be serialized
419        del result["notifiers"]
420        return result
421
422    def __setstate__(self, state):
423        """ Restore the state of the object after serialization.
424
425        Notifiers are transient and are restored to the empty list.
426        """
427        state['notifiers'] = []
428        self.__dict__.update(state)
429
430    # -- Implement IObservable ------------------------------------------------
431
432    def _notifiers(self, force_create):
433        """ Return a list of callables where each callable is a notifier.
434        The list is expected to be mutated for contributing or removing
435        notifiers from the object.
436
437        Parameters
438        ----------
439        force_create: boolean
440            Not used here.
441        """
442        return self.notifiers
443
444
445class TraitSetObject(TraitSet):
446    """ A specialization of TraitSet with a default validator and notifier
447    for compatibility with Traits versions before 6.0.
448
449    Parameters
450    ----------
451    trait : CTrait
452        The trait that the set has been assigned to.
453    object : HasTraits
454        The object the set belongs to.
455    name : str
456        The name of the trait on the object.
457    value : iterable
458        The initial value of the set.
459
460    Attributes
461    ----------
462    trait : CTrait
463        The trait that the set has been assigned to.
464    object : HasTraits
465        The object the set belongs to.
466    name : str
467        The name of the trait on the object.
468    value : iterable
469        The initial value of the set.
470    """
471
472    def __init__(self, trait, object, name, value):
473
474        self.trait = trait
475        self.object = ref(object)
476        self.name = name
477        self.name_items = None
478        if trait.has_items:
479            self.name_items = name + "_items"
480
481        super().__init__(value, item_validator=self._validator,
482                         notifiers=[self.notifier])
483
484    def _validator(self, value):
485        """ Validates the value by calling the inner trait's validate method.
486
487        Parameters
488        ----------
489        value : any
490            The value to be validated.
491
492        Returns
493        -------
494        value : any
495            The validated value.
496
497        Raises
498        ------
499        TraitError
500            On validation failure for the inner trait.
501        """
502
503        object_ref = getattr(self, 'object', None)
504        trait = getattr(self, 'trait', None)
505
506        if object_ref is None or trait is None:
507            return value
508
509        object = object_ref()
510
511        # validate the new value(s)
512        validate = trait.item_trait.handler.validate
513
514        if validate is None:
515            return value
516
517        try:
518            return validate(object, self.name, value)
519        except TraitError as excp:
520            excp.set_prefix("Each element of the")
521            raise excp
522
523    def notifier(self, trait_set, removed, added):
524        """ Converts and consolidates the parameters to a TraitSetEvent and
525        then fires the event.
526
527        Parameters
528        ----------
529        trait_set : set
530            The complete set
531        removed : set
532            Set of values that were removed.
533        added : set
534            Set of values that were added.
535        """
536
537        if self.name_items is None:
538            return
539
540        object = self.object()
541        if object is None:
542            return
543
544        if getattr(object, self.name) is not self:
545            # Workaround having this set inside another container which
546            # also uses the name_items trait for notification.
547            # Similar to enthought/traits#25
548            return
549
550        event = TraitSetEvent(removed=removed, added=added)
551        items_event = self.trait.items_event()
552        object.trait_items_event(self.name_items, event, items_event)
553
554    # -- pickle and copy support ----------------------------------------------
555    def __deepcopy__(self, memo):
556        """ Perform a deepcopy operation.
557
558        Notifiers are transient and should not be copied.
559        """
560
561        result = TraitSetObject(
562            self.trait,
563            lambda: None,
564            self.name,
565            {copy.deepcopy(x, memo) for x in self},
566        )
567
568        return result
569
570    def __getstate__(self):
571        """ Get the state of the object for serialization.
572
573        Notifiers are transient and should not be serialized.
574        """
575
576        result = super().__getstate__()
577        del result["object"]
578        del result["trait"]
579        return result
580
581    def __setstate__(self, state):
582        """ Restore the state of the object after serialization.
583
584        Notifiers are transient and are restored to the empty list.
585        """
586
587        state.setdefault("name", "")
588        state["notifiers"] = [self.notifier]
589        state["object"] = lambda: None
590        state["trait"] = None
591        self.__dict__.update(state)
592
593    def __reduce_ex__(self, protocol=None):
594        """ Overridden to make sure we call our custom __getstate__.
595        """
596        return (
597            copyreg._reconstructor,
598            (type(self), set, list(self)),
599            self.__getstate__(),
600        )
601