1"""Classes to represent arbitrary sets (including sets of sets). 2 3This module implements sets using dictionaries whose values are 4ignored. The usual operations (union, intersection, deletion, etc.) 5are provided as both methods and operators. 6 7Important: sets are not sequences! While they support 'x in s', 8'len(s)', and 'for x in s', none of those operations are unique for 9sequences; for example, mappings support all three as well. The 10characteristic operation for sequences is subscripting with small 11integers: s[i], for i in range(len(s)). Sets don't support 12subscripting at all. Also, sequences allow multiple occurrences and 13their elements have a definite order; sets on the other hand don't 14record multiple occurrences and don't remember the order of element 15insertion (which is why they don't support s[i]). 16 17The following classes are provided: 18 19BaseSet -- All the operations common to both mutable and immutable 20 sets. This is an abstract class, not meant to be directly 21 instantiated. 22 23Set -- Mutable sets, subclass of BaseSet; not hashable. 24 25ImmutableSet -- Immutable sets, subclass of BaseSet; hashable. 26 An iterable argument is mandatory to create an ImmutableSet. 27 28_TemporarilyImmutableSet -- A wrapper around a Set, hashable, 29 giving the same hash value as the immutable set equivalent 30 would have. Do not use this class directly. 31 32Only hashable objects can be added to a Set. In particular, you cannot 33really add a Set as an element to another Set; if you try, what is 34actually added is an ImmutableSet built from it (it compares equal to 35the one you tried adding). 36 37When you ask if `x in y' where x is a Set and y is a Set or 38ImmutableSet, x is wrapped into a _TemporarilyImmutableSet z, and 39what's tested is actually `z in y'. 40 41""" 42 43# Code history: 44# 45# - Greg V. Wilson wrote the first version, using a different approach 46# to the mutable/immutable problem, and inheriting from dict. 47# 48# - Alex Martelli modified Greg's version to implement the current 49# Set/ImmutableSet approach, and make the data an attribute. 50# 51# - Guido van Rossum rewrote much of the code, made some API changes, 52# and cleaned up the docstrings. 53# 54# - Raymond Hettinger added a number of speedups and other 55# improvements. 56 57from __future__ import generators 58try: 59 from itertools import ifilter, ifilterfalse 60except ImportError: 61 # Code to make the module run under Py2.2 62 def ifilter(predicate, iterable): 63 if predicate is None: 64 def predicate(x): 65 return x 66 for x in iterable: 67 if predicate(x): 68 yield x 69 def ifilterfalse(predicate, iterable): 70 if predicate is None: 71 def predicate(x): 72 return x 73 for x in iterable: 74 if not predicate(x): 75 yield x 76 77__all__ = ['BaseSet', 'Set', 'ImmutableSet'] 78 79class BaseSet(object): 80 """Common base class for mutable and immutable sets.""" 81 82 __slots__ = ['_data'] 83 84 # Constructor 85 86 def __init__(self): 87 """This is an abstract class.""" 88 # Don't call this from a concrete subclass! 89 if self.__class__ is BaseSet: 90 raise TypeError("BaseSet is an abstract class. " 91 "Use Set or ImmutableSet.") 92 93 # Standard protocols: __len__, __repr__, __str__, __iter__ 94 95 def __len__(self): 96 """Return the number of elements of a set.""" 97 return len(self._data) 98 99 def __repr__(self): 100 """Return string representation of a set. 101 102 This looks like 'Set([<list of elements>])'. 103 """ 104 return self._repr() 105 106 # __str__ is the same as __repr__ 107 __str__ = __repr__ 108 109 def _repr(self, sorted=False): 110 elements = self._data.keys() 111 if sorted: 112 elements.sort() 113 return '%s(%r)' % (self.__class__.__name__, elements) 114 115 def __iter__(self): 116 """Return an iterator over the elements or a set. 117 118 This is the keys iterator for the underlying dict. 119 """ 120 return self._data.iterkeys() 121 122 # Three-way comparison is not supported. However, because __eq__ is 123 # tried before __cmp__, if Set x == Set y, x.__eq__(y) returns True and 124 # then cmp(x, y) returns 0 (Python doesn't actually call __cmp__ in this 125 # case). 126 127 def __cmp__(self, other): 128 raise TypeError("can't compare sets using cmp()") 129 130 # Equality comparisons using the underlying dicts. Mixed-type comparisons 131 # are allowed here, where Set == z for non-Set z always returns False, 132 # and Set != z always True. This allows expressions like "x in y" to 133 # give the expected result when y is a sequence of mixed types, not 134 # raising a pointless TypeError just because y contains a Set, or x is 135 # a Set and y contain's a non-set ("in" invokes only __eq__). 136 # Subtle: it would be nicer if __eq__ and __ne__ could return 137 # NotImplemented instead of True or False. Then the other comparand 138 # would get a chance to determine the result, and if the other comparand 139 # also returned NotImplemented then it would fall back to object address 140 # comparison (which would always return False for __eq__ and always 141 # True for __ne__). However, that doesn't work, because this type 142 # *also* implements __cmp__: if, e.g., __eq__ returns NotImplemented, 143 # Python tries __cmp__ next, and the __cmp__ here then raises TypeError. 144 145 def __eq__(self, other): 146 if isinstance(other, BaseSet): 147 return self._data == other._data 148 else: 149 return False 150 151 def __ne__(self, other): 152 if isinstance(other, BaseSet): 153 return self._data != other._data 154 else: 155 return True 156 157 # Copying operations 158 159 def copy(self): 160 """Return a shallow copy of a set.""" 161 result = self.__class__() 162 result._data.update(self._data) 163 return result 164 165 __copy__ = copy # For the copy module 166 167 def __deepcopy__(self, memo): 168 """Return a deep copy of a set; used by copy module.""" 169 # This pre-creates the result and inserts it in the memo 170 # early, in case the deep copy recurses into another reference 171 # to this same set. A set can't be an element of itself, but 172 # it can certainly contain an object that has a reference to 173 # itself. 174 from copy import deepcopy 175 result = self.__class__() 176 memo[id(self)] = result 177 data = result._data 178 value = True 179 for elt in self: 180 data[deepcopy(elt, memo)] = value 181 return result 182 183 # Standard set operations: union, intersection, both differences. 184 # Each has an operator version (e.g. __or__, invoked with |) and a 185 # method version (e.g. union). 186 # Subtle: Each pair requires distinct code so that the outcome is 187 # correct when the type of other isn't suitable. For example, if 188 # we did "union = __or__" instead, then Set().union(3) would return 189 # NotImplemented instead of raising TypeError (albeit that *why* it 190 # raises TypeError as-is is also a bit subtle). 191 192 def __or__(self, other): 193 """Return the union of two sets as a new set. 194 195 (I.e. all elements that are in either set.) 196 """ 197 if not isinstance(other, BaseSet): 198 return NotImplemented 199 return self.union(other) 200 201 def union(self, other): 202 """Return the union of two sets as a new set. 203 204 (I.e. all elements that are in either set.) 205 """ 206 result = self.__class__(self) 207 result._update(other) 208 return result 209 210 def __and__(self, other): 211 """Return the intersection of two sets as a new set. 212 213 (I.e. all elements that are in both sets.) 214 """ 215 if not isinstance(other, BaseSet): 216 return NotImplemented 217 return self.intersection(other) 218 219 def intersection(self, other): 220 """Return the intersection of two sets as a new set. 221 222 (I.e. all elements that are in both sets.) 223 """ 224 if not isinstance(other, BaseSet): 225 other = Set(other) 226 if len(self) <= len(other): 227 little, big = self, other 228 else: 229 little, big = other, self 230 common = ifilter(big._data.has_key, little) 231 return self.__class__(common) 232 233 def __xor__(self, other): 234 """Return the symmetric difference of two sets as a new set. 235 236 (I.e. all elements that are in exactly one of the sets.) 237 """ 238 if not isinstance(other, BaseSet): 239 return NotImplemented 240 return self.symmetric_difference(other) 241 242 def symmetric_difference(self, other): 243 """Return the symmetric difference of two sets as a new set. 244 245 (I.e. all elements that are in exactly one of the sets.) 246 """ 247 result = self.__class__() 248 data = result._data 249 value = True 250 selfdata = self._data 251 try: 252 otherdata = other._data 253 except AttributeError: 254 otherdata = Set(other)._data 255 for elt in ifilterfalse(otherdata.has_key, selfdata): 256 data[elt] = value 257 for elt in ifilterfalse(selfdata.has_key, otherdata): 258 data[elt] = value 259 return result 260 261 def __sub__(self, other): 262 """Return the difference of two sets as a new Set. 263 264 (I.e. all elements that are in this set and not in the other.) 265 """ 266 if not isinstance(other, BaseSet): 267 return NotImplemented 268 return self.difference(other) 269 270 def difference(self, other): 271 """Return the difference of two sets as a new Set. 272 273 (I.e. all elements that are in this set and not in the other.) 274 """ 275 result = self.__class__() 276 data = result._data 277 try: 278 otherdata = other._data 279 except AttributeError: 280 otherdata = Set(other)._data 281 value = True 282 for elt in ifilterfalse(otherdata.has_key, self): 283 data[elt] = value 284 return result 285 286 # Membership test 287 288 def __contains__(self, element): 289 """Report whether an element is a member of a set. 290 291 (Called in response to the expression `element in self'.) 292 """ 293 try: 294 return element in self._data 295 except TypeError: 296 transform = getattr(element, "__as_temporarily_immutable__", None) 297 if transform is None: 298 raise # re-raise the TypeError exception we caught 299 return transform() in self._data 300 301 # Subset and superset test 302 303 def issubset(self, other): 304 """Report whether another set contains this set.""" 305 self._binary_sanity_check(other) 306 if len(self) > len(other): # Fast check for obvious cases 307 return False 308 for elt in ifilterfalse(other._data.has_key, self): 309 return False 310 return True 311 312 def issuperset(self, other): 313 """Report whether this set contains another set.""" 314 self._binary_sanity_check(other) 315 if len(self) < len(other): # Fast check for obvious cases 316 return False 317 for elt in ifilterfalse(self._data.has_key, other): 318 return False 319 return True 320 321 # Inequality comparisons using the is-subset relation. 322 __le__ = issubset 323 __ge__ = issuperset 324 325 def __lt__(self, other): 326 self._binary_sanity_check(other) 327 return len(self) < len(other) and self.issubset(other) 328 329 def __gt__(self, other): 330 self._binary_sanity_check(other) 331 return len(self) > len(other) and self.issuperset(other) 332 333 # Assorted helpers 334 335 def _binary_sanity_check(self, other): 336 # Check that the other argument to a binary operation is also 337 # a set, raising a TypeError otherwise. 338 if not isinstance(other, BaseSet): 339 raise TypeError("Binary operation only permitted between sets") 340 341 def _compute_hash(self): 342 # Calculate hash code for a set by xor'ing the hash codes of 343 # the elements. This ensures that the hash code does not depend 344 # on the order in which elements are added to the set. This is 345 # not called __hash__ because a BaseSet should not be hashable; 346 # only an ImmutableSet is hashable. 347 result = 0 348 for elt in self: 349 result ^= hash(elt) 350 return result 351 352 def _update(self, iterable): 353 # The main loop for update() and the subclass __init__() methods. 354 data = self._data 355 356 # Use the fast update() method when a dictionary is available. 357 if isinstance(iterable, BaseSet): 358 data.update(iterable._data) 359 return 360 361 value = True 362 363 if type(iterable) in (list, tuple, xrange): 364 # Optimized: we know that __iter__() and next() can't 365 # raise TypeError, so we can move 'try:' out of the loop. 366 it = iter(iterable) 367 while True: 368 try: 369 for element in it: 370 data[element] = value 371 return 372 except TypeError: 373 transform = getattr(element, "__as_immutable__", None) 374 if transform is None: 375 raise # re-raise the TypeError exception we caught 376 data[transform()] = value 377 else: 378 # Safe: only catch TypeError where intended 379 for element in iterable: 380 try: 381 data[element] = value 382 except TypeError: 383 transform = getattr(element, "__as_immutable__", None) 384 if transform is None: 385 raise # re-raise the TypeError exception we caught 386 data[transform()] = value 387 388 389class ImmutableSet(BaseSet): 390 """Immutable set class.""" 391 392 __slots__ = ['_hashcode'] 393 394 # BaseSet + hashing 395 396 def __init__(self, iterable=None): 397 """Construct an immutable set from an optional iterable.""" 398 self._hashcode = None 399 self._data = {} 400 if iterable is not None: 401 self._update(iterable) 402 403 def __hash__(self): 404 if self._hashcode is None: 405 self._hashcode = self._compute_hash() 406 return self._hashcode 407 408 def __getstate__(self): 409 return self._data, self._hashcode 410 411 def __setstate__(self, state): 412 self._data, self._hashcode = state 413 414class Set(BaseSet): 415 """ Mutable set class.""" 416 417 __slots__ = [] 418 419 # BaseSet + operations requiring mutability; no hashing 420 421 def __init__(self, iterable=None): 422 """Construct a set from an optional iterable.""" 423 self._data = {} 424 if iterable is not None: 425 self._update(iterable) 426 427 def __getstate__(self): 428 # getstate's results are ignored if it is not 429 return self._data, 430 431 def __setstate__(self, data): 432 self._data, = data 433 434 def __hash__(self): 435 """A Set cannot be hashed.""" 436 # We inherit object.__hash__, so we must deny this explicitly 437 raise TypeError("Can't hash a Set, only an ImmutableSet.") 438 439 # In-place union, intersection, differences. 440 # Subtle: The xyz_update() functions deliberately return None, 441 # as do all mutating operations on built-in container types. 442 # The __xyz__ spellings have to return self, though. 443 444 def __ior__(self, other): 445 """Update a set with the union of itself and another.""" 446 self._binary_sanity_check(other) 447 self._data.update(other._data) 448 return self 449 450 def union_update(self, other): 451 """Update a set with the union of itself and another.""" 452 self._update(other) 453 454 def __iand__(self, other): 455 """Update a set with the intersection of itself and another.""" 456 self._binary_sanity_check(other) 457 self._data = (self & other)._data 458 return self 459 460 def intersection_update(self, other): 461 """Update a set with the intersection of itself and another.""" 462 if isinstance(other, BaseSet): 463 self &= other 464 else: 465 self._data = (self.intersection(other))._data 466 467 def __ixor__(self, other): 468 """Update a set with the symmetric difference of itself and another.""" 469 self._binary_sanity_check(other) 470 self.symmetric_difference_update(other) 471 return self 472 473 def symmetric_difference_update(self, other): 474 """Update a set with the symmetric difference of itself and another.""" 475 data = self._data 476 value = True 477 if not isinstance(other, BaseSet): 478 other = Set(other) 479 if self is other: 480 self.clear() 481 for elt in other: 482 if elt in data: 483 del data[elt] 484 else: 485 data[elt] = value 486 487 def __isub__(self, other): 488 """Remove all elements of another set from this set.""" 489 self._binary_sanity_check(other) 490 self.difference_update(other) 491 return self 492 493 def difference_update(self, other): 494 """Remove all elements of another set from this set.""" 495 data = self._data 496 if not isinstance(other, BaseSet): 497 other = Set(other) 498 if self is other: 499 self.clear() 500 for elt in ifilter(data.has_key, other): 501 del data[elt] 502 503 # Python dict-like mass mutations: update, clear 504 505 def update(self, iterable): 506 """Add all values from an iterable (such as a list or file).""" 507 self._update(iterable) 508 509 def clear(self): 510 """Remove all elements from this set.""" 511 self._data.clear() 512 513 # Single-element mutations: add, remove, discard 514 515 def add(self, element): 516 """Add an element to a set. 517 518 This has no effect if the element is already present. 519 """ 520 try: 521 self._data[element] = True 522 except TypeError: 523 transform = getattr(element, "__as_immutable__", None) 524 if transform is None: 525 raise # re-raise the TypeError exception we caught 526 self._data[transform()] = True 527 528 def remove(self, element): 529 """Remove an element from a set; it must be a member. 530 531 If the element is not a member, raise a KeyError. 532 """ 533 try: 534 del self._data[element] 535 except TypeError: 536 transform = getattr(element, "__as_temporarily_immutable__", None) 537 if transform is None: 538 raise # re-raise the TypeError exception we caught 539 del self._data[transform()] 540 541 def discard(self, element): 542 """Remove an element from a set if it is a member. 543 544 If the element is not a member, do nothing. 545 """ 546 try: 547 self.remove(element) 548 except KeyError: 549 pass 550 551 def pop(self): 552 """Remove and return an arbitrary set element.""" 553 return self._data.popitem()[0] 554 555 def __as_immutable__(self): 556 # Return a copy of self as an immutable set 557 return ImmutableSet(self) 558 559 def __as_temporarily_immutable__(self): 560 # Return self wrapped in a temporarily immutable set 561 return _TemporarilyImmutableSet(self) 562 563 564class _TemporarilyImmutableSet(BaseSet): 565 # Wrap a mutable set as if it was temporarily immutable. 566 # This only supplies hashing and equality comparisons. 567 568 def __init__(self, set): 569 self._set = set 570 self._data = set._data # Needed by ImmutableSet.__eq__() 571 572 def __hash__(self): 573 return self._set._compute_hash() 574 575# Local Variables: 576# tab-width:4 577# indent-tabs-mode:nil 578# End: 579# vim: set expandtab tabstop=4 shiftwidth=4: 580