1# (C) Copyright 2005-2021 Enthought, Inc., Austin, TX 2# All rights reserved. 3# 4# This software is provided without warranty under the terms of the BSD 5# license included in LICENSE.txt and may be redistributed only under 6# the conditions described in the aforementioned license. The license 7# is also available online at http://www.enthought.com/licenses/BSD.txt 8# 9# Thanks for using Enthought open source! 10 11import copy 12import copyreg 13from itertools import chain 14from weakref import ref 15 16from traits.observation.i_observable import IObservable 17from traits.trait_base import _validate_everything 18from traits.trait_errors import TraitError 19 20 21class TraitSetEvent(object): 22 """ An object reporting in-place changes to a traits sets. 23 24 Parameters 25 ---------- 26 removed : set, optional 27 Old values that were removed from the set. 28 added : set, optional 29 New values added to the set. 30 31 Attributes 32 ---------- 33 removed : set 34 Old values that were removed from the set. 35 added : set 36 New values added to the set. 37 """ 38 39 def __init__(self, *, removed=None, added=None): 40 41 if removed is None: 42 removed = set() 43 self.removed = removed 44 45 if added is None: 46 added = set() 47 self.added = added 48 49 def __repr__(self): 50 return ( 51 f"{self.__class__.__name__}(" 52 f"removed={self.removed!r}, " 53 f"added={self.added!r})" 54 ) 55 56 57@IObservable.register 58class TraitSet(set): 59 """ A subclass of set that validates and notifies listeners of changes. 60 61 Parameters 62 ---------- 63 value : iterable, optional 64 Iterable providing the items for the set. 65 item_validator : callable, optional 66 Called to validate and/or transform items added to the set. The 67 callable should accept a single item and return the transformed 68 item, raising TraitError for invalid items. If not given, no 69 item validation is performed. 70 notifiers : list of callable, optional 71 A list of callables with the signature:: 72 73 notifier(trait_set, removed, added) 74 75 Where 'added' is a set containing new values that have been added. 76 And 'removed' is a set containing old values that have been removed. 77 78 If this argument is not given, the list of notifiers is initially 79 empty. 80 81 Attributes 82 ---------- 83 item_validator : callable 84 Called to validate and/or transform items added to the set. The 85 callable should accept a single item and return the transformed 86 item, raising TraitError for invalid items. 87 notifiers : list of callable 88 A list of callables with the signature:: 89 90 notifier(trait_set, removed, added) 91 92 where 'added' is a set containing new values that have been added 93 and 'removed' is a set containing old values that have been removed. 94 """ 95 96 def __new__(cls, *args, **kwargs): 97 self = super().__new__(cls) 98 self.item_validator = _validate_everything 99 self.notifiers = [] 100 return self 101 102 def __init__(self, value=(), *, item_validator=None, notifiers=None): 103 if item_validator is not None: 104 self.item_validator = item_validator 105 super().__init__(self.item_validator(item) for item in value) 106 if notifiers is not None: 107 self.notifiers = notifiers 108 109 def notify(self, removed, added): 110 """ Call all notifiers. 111 112 This simply calls all notifiers provided by the class, if any. 113 The notifiers are expected to have the signature:: 114 115 notifier(trait_set, removed, added) 116 117 Any return values are ignored. Any exceptions raised are not 118 handled. Notifiers are therefore expected not to raise any 119 exceptions under normal use. 120 121 Parameters 122 ---------- 123 removed : set 124 The items that have been removed. 125 added : set 126 The new items that have been added to the set. 127 """ 128 for notifier in self.notifiers: 129 notifier(self, removed, added) 130 131 # -- set interface ------------------------------------------------------- 132 133 def __iand__(self, value): 134 """ Return self &= value. 135 136 Parameters 137 ---------- 138 value : set or frozenset 139 A value. 140 141 Returns 142 ------- 143 self : TraitSet 144 The updated set. 145 """ 146 147 old_set = self.copy() 148 retval = super().__iand__(value) 149 removed = old_set.difference(self) 150 151 if len(removed) > 0: 152 self.notify(removed, set()) 153 154 return retval 155 156 def __ior__(self, value): 157 """ Return self |= value. 158 159 Parameters 160 ---------- 161 value : set or frozenset 162 A value. 163 164 Returns 165 ------- 166 self : TraitSet 167 The updated set. 168 """ 169 old_set = self.copy() 170 171 # Validate each item in value, only if value is a set or frozenset. 172 # We do not want to convert any other iterable type to a set 173 # so that super().__ior__ raises the appropriate error message 174 # for all other iterables. 175 if isinstance(value, (set, frozenset)): 176 value = {self.item_validator(item) 177 for item in value} 178 179 retval = super().__ior__(value) 180 181 added = self.difference(old_set) 182 183 if len(added) > 0: 184 self.notify(set(), added) 185 186 return retval 187 188 def __isub__(self, value): 189 """ Return self-=value. 190 191 Parameters 192 ---------- 193 value : set or frozenset 194 A value. 195 196 Returns 197 ------- 198 self : TraitSet 199 The updated set. 200 """ 201 202 old_set = self.copy() 203 retval = super().__isub__(value) 204 removed = old_set.difference(self) 205 206 if len(removed) > 0: 207 self.notify(removed, set()) 208 209 return retval 210 211 def __ixor__(self, value): 212 """ Return self ^= value. 213 214 Parameters 215 ---------- 216 value : set or frozenset 217 A value. 218 219 Returns 220 ------- 221 self : TraitSet 222 The updated set. 223 """ 224 225 removed = set() 226 added = set() 227 228 # Validate each item in value, only if value is a set or frozenset. 229 # We do not want to convert any other iterable type to a set 230 # so that super().__ixor__ raises the appropriate error message 231 # for all other iterables. 232 if isinstance(value, (set, frozenset)): 233 values = set(value) 234 removed = self.intersection(values) 235 raw_added = values.difference(removed) 236 validated_added = {self.item_validator(item) for item in 237 raw_added} 238 added = validated_added.difference(self) 239 value = added | removed 240 241 retval = super().__ixor__(value) 242 243 if removed or added: 244 self.notify(removed, added) 245 246 return retval 247 248 def add(self, value): 249 """ Add an element to a set. 250 251 This has no effect if the element is already present. 252 253 Parameters 254 ---------- 255 value : any 256 The value to add to the set. 257 """ 258 259 value = self.item_validator(value) 260 value_in_self = value in self 261 super().add(value) 262 if not value_in_self: 263 self.notify(set(), {value}) 264 265 def clear(self): 266 """ Remove all elements from this set. """ 267 268 removed = set(self) 269 super().clear() 270 if removed: 271 self.notify(removed, set()) 272 273 def discard(self, value): 274 """ Remove an element from the set if it is a member. 275 276 If the element is not a member, do nothing. 277 278 Parameters 279 ---------- 280 value : any 281 An item in the set 282 """ 283 284 value_in_self = value in self 285 super().discard(value) 286 287 if value_in_self: 288 self.notify({value}, set()) 289 290 def difference_update(self, *args): 291 """ Remove all elements of another set from this set. 292 293 Parameters 294 ---------- 295 args : iterables 296 The other iterables. 297 """ 298 299 old_set = self.copy() 300 super().difference_update(*args) 301 removed = old_set.difference(self) 302 303 if len(removed) > 0: 304 self.notify(removed, set()) 305 306 def intersection_update(self, *args): 307 """ Update the set with the intersection of itself and another set. 308 309 Parameters 310 ---------- 311 args : iterables 312 The other iterables. 313 """ 314 315 old_set = self.copy() 316 super().intersection_update(*args) 317 removed = old_set.difference(self) 318 319 if len(removed) > 0: 320 self.notify(removed, set()) 321 322 def pop(self): 323 """ Remove and return an arbitrary set element. 324 325 Raises KeyError if the set is empty. 326 327 Returns 328 ------- 329 item : any 330 An element from the set. 331 332 Raises 333 ------ 334 KeyError 335 If the set is empty. 336 """ 337 338 removed = super().pop() 339 self.notify({removed}, set()) 340 return removed 341 342 def remove(self, value): 343 """ Remove an element that is a member of the set. 344 345 If the element is not a member, raise a KeyError. 346 347 Parameters 348 ---------- 349 value : any 350 An element in the set 351 352 Raises 353 ------ 354 KeyError 355 If the value is not found in the set. 356 """ 357 358 super().remove(value) 359 self.notify({value}, set()) 360 361 def symmetric_difference_update(self, value): 362 """ Update the set with the symmetric difference of itself and another. 363 364 Parameters 365 ---------- 366 value : iterable 367 """ 368 369 values = set(value) 370 removed = self.intersection(values) 371 raw_result = values.difference(removed) 372 validated_result = {self.item_validator(item) for item in raw_result} 373 added = validated_result.difference(self) 374 375 super().symmetric_difference_update(removed | added) 376 if removed or added: 377 self.notify(removed, added) 378 379 def update(self, *args): 380 """ Update the set with the union of itself and others. 381 382 Parameters 383 ---------- 384 args : iterables 385 The other iterables. 386 """ 387 388 validated_values = {self.item_validator(item) 389 for item in chain.from_iterable(args)} 390 added = validated_values.difference(self) 391 super().update(added) 392 393 if len(added) > 0: 394 self.notify(set(), added) 395 396 # -- pickle and copy support ---------------------------------------------- 397 398 def __deepcopy__(self, memo): 399 """ Perform a deepcopy operation. 400 401 Notifiers are transient and should not be copied. 402 """ 403 # notifiers are transient and should not be copied 404 result = TraitSet( 405 [copy.deepcopy(x, memo) for x in self], 406 item_validator=copy.deepcopy(self.validator, memo), 407 notifiers=[], 408 ) 409 410 return result 411 412 def __getstate__(self): 413 """ Get the state of the object for serialization. 414 415 Notifiers are transient and should not be serialized. 416 """ 417 result = self.__dict__.copy() 418 # notifiers are transient and should not be serialized 419 del result["notifiers"] 420 return result 421 422 def __setstate__(self, state): 423 """ Restore the state of the object after serialization. 424 425 Notifiers are transient and are restored to the empty list. 426 """ 427 state['notifiers'] = [] 428 self.__dict__.update(state) 429 430 # -- Implement IObservable ------------------------------------------------ 431 432 def _notifiers(self, force_create): 433 """ Return a list of callables where each callable is a notifier. 434 The list is expected to be mutated for contributing or removing 435 notifiers from the object. 436 437 Parameters 438 ---------- 439 force_create: boolean 440 Not used here. 441 """ 442 return self.notifiers 443 444 445class TraitSetObject(TraitSet): 446 """ A specialization of TraitSet with a default validator and notifier 447 for compatibility with Traits versions before 6.0. 448 449 Parameters 450 ---------- 451 trait : CTrait 452 The trait that the set has been assigned to. 453 object : HasTraits 454 The object the set belongs to. 455 name : str 456 The name of the trait on the object. 457 value : iterable 458 The initial value of the set. 459 460 Attributes 461 ---------- 462 trait : CTrait 463 The trait that the set has been assigned to. 464 object : HasTraits 465 The object the set belongs to. 466 name : str 467 The name of the trait on the object. 468 value : iterable 469 The initial value of the set. 470 """ 471 472 def __init__(self, trait, object, name, value): 473 474 self.trait = trait 475 self.object = ref(object) 476 self.name = name 477 self.name_items = None 478 if trait.has_items: 479 self.name_items = name + "_items" 480 481 super().__init__(value, item_validator=self._validator, 482 notifiers=[self.notifier]) 483 484 def _validator(self, value): 485 """ Validates the value by calling the inner trait's validate method. 486 487 Parameters 488 ---------- 489 value : any 490 The value to be validated. 491 492 Returns 493 ------- 494 value : any 495 The validated value. 496 497 Raises 498 ------ 499 TraitError 500 On validation failure for the inner trait. 501 """ 502 503 object_ref = getattr(self, 'object', None) 504 trait = getattr(self, 'trait', None) 505 506 if object_ref is None or trait is None: 507 return value 508 509 object = object_ref() 510 511 # validate the new value(s) 512 validate = trait.item_trait.handler.validate 513 514 if validate is None: 515 return value 516 517 try: 518 return validate(object, self.name, value) 519 except TraitError as excp: 520 excp.set_prefix("Each element of the") 521 raise excp 522 523 def notifier(self, trait_set, removed, added): 524 """ Converts and consolidates the parameters to a TraitSetEvent and 525 then fires the event. 526 527 Parameters 528 ---------- 529 trait_set : set 530 The complete set 531 removed : set 532 Set of values that were removed. 533 added : set 534 Set of values that were added. 535 """ 536 537 if self.name_items is None: 538 return 539 540 object = self.object() 541 if object is None: 542 return 543 544 if getattr(object, self.name) is not self: 545 # Workaround having this set inside another container which 546 # also uses the name_items trait for notification. 547 # Similar to enthought/traits#25 548 return 549 550 event = TraitSetEvent(removed=removed, added=added) 551 items_event = self.trait.items_event() 552 object.trait_items_event(self.name_items, event, items_event) 553 554 # -- pickle and copy support ---------------------------------------------- 555 def __deepcopy__(self, memo): 556 """ Perform a deepcopy operation. 557 558 Notifiers are transient and should not be copied. 559 """ 560 561 result = TraitSetObject( 562 self.trait, 563 lambda: None, 564 self.name, 565 {copy.deepcopy(x, memo) for x in self}, 566 ) 567 568 return result 569 570 def __getstate__(self): 571 """ Get the state of the object for serialization. 572 573 Notifiers are transient and should not be serialized. 574 """ 575 576 result = super().__getstate__() 577 del result["object"] 578 del result["trait"] 579 return result 580 581 def __setstate__(self, state): 582 """ Restore the state of the object after serialization. 583 584 Notifiers are transient and are restored to the empty list. 585 """ 586 587 state.setdefault("name", "") 588 state["notifiers"] = [self.notifier] 589 state["object"] = lambda: None 590 state["trait"] = None 591 self.__dict__.update(state) 592 593 def __reduce_ex__(self, protocol=None): 594 """ Overridden to make sure we call our custom __getstate__. 595 """ 596 return ( 597 copyreg._reconstructor, 598 (type(self), set, list(self)), 599 self.__getstate__(), 600 ) 601