1import unittest
2from test import support
3from test.support import warnings_helper
4import gc
5import weakref
6import operator
7import copy
8import pickle
9from random import randrange, shuffle
10import warnings
11import collections
12import collections.abc
13import itertools
14
15class PassThru(Exception):
16    pass
17
18def check_pass_thru():
19    raise PassThru
20    yield 1
21
22class BadCmp:
23    def __hash__(self):
24        return 1
25    def __eq__(self, other):
26        raise RuntimeError
27
28class ReprWrapper:
29    'Used to test self-referential repr() calls'
30    def __repr__(self):
31        return repr(self.value)
32
33class HashCountingInt(int):
34    'int-like object that counts the number of times __hash__ is called'
35    def __init__(self, *args):
36        self.hash_count = 0
37    def __hash__(self):
38        self.hash_count += 1
39        return int.__hash__(self)
40
41class TestJointOps:
42    # Tests common to both set and frozenset
43
44    def setUp(self):
45        self.word = word = 'simsalabim'
46        self.otherword = 'madagascar'
47        self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
48        self.s = self.thetype(word)
49        self.d = dict.fromkeys(word)
50
51    def test_new_or_init(self):
52        self.assertRaises(TypeError, self.thetype, [], 2)
53        self.assertRaises(TypeError, set().__init__, a=1)
54
55    def test_uniquification(self):
56        actual = sorted(self.s)
57        expected = sorted(self.d)
58        self.assertEqual(actual, expected)
59        self.assertRaises(PassThru, self.thetype, check_pass_thru())
60        self.assertRaises(TypeError, self.thetype, [[]])
61
62    def test_len(self):
63        self.assertEqual(len(self.s), len(self.d))
64
65    def test_contains(self):
66        for c in self.letters:
67            self.assertEqual(c in self.s, c in self.d)
68        self.assertRaises(TypeError, self.s.__contains__, [[]])
69        s = self.thetype([frozenset(self.letters)])
70        self.assertIn(self.thetype(self.letters), s)
71
72    def test_union(self):
73        u = self.s.union(self.otherword)
74        for c in self.letters:
75            self.assertEqual(c in u, c in self.d or c in self.otherword)
76        self.assertEqual(self.s, self.thetype(self.word))
77        self.assertEqual(type(u), self.basetype)
78        self.assertRaises(PassThru, self.s.union, check_pass_thru())
79        self.assertRaises(TypeError, self.s.union, [[]])
80        for C in set, frozenset, dict.fromkeys, str, list, tuple:
81            self.assertEqual(self.thetype('abcba').union(C('cdc')), set('abcd'))
82            self.assertEqual(self.thetype('abcba').union(C('efgfe')), set('abcefg'))
83            self.assertEqual(self.thetype('abcba').union(C('ccb')), set('abc'))
84            self.assertEqual(self.thetype('abcba').union(C('ef')), set('abcef'))
85            self.assertEqual(self.thetype('abcba').union(C('ef'), C('fg')), set('abcefg'))
86
87        # Issue #6573
88        x = self.thetype()
89        self.assertEqual(x.union(set([1]), x, set([2])), self.thetype([1, 2]))
90
91    def test_or(self):
92        i = self.s.union(self.otherword)
93        self.assertEqual(self.s | set(self.otherword), i)
94        self.assertEqual(self.s | frozenset(self.otherword), i)
95        try:
96            self.s | self.otherword
97        except TypeError:
98            pass
99        else:
100            self.fail("s|t did not screen-out general iterables")
101
102    def test_intersection(self):
103        i = self.s.intersection(self.otherword)
104        for c in self.letters:
105            self.assertEqual(c in i, c in self.d and c in self.otherword)
106        self.assertEqual(self.s, self.thetype(self.word))
107        self.assertEqual(type(i), self.basetype)
108        self.assertRaises(PassThru, self.s.intersection, check_pass_thru())
109        for C in set, frozenset, dict.fromkeys, str, list, tuple:
110            self.assertEqual(self.thetype('abcba').intersection(C('cdc')), set('cc'))
111            self.assertEqual(self.thetype('abcba').intersection(C('efgfe')), set(''))
112            self.assertEqual(self.thetype('abcba').intersection(C('ccb')), set('bc'))
113            self.assertEqual(self.thetype('abcba').intersection(C('ef')), set(''))
114            self.assertEqual(self.thetype('abcba').intersection(C('cbcf'), C('bag')), set('b'))
115        s = self.thetype('abcba')
116        z = s.intersection()
117        if self.thetype == frozenset():
118            self.assertEqual(id(s), id(z))
119        else:
120            self.assertNotEqual(id(s), id(z))
121
122    def test_isdisjoint(self):
123        def f(s1, s2):
124            'Pure python equivalent of isdisjoint()'
125            return not set(s1).intersection(s2)
126        for larg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
127            s1 = self.thetype(larg)
128            for rarg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
129                for C in set, frozenset, dict.fromkeys, str, list, tuple:
130                    s2 = C(rarg)
131                    actual = s1.isdisjoint(s2)
132                    expected = f(s1, s2)
133                    self.assertEqual(actual, expected)
134                    self.assertTrue(actual is True or actual is False)
135
136    def test_and(self):
137        i = self.s.intersection(self.otherword)
138        self.assertEqual(self.s & set(self.otherword), i)
139        self.assertEqual(self.s & frozenset(self.otherword), i)
140        try:
141            self.s & self.otherword
142        except TypeError:
143            pass
144        else:
145            self.fail("s&t did not screen-out general iterables")
146
147    def test_difference(self):
148        i = self.s.difference(self.otherword)
149        for c in self.letters:
150            self.assertEqual(c in i, c in self.d and c not in self.otherword)
151        self.assertEqual(self.s, self.thetype(self.word))
152        self.assertEqual(type(i), self.basetype)
153        self.assertRaises(PassThru, self.s.difference, check_pass_thru())
154        self.assertRaises(TypeError, self.s.difference, [[]])
155        for C in set, frozenset, dict.fromkeys, str, list, tuple:
156            self.assertEqual(self.thetype('abcba').difference(C('cdc')), set('ab'))
157            self.assertEqual(self.thetype('abcba').difference(C('efgfe')), set('abc'))
158            self.assertEqual(self.thetype('abcba').difference(C('ccb')), set('a'))
159            self.assertEqual(self.thetype('abcba').difference(C('ef')), set('abc'))
160            self.assertEqual(self.thetype('abcba').difference(), set('abc'))
161            self.assertEqual(self.thetype('abcba').difference(C('a'), C('b')), set('c'))
162
163    def test_sub(self):
164        i = self.s.difference(self.otherword)
165        self.assertEqual(self.s - set(self.otherword), i)
166        self.assertEqual(self.s - frozenset(self.otherword), i)
167        try:
168            self.s - self.otherword
169        except TypeError:
170            pass
171        else:
172            self.fail("s-t did not screen-out general iterables")
173
174    def test_symmetric_difference(self):
175        i = self.s.symmetric_difference(self.otherword)
176        for c in self.letters:
177            self.assertEqual(c in i, (c in self.d) ^ (c in self.otherword))
178        self.assertEqual(self.s, self.thetype(self.word))
179        self.assertEqual(type(i), self.basetype)
180        self.assertRaises(PassThru, self.s.symmetric_difference, check_pass_thru())
181        self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
182        for C in set, frozenset, dict.fromkeys, str, list, tuple:
183            self.assertEqual(self.thetype('abcba').symmetric_difference(C('cdc')), set('abd'))
184            self.assertEqual(self.thetype('abcba').symmetric_difference(C('efgfe')), set('abcefg'))
185            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ccb')), set('a'))
186            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ef')), set('abcef'))
187
188    def test_xor(self):
189        i = self.s.symmetric_difference(self.otherword)
190        self.assertEqual(self.s ^ set(self.otherword), i)
191        self.assertEqual(self.s ^ frozenset(self.otherword), i)
192        try:
193            self.s ^ self.otherword
194        except TypeError:
195            pass
196        else:
197            self.fail("s^t did not screen-out general iterables")
198
199    def test_equality(self):
200        self.assertEqual(self.s, set(self.word))
201        self.assertEqual(self.s, frozenset(self.word))
202        self.assertEqual(self.s == self.word, False)
203        self.assertNotEqual(self.s, set(self.otherword))
204        self.assertNotEqual(self.s, frozenset(self.otherword))
205        self.assertEqual(self.s != self.word, True)
206
207    def test_setOfFrozensets(self):
208        t = map(frozenset, ['abcdef', 'bcd', 'bdcb', 'fed', 'fedccba'])
209        s = self.thetype(t)
210        self.assertEqual(len(s), 3)
211
212    def test_sub_and_super(self):
213        p, q, r = map(self.thetype, ['ab', 'abcde', 'def'])
214        self.assertTrue(p < q)
215        self.assertTrue(p <= q)
216        self.assertTrue(q <= q)
217        self.assertTrue(q > p)
218        self.assertTrue(q >= p)
219        self.assertFalse(q < r)
220        self.assertFalse(q <= r)
221        self.assertFalse(q > r)
222        self.assertFalse(q >= r)
223        self.assertTrue(set('a').issubset('abc'))
224        self.assertTrue(set('abc').issuperset('a'))
225        self.assertFalse(set('a').issubset('cbs'))
226        self.assertFalse(set('cbs').issuperset('a'))
227
228    def test_pickling(self):
229        for i in range(pickle.HIGHEST_PROTOCOL + 1):
230            p = pickle.dumps(self.s, i)
231            dup = pickle.loads(p)
232            self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
233            if type(self.s) not in (set, frozenset):
234                self.s.x = 10
235                p = pickle.dumps(self.s, i)
236                dup = pickle.loads(p)
237                self.assertEqual(self.s.x, dup.x)
238
239    def test_iterator_pickling(self):
240        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
241            itorg = iter(self.s)
242            data = self.thetype(self.s)
243            d = pickle.dumps(itorg, proto)
244            it = pickle.loads(d)
245            # Set iterators unpickle as list iterators due to the
246            # undefined order of set items.
247            # self.assertEqual(type(itorg), type(it))
248            self.assertIsInstance(it, collections.abc.Iterator)
249            self.assertEqual(self.thetype(it), data)
250
251            it = pickle.loads(d)
252            try:
253                drop = next(it)
254            except StopIteration:
255                continue
256            d = pickle.dumps(it, proto)
257            it = pickle.loads(d)
258            self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
259
260    def test_deepcopy(self):
261        class Tracer:
262            def __init__(self, value):
263                self.value = value
264            def __hash__(self):
265                return self.value
266            def __deepcopy__(self, memo=None):
267                return Tracer(self.value + 1)
268        t = Tracer(10)
269        s = self.thetype([t])
270        dup = copy.deepcopy(s)
271        self.assertNotEqual(id(s), id(dup))
272        for elem in dup:
273            newt = elem
274        self.assertNotEqual(id(t), id(newt))
275        self.assertEqual(t.value + 1, newt.value)
276
277    def test_gc(self):
278        # Create a nest of cycles to exercise overall ref count check
279        class A:
280            pass
281        s = set(A() for i in range(1000))
282        for elem in s:
283            elem.cycle = s
284            elem.sub = elem
285            elem.set = set([elem])
286
287    def test_subclass_with_custom_hash(self):
288        # Bug #1257731
289        class H(self.thetype):
290            def __hash__(self):
291                return int(id(self) & 0x7fffffff)
292        s=H()
293        f=set()
294        f.add(s)
295        self.assertIn(s, f)
296        f.remove(s)
297        f.add(s)
298        f.discard(s)
299
300    def test_badcmp(self):
301        s = self.thetype([BadCmp()])
302        # Detect comparison errors during insertion and lookup
303        self.assertRaises(RuntimeError, self.thetype, [BadCmp(), BadCmp()])
304        self.assertRaises(RuntimeError, s.__contains__, BadCmp())
305        # Detect errors during mutating operations
306        if hasattr(s, 'add'):
307            self.assertRaises(RuntimeError, s.add, BadCmp())
308            self.assertRaises(RuntimeError, s.discard, BadCmp())
309            self.assertRaises(RuntimeError, s.remove, BadCmp())
310
311    def test_cyclical_repr(self):
312        w = ReprWrapper()
313        s = self.thetype([w])
314        w.value = s
315        if self.thetype == set:
316            self.assertEqual(repr(s), '{set(...)}')
317        else:
318            name = repr(s).partition('(')[0]    # strip class name
319            self.assertEqual(repr(s), '%s({%s(...)})' % (name, name))
320
321    def test_do_not_rehash_dict_keys(self):
322        n = 10
323        d = dict.fromkeys(map(HashCountingInt, range(n)))
324        self.assertEqual(sum(elem.hash_count for elem in d), n)
325        s = self.thetype(d)
326        self.assertEqual(sum(elem.hash_count for elem in d), n)
327        s.difference(d)
328        self.assertEqual(sum(elem.hash_count for elem in d), n)
329        if hasattr(s, 'symmetric_difference_update'):
330            s.symmetric_difference_update(d)
331        self.assertEqual(sum(elem.hash_count for elem in d), n)
332        d2 = dict.fromkeys(set(d))
333        self.assertEqual(sum(elem.hash_count for elem in d), n)
334        d3 = dict.fromkeys(frozenset(d))
335        self.assertEqual(sum(elem.hash_count for elem in d), n)
336        d3 = dict.fromkeys(frozenset(d), 123)
337        self.assertEqual(sum(elem.hash_count for elem in d), n)
338        self.assertEqual(d3, dict.fromkeys(d, 123))
339
340    def test_container_iterator(self):
341        # Bug #3680: tp_traverse was not implemented for set iterator object
342        class C(object):
343            pass
344        obj = C()
345        ref = weakref.ref(obj)
346        container = set([obj, 1])
347        obj.x = iter(container)
348        del obj, container
349        gc.collect()
350        self.assertTrue(ref() is None, "Cycle was not collected")
351
352    def test_free_after_iterating(self):
353        support.check_free_after_iterating(self, iter, self.thetype)
354
355class TestSet(TestJointOps, unittest.TestCase):
356    thetype = set
357    basetype = set
358
359    def test_init(self):
360        s = self.thetype()
361        s.__init__(self.word)
362        self.assertEqual(s, set(self.word))
363        s.__init__(self.otherword)
364        self.assertEqual(s, set(self.otherword))
365        self.assertRaises(TypeError, s.__init__, s, 2)
366        self.assertRaises(TypeError, s.__init__, 1)
367
368    def test_constructor_identity(self):
369        s = self.thetype(range(3))
370        t = self.thetype(s)
371        self.assertNotEqual(id(s), id(t))
372
373    def test_set_literal(self):
374        s = set([1,2,3])
375        t = {1,2,3}
376        self.assertEqual(s, t)
377
378    def test_set_literal_insertion_order(self):
379        # SF Issue #26020 -- Expect left to right insertion
380        s = {1, 1.0, True}
381        self.assertEqual(len(s), 1)
382        stored_value = s.pop()
383        self.assertEqual(type(stored_value), int)
384
385    def test_set_literal_evaluation_order(self):
386        # Expect left to right expression evaluation
387        events = []
388        def record(obj):
389            events.append(obj)
390        s = {record(1), record(2), record(3)}
391        self.assertEqual(events, [1, 2, 3])
392
393    def test_hash(self):
394        self.assertRaises(TypeError, hash, self.s)
395
396    def test_clear(self):
397        self.s.clear()
398        self.assertEqual(self.s, set())
399        self.assertEqual(len(self.s), 0)
400
401    def test_copy(self):
402        dup = self.s.copy()
403        self.assertEqual(self.s, dup)
404        self.assertNotEqual(id(self.s), id(dup))
405        self.assertEqual(type(dup), self.basetype)
406
407    def test_add(self):
408        self.s.add('Q')
409        self.assertIn('Q', self.s)
410        dup = self.s.copy()
411        self.s.add('Q')
412        self.assertEqual(self.s, dup)
413        self.assertRaises(TypeError, self.s.add, [])
414
415    def test_remove(self):
416        self.s.remove('a')
417        self.assertNotIn('a', self.s)
418        self.assertRaises(KeyError, self.s.remove, 'Q')
419        self.assertRaises(TypeError, self.s.remove, [])
420        s = self.thetype([frozenset(self.word)])
421        self.assertIn(self.thetype(self.word), s)
422        s.remove(self.thetype(self.word))
423        self.assertNotIn(self.thetype(self.word), s)
424        self.assertRaises(KeyError, self.s.remove, self.thetype(self.word))
425
426    def test_remove_keyerror_unpacking(self):
427        # bug:  www.python.org/sf/1576657
428        for v1 in ['Q', (1,)]:
429            try:
430                self.s.remove(v1)
431            except KeyError as e:
432                v2 = e.args[0]
433                self.assertEqual(v1, v2)
434            else:
435                self.fail()
436
437    def test_remove_keyerror_set(self):
438        key = self.thetype([3, 4])
439        try:
440            self.s.remove(key)
441        except KeyError as e:
442            self.assertTrue(e.args[0] is key,
443                         "KeyError should be {0}, not {1}".format(key,
444                                                                  e.args[0]))
445        else:
446            self.fail()
447
448    def test_discard(self):
449        self.s.discard('a')
450        self.assertNotIn('a', self.s)
451        self.s.discard('Q')
452        self.assertRaises(TypeError, self.s.discard, [])
453        s = self.thetype([frozenset(self.word)])
454        self.assertIn(self.thetype(self.word), s)
455        s.discard(self.thetype(self.word))
456        self.assertNotIn(self.thetype(self.word), s)
457        s.discard(self.thetype(self.word))
458
459    def test_pop(self):
460        for i in range(len(self.s)):
461            elem = self.s.pop()
462            self.assertNotIn(elem, self.s)
463        self.assertRaises(KeyError, self.s.pop)
464
465    def test_update(self):
466        retval = self.s.update(self.otherword)
467        self.assertEqual(retval, None)
468        for c in (self.word + self.otherword):
469            self.assertIn(c, self.s)
470        self.assertRaises(PassThru, self.s.update, check_pass_thru())
471        self.assertRaises(TypeError, self.s.update, [[]])
472        for p, q in (('cdc', 'abcd'), ('efgfe', 'abcefg'), ('ccb', 'abc'), ('ef', 'abcef')):
473            for C in set, frozenset, dict.fromkeys, str, list, tuple:
474                s = self.thetype('abcba')
475                self.assertEqual(s.update(C(p)), None)
476                self.assertEqual(s, set(q))
477        for p in ('cdc', 'efgfe', 'ccb', 'ef', 'abcda'):
478            q = 'ahi'
479            for C in set, frozenset, dict.fromkeys, str, list, tuple:
480                s = self.thetype('abcba')
481                self.assertEqual(s.update(C(p), C(q)), None)
482                self.assertEqual(s, set(s) | set(p) | set(q))
483
484    def test_ior(self):
485        self.s |= set(self.otherword)
486        for c in (self.word + self.otherword):
487            self.assertIn(c, self.s)
488
489    def test_intersection_update(self):
490        retval = self.s.intersection_update(self.otherword)
491        self.assertEqual(retval, None)
492        for c in (self.word + self.otherword):
493            if c in self.otherword and c in self.word:
494                self.assertIn(c, self.s)
495            else:
496                self.assertNotIn(c, self.s)
497        self.assertRaises(PassThru, self.s.intersection_update, check_pass_thru())
498        self.assertRaises(TypeError, self.s.intersection_update, [[]])
499        for p, q in (('cdc', 'c'), ('efgfe', ''), ('ccb', 'bc'), ('ef', '')):
500            for C in set, frozenset, dict.fromkeys, str, list, tuple:
501                s = self.thetype('abcba')
502                self.assertEqual(s.intersection_update(C(p)), None)
503                self.assertEqual(s, set(q))
504                ss = 'abcba'
505                s = self.thetype(ss)
506                t = 'cbc'
507                self.assertEqual(s.intersection_update(C(p), C(t)), None)
508                self.assertEqual(s, set('abcba')&set(p)&set(t))
509
510    def test_iand(self):
511        self.s &= set(self.otherword)
512        for c in (self.word + self.otherword):
513            if c in self.otherword and c in self.word:
514                self.assertIn(c, self.s)
515            else:
516                self.assertNotIn(c, self.s)
517
518    def test_difference_update(self):
519        retval = self.s.difference_update(self.otherword)
520        self.assertEqual(retval, None)
521        for c in (self.word + self.otherword):
522            if c in self.word and c not in self.otherword:
523                self.assertIn(c, self.s)
524            else:
525                self.assertNotIn(c, self.s)
526        self.assertRaises(PassThru, self.s.difference_update, check_pass_thru())
527        self.assertRaises(TypeError, self.s.difference_update, [[]])
528        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
529        for p, q in (('cdc', 'ab'), ('efgfe', 'abc'), ('ccb', 'a'), ('ef', 'abc')):
530            for C in set, frozenset, dict.fromkeys, str, list, tuple:
531                s = self.thetype('abcba')
532                self.assertEqual(s.difference_update(C(p)), None)
533                self.assertEqual(s, set(q))
534
535                s = self.thetype('abcdefghih')
536                s.difference_update()
537                self.assertEqual(s, self.thetype('abcdefghih'))
538
539                s = self.thetype('abcdefghih')
540                s.difference_update(C('aba'))
541                self.assertEqual(s, self.thetype('cdefghih'))
542
543                s = self.thetype('abcdefghih')
544                s.difference_update(C('cdc'), C('aba'))
545                self.assertEqual(s, self.thetype('efghih'))
546
547    def test_isub(self):
548        self.s -= set(self.otherword)
549        for c in (self.word + self.otherword):
550            if c in self.word and c not in self.otherword:
551                self.assertIn(c, self.s)
552            else:
553                self.assertNotIn(c, self.s)
554
555    def test_symmetric_difference_update(self):
556        retval = self.s.symmetric_difference_update(self.otherword)
557        self.assertEqual(retval, None)
558        for c in (self.word + self.otherword):
559            if (c in self.word) ^ (c in self.otherword):
560                self.assertIn(c, self.s)
561            else:
562                self.assertNotIn(c, self.s)
563        self.assertRaises(PassThru, self.s.symmetric_difference_update, check_pass_thru())
564        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
565        for p, q in (('cdc', 'abd'), ('efgfe', 'abcefg'), ('ccb', 'a'), ('ef', 'abcef')):
566            for C in set, frozenset, dict.fromkeys, str, list, tuple:
567                s = self.thetype('abcba')
568                self.assertEqual(s.symmetric_difference_update(C(p)), None)
569                self.assertEqual(s, set(q))
570
571    def test_ixor(self):
572        self.s ^= set(self.otherword)
573        for c in (self.word + self.otherword):
574            if (c in self.word) ^ (c in self.otherword):
575                self.assertIn(c, self.s)
576            else:
577                self.assertNotIn(c, self.s)
578
579    def test_inplace_on_self(self):
580        t = self.s.copy()
581        t |= t
582        self.assertEqual(t, self.s)
583        t &= t
584        self.assertEqual(t, self.s)
585        t -= t
586        self.assertEqual(t, self.thetype())
587        t = self.s.copy()
588        t ^= t
589        self.assertEqual(t, self.thetype())
590
591    def test_weakref(self):
592        s = self.thetype('gallahad')
593        p = weakref.proxy(s)
594        self.assertEqual(str(p), str(s))
595        s = None
596        support.gc_collect()  # For PyPy or other GCs.
597        self.assertRaises(ReferenceError, str, p)
598
599    def test_rich_compare(self):
600        class TestRichSetCompare:
601            def __gt__(self, some_set):
602                self.gt_called = True
603                return False
604            def __lt__(self, some_set):
605                self.lt_called = True
606                return False
607            def __ge__(self, some_set):
608                self.ge_called = True
609                return False
610            def __le__(self, some_set):
611                self.le_called = True
612                return False
613
614        # This first tries the builtin rich set comparison, which doesn't know
615        # how to handle the custom object. Upon returning NotImplemented, the
616        # corresponding comparison on the right object is invoked.
617        myset = {1, 2, 3}
618
619        myobj = TestRichSetCompare()
620        myset < myobj
621        self.assertTrue(myobj.gt_called)
622
623        myobj = TestRichSetCompare()
624        myset > myobj
625        self.assertTrue(myobj.lt_called)
626
627        myobj = TestRichSetCompare()
628        myset <= myobj
629        self.assertTrue(myobj.ge_called)
630
631        myobj = TestRichSetCompare()
632        myset >= myobj
633        self.assertTrue(myobj.le_called)
634
635    @unittest.skipUnless(hasattr(set, "test_c_api"),
636                         'C API test only available in a debug build')
637    def test_c_api(self):
638        self.assertEqual(set().test_c_api(), True)
639
640class SetSubclass(set):
641    pass
642
643class TestSetSubclass(TestSet):
644    thetype = SetSubclass
645    basetype = set
646
647    def test_keywords_in_subclass(self):
648        class subclass(set):
649            pass
650        u = subclass([1, 2])
651        self.assertIs(type(u), subclass)
652        self.assertEqual(set(u), {1, 2})
653        with self.assertRaises(TypeError):
654            subclass(sequence=())
655
656        class subclass_with_init(set):
657            def __init__(self, arg, newarg=None):
658                super().__init__(arg)
659                self.newarg = newarg
660        u = subclass_with_init([1, 2], newarg=3)
661        self.assertIs(type(u), subclass_with_init)
662        self.assertEqual(set(u), {1, 2})
663        self.assertEqual(u.newarg, 3)
664
665        class subclass_with_new(set):
666            def __new__(cls, arg, newarg=None):
667                self = super().__new__(cls, arg)
668                self.newarg = newarg
669                return self
670        u = subclass_with_new([1, 2], newarg=3)
671        self.assertIs(type(u), subclass_with_new)
672        self.assertEqual(set(u), {1, 2})
673        self.assertEqual(u.newarg, 3)
674
675
676class TestFrozenSet(TestJointOps, unittest.TestCase):
677    thetype = frozenset
678    basetype = frozenset
679
680    def test_init(self):
681        s = self.thetype(self.word)
682        s.__init__(self.otherword)
683        self.assertEqual(s, set(self.word))
684
685    def test_constructor_identity(self):
686        s = self.thetype(range(3))
687        t = self.thetype(s)
688        self.assertEqual(id(s), id(t))
689
690    def test_hash(self):
691        self.assertEqual(hash(self.thetype('abcdeb')),
692                         hash(self.thetype('ebecda')))
693
694        # make sure that all permutations give the same hash value
695        n = 100
696        seq = [randrange(n) for i in range(n)]
697        results = set()
698        for i in range(200):
699            shuffle(seq)
700            results.add(hash(self.thetype(seq)))
701        self.assertEqual(len(results), 1)
702
703    def test_copy(self):
704        dup = self.s.copy()
705        self.assertEqual(id(self.s), id(dup))
706
707    def test_frozen_as_dictkey(self):
708        seq = list(range(10)) + list('abcdefg') + ['apple']
709        key1 = self.thetype(seq)
710        key2 = self.thetype(reversed(seq))
711        self.assertEqual(key1, key2)
712        self.assertNotEqual(id(key1), id(key2))
713        d = {}
714        d[key1] = 42
715        self.assertEqual(d[key2], 42)
716
717    def test_hash_caching(self):
718        f = self.thetype('abcdcda')
719        self.assertEqual(hash(f), hash(f))
720
721    def test_hash_effectiveness(self):
722        n = 13
723        hashvalues = set()
724        addhashvalue = hashvalues.add
725        elemmasks = [(i+1, 1<<i) for i in range(n)]
726        for i in range(2**n):
727            addhashvalue(hash(frozenset([e for e, m in elemmasks if m&i])))
728        self.assertEqual(len(hashvalues), 2**n)
729
730        def zf_range(n):
731            # https://en.wikipedia.org/wiki/Set-theoretic_definition_of_natural_numbers
732            nums = [frozenset()]
733            for i in range(n-1):
734                num = frozenset(nums)
735                nums.append(num)
736            return nums[:n]
737
738        def powerset(s):
739            for i in range(len(s)+1):
740                yield from map(frozenset, itertools.combinations(s, i))
741
742        for n in range(18):
743            t = 2 ** n
744            mask = t - 1
745            for nums in (range, zf_range):
746                u = len({h & mask for h in map(hash, powerset(nums(n)))})
747                self.assertGreater(4*u, t)
748
749class FrozenSetSubclass(frozenset):
750    pass
751
752class TestFrozenSetSubclass(TestFrozenSet):
753    thetype = FrozenSetSubclass
754    basetype = frozenset
755
756    def test_keywords_in_subclass(self):
757        class subclass(frozenset):
758            pass
759        u = subclass([1, 2])
760        self.assertIs(type(u), subclass)
761        self.assertEqual(set(u), {1, 2})
762        with self.assertRaises(TypeError):
763            subclass(sequence=())
764
765        class subclass_with_init(frozenset):
766            def __init__(self, arg, newarg=None):
767                self.newarg = newarg
768        u = subclass_with_init([1, 2], newarg=3)
769        self.assertIs(type(u), subclass_with_init)
770        self.assertEqual(set(u), {1, 2})
771        self.assertEqual(u.newarg, 3)
772
773        class subclass_with_new(frozenset):
774            def __new__(cls, arg, newarg=None):
775                self = super().__new__(cls, arg)
776                self.newarg = newarg
777                return self
778        u = subclass_with_new([1, 2], newarg=3)
779        self.assertIs(type(u), subclass_with_new)
780        self.assertEqual(set(u), {1, 2})
781        self.assertEqual(u.newarg, 3)
782
783    def test_constructor_identity(self):
784        s = self.thetype(range(3))
785        t = self.thetype(s)
786        self.assertNotEqual(id(s), id(t))
787
788    def test_copy(self):
789        dup = self.s.copy()
790        self.assertNotEqual(id(self.s), id(dup))
791
792    def test_nested_empty_constructor(self):
793        s = self.thetype()
794        t = self.thetype(s)
795        self.assertEqual(s, t)
796
797    def test_singleton_empty_frozenset(self):
798        Frozenset = self.thetype
799        f = frozenset()
800        F = Frozenset()
801        efs = [Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
802               Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
803               Frozenset(range(0)), Frozenset(Frozenset()),
804               Frozenset(frozenset()), f, F, Frozenset(f), Frozenset(F)]
805        # All empty frozenset subclass instances should have different ids
806        self.assertEqual(len(set(map(id, efs))), len(efs))
807
808# Tests taken from test_sets.py =============================================
809
810empty_set = set()
811
812#==============================================================================
813
814class TestBasicOps:
815
816    def test_repr(self):
817        if self.repr is not None:
818            self.assertEqual(repr(self.set), self.repr)
819
820    def check_repr_against_values(self):
821        text = repr(self.set)
822        self.assertTrue(text.startswith('{'))
823        self.assertTrue(text.endswith('}'))
824
825        result = text[1:-1].split(', ')
826        result.sort()
827        sorted_repr_values = [repr(value) for value in self.values]
828        sorted_repr_values.sort()
829        self.assertEqual(result, sorted_repr_values)
830
831    def test_length(self):
832        self.assertEqual(len(self.set), self.length)
833
834    def test_self_equality(self):
835        self.assertEqual(self.set, self.set)
836
837    def test_equivalent_equality(self):
838        self.assertEqual(self.set, self.dup)
839
840    def test_copy(self):
841        self.assertEqual(self.set.copy(), self.dup)
842
843    def test_self_union(self):
844        result = self.set | self.set
845        self.assertEqual(result, self.dup)
846
847    def test_empty_union(self):
848        result = self.set | empty_set
849        self.assertEqual(result, self.dup)
850
851    def test_union_empty(self):
852        result = empty_set | self.set
853        self.assertEqual(result, self.dup)
854
855    def test_self_intersection(self):
856        result = self.set & self.set
857        self.assertEqual(result, self.dup)
858
859    def test_empty_intersection(self):
860        result = self.set & empty_set
861        self.assertEqual(result, empty_set)
862
863    def test_intersection_empty(self):
864        result = empty_set & self.set
865        self.assertEqual(result, empty_set)
866
867    def test_self_isdisjoint(self):
868        result = self.set.isdisjoint(self.set)
869        self.assertEqual(result, not self.set)
870
871    def test_empty_isdisjoint(self):
872        result = self.set.isdisjoint(empty_set)
873        self.assertEqual(result, True)
874
875    def test_isdisjoint_empty(self):
876        result = empty_set.isdisjoint(self.set)
877        self.assertEqual(result, True)
878
879    def test_self_symmetric_difference(self):
880        result = self.set ^ self.set
881        self.assertEqual(result, empty_set)
882
883    def test_empty_symmetric_difference(self):
884        result = self.set ^ empty_set
885        self.assertEqual(result, self.set)
886
887    def test_self_difference(self):
888        result = self.set - self.set
889        self.assertEqual(result, empty_set)
890
891    def test_empty_difference(self):
892        result = self.set - empty_set
893        self.assertEqual(result, self.dup)
894
895    def test_empty_difference_rev(self):
896        result = empty_set - self.set
897        self.assertEqual(result, empty_set)
898
899    def test_iteration(self):
900        for v in self.set:
901            self.assertIn(v, self.values)
902        setiter = iter(self.set)
903        self.assertEqual(setiter.__length_hint__(), len(self.set))
904
905    def test_pickling(self):
906        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
907            p = pickle.dumps(self.set, proto)
908            copy = pickle.loads(p)
909            self.assertEqual(self.set, copy,
910                             "%s != %s" % (self.set, copy))
911
912    def test_issue_37219(self):
913        with self.assertRaises(TypeError):
914            set().difference(123)
915        with self.assertRaises(TypeError):
916            set().difference_update(123)
917
918#------------------------------------------------------------------------------
919
920class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
921    def setUp(self):
922        self.case   = "empty set"
923        self.values = []
924        self.set    = set(self.values)
925        self.dup    = set(self.values)
926        self.length = 0
927        self.repr   = "set()"
928
929#------------------------------------------------------------------------------
930
931class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
932    def setUp(self):
933        self.case   = "unit set (number)"
934        self.values = [3]
935        self.set    = set(self.values)
936        self.dup    = set(self.values)
937        self.length = 1
938        self.repr   = "{3}"
939
940    def test_in(self):
941        self.assertIn(3, self.set)
942
943    def test_not_in(self):
944        self.assertNotIn(2, self.set)
945
946#------------------------------------------------------------------------------
947
948class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
949    def setUp(self):
950        self.case   = "unit set (tuple)"
951        self.values = [(0, "zero")]
952        self.set    = set(self.values)
953        self.dup    = set(self.values)
954        self.length = 1
955        self.repr   = "{(0, 'zero')}"
956
957    def test_in(self):
958        self.assertIn((0, "zero"), self.set)
959
960    def test_not_in(self):
961        self.assertNotIn(9, self.set)
962
963#------------------------------------------------------------------------------
964
965class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
966    def setUp(self):
967        self.case   = "triple set"
968        self.values = [0, "zero", operator.add]
969        self.set    = set(self.values)
970        self.dup    = set(self.values)
971        self.length = 3
972        self.repr   = None
973
974#------------------------------------------------------------------------------
975
976class TestBasicOpsString(TestBasicOps, unittest.TestCase):
977    def setUp(self):
978        self.case   = "string set"
979        self.values = ["a", "b", "c"]
980        self.set    = set(self.values)
981        self.dup    = set(self.values)
982        self.length = 3
983
984    def test_repr(self):
985        self.check_repr_against_values()
986
987#------------------------------------------------------------------------------
988
989class TestBasicOpsBytes(TestBasicOps, unittest.TestCase):
990    def setUp(self):
991        self.case   = "bytes set"
992        self.values = [b"a", b"b", b"c"]
993        self.set    = set(self.values)
994        self.dup    = set(self.values)
995        self.length = 3
996
997    def test_repr(self):
998        self.check_repr_against_values()
999
1000#------------------------------------------------------------------------------
1001
1002class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
1003    def setUp(self):
1004        self._warning_filters = warnings_helper.check_warnings()
1005        self._warning_filters.__enter__()
1006        warnings.simplefilter('ignore', BytesWarning)
1007        self.case   = "string and bytes set"
1008        self.values = ["a", "b", b"a", b"b"]
1009        self.set    = set(self.values)
1010        self.dup    = set(self.values)
1011        self.length = 4
1012
1013    def tearDown(self):
1014        self._warning_filters.__exit__(None, None, None)
1015
1016    def test_repr(self):
1017        self.check_repr_against_values()
1018
1019#==============================================================================
1020
1021def baditer():
1022    raise TypeError
1023    yield True
1024
1025def gooditer():
1026    yield True
1027
1028class TestExceptionPropagation(unittest.TestCase):
1029    """SF 628246:  Set constructor should not trap iterator TypeErrors"""
1030
1031    def test_instanceWithException(self):
1032        self.assertRaises(TypeError, set, baditer())
1033
1034    def test_instancesWithoutException(self):
1035        # All of these iterables should load without exception.
1036        set([1,2,3])
1037        set((1,2,3))
1038        set({'one':1, 'two':2, 'three':3})
1039        set(range(3))
1040        set('abc')
1041        set(gooditer())
1042
1043    def test_changingSizeWhileIterating(self):
1044        s = set([1,2,3])
1045        try:
1046            for i in s:
1047                s.update([4])
1048        except RuntimeError:
1049            pass
1050        else:
1051            self.fail("no exception when changing size during iteration")
1052
1053#==============================================================================
1054
1055class TestSetOfSets(unittest.TestCase):
1056    def test_constructor(self):
1057        inner = frozenset([1])
1058        outer = set([inner])
1059        element = outer.pop()
1060        self.assertEqual(type(element), frozenset)
1061        outer.add(inner)        # Rebuild set of sets with .add method
1062        outer.remove(inner)
1063        self.assertEqual(outer, set())   # Verify that remove worked
1064        outer.discard(inner)    # Absence of KeyError indicates working fine
1065
1066#==============================================================================
1067
1068class TestBinaryOps(unittest.TestCase):
1069    def setUp(self):
1070        self.set = set((2, 4, 6))
1071
1072    def test_eq(self):              # SF bug 643115
1073        self.assertEqual(self.set, set({2:1,4:3,6:5}))
1074
1075    def test_union_subset(self):
1076        result = self.set | set([2])
1077        self.assertEqual(result, set((2, 4, 6)))
1078
1079    def test_union_superset(self):
1080        result = self.set | set([2, 4, 6, 8])
1081        self.assertEqual(result, set([2, 4, 6, 8]))
1082
1083    def test_union_overlap(self):
1084        result = self.set | set([3, 4, 5])
1085        self.assertEqual(result, set([2, 3, 4, 5, 6]))
1086
1087    def test_union_non_overlap(self):
1088        result = self.set | set([8])
1089        self.assertEqual(result, set([2, 4, 6, 8]))
1090
1091    def test_intersection_subset(self):
1092        result = self.set & set((2, 4))
1093        self.assertEqual(result, set((2, 4)))
1094
1095    def test_intersection_superset(self):
1096        result = self.set & set([2, 4, 6, 8])
1097        self.assertEqual(result, set([2, 4, 6]))
1098
1099    def test_intersection_overlap(self):
1100        result = self.set & set([3, 4, 5])
1101        self.assertEqual(result, set([4]))
1102
1103    def test_intersection_non_overlap(self):
1104        result = self.set & set([8])
1105        self.assertEqual(result, empty_set)
1106
1107    def test_isdisjoint_subset(self):
1108        result = self.set.isdisjoint(set((2, 4)))
1109        self.assertEqual(result, False)
1110
1111    def test_isdisjoint_superset(self):
1112        result = self.set.isdisjoint(set([2, 4, 6, 8]))
1113        self.assertEqual(result, False)
1114
1115    def test_isdisjoint_overlap(self):
1116        result = self.set.isdisjoint(set([3, 4, 5]))
1117        self.assertEqual(result, False)
1118
1119    def test_isdisjoint_non_overlap(self):
1120        result = self.set.isdisjoint(set([8]))
1121        self.assertEqual(result, True)
1122
1123    def test_sym_difference_subset(self):
1124        result = self.set ^ set((2, 4))
1125        self.assertEqual(result, set([6]))
1126
1127    def test_sym_difference_superset(self):
1128        result = self.set ^ set((2, 4, 6, 8))
1129        self.assertEqual(result, set([8]))
1130
1131    def test_sym_difference_overlap(self):
1132        result = self.set ^ set((3, 4, 5))
1133        self.assertEqual(result, set([2, 3, 5, 6]))
1134
1135    def test_sym_difference_non_overlap(self):
1136        result = self.set ^ set([8])
1137        self.assertEqual(result, set([2, 4, 6, 8]))
1138
1139#==============================================================================
1140
1141class TestUpdateOps(unittest.TestCase):
1142    def setUp(self):
1143        self.set = set((2, 4, 6))
1144
1145    def test_union_subset(self):
1146        self.set |= set([2])
1147        self.assertEqual(self.set, set((2, 4, 6)))
1148
1149    def test_union_superset(self):
1150        self.set |= set([2, 4, 6, 8])
1151        self.assertEqual(self.set, set([2, 4, 6, 8]))
1152
1153    def test_union_overlap(self):
1154        self.set |= set([3, 4, 5])
1155        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
1156
1157    def test_union_non_overlap(self):
1158        self.set |= set([8])
1159        self.assertEqual(self.set, set([2, 4, 6, 8]))
1160
1161    def test_union_method_call(self):
1162        self.set.update(set([3, 4, 5]))
1163        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
1164
1165    def test_intersection_subset(self):
1166        self.set &= set((2, 4))
1167        self.assertEqual(self.set, set((2, 4)))
1168
1169    def test_intersection_superset(self):
1170        self.set &= set([2, 4, 6, 8])
1171        self.assertEqual(self.set, set([2, 4, 6]))
1172
1173    def test_intersection_overlap(self):
1174        self.set &= set([3, 4, 5])
1175        self.assertEqual(self.set, set([4]))
1176
1177    def test_intersection_non_overlap(self):
1178        self.set &= set([8])
1179        self.assertEqual(self.set, empty_set)
1180
1181    def test_intersection_method_call(self):
1182        self.set.intersection_update(set([3, 4, 5]))
1183        self.assertEqual(self.set, set([4]))
1184
1185    def test_sym_difference_subset(self):
1186        self.set ^= set((2, 4))
1187        self.assertEqual(self.set, set([6]))
1188
1189    def test_sym_difference_superset(self):
1190        self.set ^= set((2, 4, 6, 8))
1191        self.assertEqual(self.set, set([8]))
1192
1193    def test_sym_difference_overlap(self):
1194        self.set ^= set((3, 4, 5))
1195        self.assertEqual(self.set, set([2, 3, 5, 6]))
1196
1197    def test_sym_difference_non_overlap(self):
1198        self.set ^= set([8])
1199        self.assertEqual(self.set, set([2, 4, 6, 8]))
1200
1201    def test_sym_difference_method_call(self):
1202        self.set.symmetric_difference_update(set([3, 4, 5]))
1203        self.assertEqual(self.set, set([2, 3, 5, 6]))
1204
1205    def test_difference_subset(self):
1206        self.set -= set((2, 4))
1207        self.assertEqual(self.set, set([6]))
1208
1209    def test_difference_superset(self):
1210        self.set -= set((2, 4, 6, 8))
1211        self.assertEqual(self.set, set([]))
1212
1213    def test_difference_overlap(self):
1214        self.set -= set((3, 4, 5))
1215        self.assertEqual(self.set, set([2, 6]))
1216
1217    def test_difference_non_overlap(self):
1218        self.set -= set([8])
1219        self.assertEqual(self.set, set([2, 4, 6]))
1220
1221    def test_difference_method_call(self):
1222        self.set.difference_update(set([3, 4, 5]))
1223        self.assertEqual(self.set, set([2, 6]))
1224
1225#==============================================================================
1226
1227class TestMutate(unittest.TestCase):
1228    def setUp(self):
1229        self.values = ["a", "b", "c"]
1230        self.set = set(self.values)
1231
1232    def test_add_present(self):
1233        self.set.add("c")
1234        self.assertEqual(self.set, set("abc"))
1235
1236    def test_add_absent(self):
1237        self.set.add("d")
1238        self.assertEqual(self.set, set("abcd"))
1239
1240    def test_add_until_full(self):
1241        tmp = set()
1242        expected_len = 0
1243        for v in self.values:
1244            tmp.add(v)
1245            expected_len += 1
1246            self.assertEqual(len(tmp), expected_len)
1247        self.assertEqual(tmp, self.set)
1248
1249    def test_remove_present(self):
1250        self.set.remove("b")
1251        self.assertEqual(self.set, set("ac"))
1252
1253    def test_remove_absent(self):
1254        try:
1255            self.set.remove("d")
1256            self.fail("Removing missing element should have raised LookupError")
1257        except LookupError:
1258            pass
1259
1260    def test_remove_until_empty(self):
1261        expected_len = len(self.set)
1262        for v in self.values:
1263            self.set.remove(v)
1264            expected_len -= 1
1265            self.assertEqual(len(self.set), expected_len)
1266
1267    def test_discard_present(self):
1268        self.set.discard("c")
1269        self.assertEqual(self.set, set("ab"))
1270
1271    def test_discard_absent(self):
1272        self.set.discard("d")
1273        self.assertEqual(self.set, set("abc"))
1274
1275    def test_clear(self):
1276        self.set.clear()
1277        self.assertEqual(len(self.set), 0)
1278
1279    def test_pop(self):
1280        popped = {}
1281        while self.set:
1282            popped[self.set.pop()] = None
1283        self.assertEqual(len(popped), len(self.values))
1284        for v in self.values:
1285            self.assertIn(v, popped)
1286
1287    def test_update_empty_tuple(self):
1288        self.set.update(())
1289        self.assertEqual(self.set, set(self.values))
1290
1291    def test_update_unit_tuple_overlap(self):
1292        self.set.update(("a",))
1293        self.assertEqual(self.set, set(self.values))
1294
1295    def test_update_unit_tuple_non_overlap(self):
1296        self.set.update(("a", "z"))
1297        self.assertEqual(self.set, set(self.values + ["z"]))
1298
1299#==============================================================================
1300
1301class TestSubsets:
1302
1303    case2method = {"<=": "issubset",
1304                   ">=": "issuperset",
1305                  }
1306
1307    reverse = {"==": "==",
1308               "!=": "!=",
1309               "<":  ">",
1310               ">":  "<",
1311               "<=": ">=",
1312               ">=": "<=",
1313              }
1314
1315    def test_issubset(self):
1316        x = self.left
1317        y = self.right
1318        for case in "!=", "==", "<", "<=", ">", ">=":
1319            expected = case in self.cases
1320            # Test the binary infix spelling.
1321            result = eval("x" + case + "y", locals())
1322            self.assertEqual(result, expected)
1323            # Test the "friendly" method-name spelling, if one exists.
1324            if case in TestSubsets.case2method:
1325                method = getattr(x, TestSubsets.case2method[case])
1326                result = method(y)
1327                self.assertEqual(result, expected)
1328
1329            # Now do the same for the operands reversed.
1330            rcase = TestSubsets.reverse[case]
1331            result = eval("y" + rcase + "x", locals())
1332            self.assertEqual(result, expected)
1333            if rcase in TestSubsets.case2method:
1334                method = getattr(y, TestSubsets.case2method[rcase])
1335                result = method(x)
1336                self.assertEqual(result, expected)
1337#------------------------------------------------------------------------------
1338
1339class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
1340    left  = set()
1341    right = set()
1342    name  = "both empty"
1343    cases = "==", "<=", ">="
1344
1345#------------------------------------------------------------------------------
1346
1347class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
1348    left  = set([1, 2])
1349    right = set([1, 2])
1350    name  = "equal pair"
1351    cases = "==", "<=", ">="
1352
1353#------------------------------------------------------------------------------
1354
1355class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
1356    left  = set()
1357    right = set([1, 2])
1358    name  = "one empty, one non-empty"
1359    cases = "!=", "<", "<="
1360
1361#------------------------------------------------------------------------------
1362
1363class TestSubsetPartial(TestSubsets, unittest.TestCase):
1364    left  = set([1])
1365    right = set([1, 2])
1366    name  = "one a non-empty proper subset of other"
1367    cases = "!=", "<", "<="
1368
1369#------------------------------------------------------------------------------
1370
1371class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
1372    left  = set([1])
1373    right = set([2])
1374    name  = "neither empty, neither contains"
1375    cases = "!="
1376
1377#==============================================================================
1378
1379class TestOnlySetsInBinaryOps:
1380
1381    def test_eq_ne(self):
1382        # Unlike the others, this is testing that == and != *are* allowed.
1383        self.assertEqual(self.other == self.set, False)
1384        self.assertEqual(self.set == self.other, False)
1385        self.assertEqual(self.other != self.set, True)
1386        self.assertEqual(self.set != self.other, True)
1387
1388    def test_ge_gt_le_lt(self):
1389        self.assertRaises(TypeError, lambda: self.set < self.other)
1390        self.assertRaises(TypeError, lambda: self.set <= self.other)
1391        self.assertRaises(TypeError, lambda: self.set > self.other)
1392        self.assertRaises(TypeError, lambda: self.set >= self.other)
1393
1394        self.assertRaises(TypeError, lambda: self.other < self.set)
1395        self.assertRaises(TypeError, lambda: self.other <= self.set)
1396        self.assertRaises(TypeError, lambda: self.other > self.set)
1397        self.assertRaises(TypeError, lambda: self.other >= self.set)
1398
1399    def test_update_operator(self):
1400        try:
1401            self.set |= self.other
1402        except TypeError:
1403            pass
1404        else:
1405            self.fail("expected TypeError")
1406
1407    def test_update(self):
1408        if self.otherIsIterable:
1409            self.set.update(self.other)
1410        else:
1411            self.assertRaises(TypeError, self.set.update, self.other)
1412
1413    def test_union(self):
1414        self.assertRaises(TypeError, lambda: self.set | self.other)
1415        self.assertRaises(TypeError, lambda: self.other | self.set)
1416        if self.otherIsIterable:
1417            self.set.union(self.other)
1418        else:
1419            self.assertRaises(TypeError, self.set.union, self.other)
1420
1421    def test_intersection_update_operator(self):
1422        try:
1423            self.set &= self.other
1424        except TypeError:
1425            pass
1426        else:
1427            self.fail("expected TypeError")
1428
1429    def test_intersection_update(self):
1430        if self.otherIsIterable:
1431            self.set.intersection_update(self.other)
1432        else:
1433            self.assertRaises(TypeError,
1434                              self.set.intersection_update,
1435                              self.other)
1436
1437    def test_intersection(self):
1438        self.assertRaises(TypeError, lambda: self.set & self.other)
1439        self.assertRaises(TypeError, lambda: self.other & self.set)
1440        if self.otherIsIterable:
1441            self.set.intersection(self.other)
1442        else:
1443            self.assertRaises(TypeError, self.set.intersection, self.other)
1444
1445    def test_sym_difference_update_operator(self):
1446        try:
1447            self.set ^= self.other
1448        except TypeError:
1449            pass
1450        else:
1451            self.fail("expected TypeError")
1452
1453    def test_sym_difference_update(self):
1454        if self.otherIsIterable:
1455            self.set.symmetric_difference_update(self.other)
1456        else:
1457            self.assertRaises(TypeError,
1458                              self.set.symmetric_difference_update,
1459                              self.other)
1460
1461    def test_sym_difference(self):
1462        self.assertRaises(TypeError, lambda: self.set ^ self.other)
1463        self.assertRaises(TypeError, lambda: self.other ^ self.set)
1464        if self.otherIsIterable:
1465            self.set.symmetric_difference(self.other)
1466        else:
1467            self.assertRaises(TypeError, self.set.symmetric_difference, self.other)
1468
1469    def test_difference_update_operator(self):
1470        try:
1471            self.set -= self.other
1472        except TypeError:
1473            pass
1474        else:
1475            self.fail("expected TypeError")
1476
1477    def test_difference_update(self):
1478        if self.otherIsIterable:
1479            self.set.difference_update(self.other)
1480        else:
1481            self.assertRaises(TypeError,
1482                              self.set.difference_update,
1483                              self.other)
1484
1485    def test_difference(self):
1486        self.assertRaises(TypeError, lambda: self.set - self.other)
1487        self.assertRaises(TypeError, lambda: self.other - self.set)
1488        if self.otherIsIterable:
1489            self.set.difference(self.other)
1490        else:
1491            self.assertRaises(TypeError, self.set.difference, self.other)
1492
1493#------------------------------------------------------------------------------
1494
1495class TestOnlySetsNumeric(TestOnlySetsInBinaryOps, unittest.TestCase):
1496    def setUp(self):
1497        self.set   = set((1, 2, 3))
1498        self.other = 19
1499        self.otherIsIterable = False
1500
1501#------------------------------------------------------------------------------
1502
1503class TestOnlySetsDict(TestOnlySetsInBinaryOps, unittest.TestCase):
1504    def setUp(self):
1505        self.set   = set((1, 2, 3))
1506        self.other = {1:2, 3:4}
1507        self.otherIsIterable = True
1508
1509#------------------------------------------------------------------------------
1510
1511class TestOnlySetsOperator(TestOnlySetsInBinaryOps, unittest.TestCase):
1512    def setUp(self):
1513        self.set   = set((1, 2, 3))
1514        self.other = operator.add
1515        self.otherIsIterable = False
1516
1517#------------------------------------------------------------------------------
1518
1519class TestOnlySetsTuple(TestOnlySetsInBinaryOps, unittest.TestCase):
1520    def setUp(self):
1521        self.set   = set((1, 2, 3))
1522        self.other = (2, 4, 6)
1523        self.otherIsIterable = True
1524
1525#------------------------------------------------------------------------------
1526
1527class TestOnlySetsString(TestOnlySetsInBinaryOps, unittest.TestCase):
1528    def setUp(self):
1529        self.set   = set((1, 2, 3))
1530        self.other = 'abc'
1531        self.otherIsIterable = True
1532
1533#------------------------------------------------------------------------------
1534
1535class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
1536    def setUp(self):
1537        def gen():
1538            for i in range(0, 10, 2):
1539                yield i
1540        self.set   = set((1, 2, 3))
1541        self.other = gen()
1542        self.otherIsIterable = True
1543
1544#==============================================================================
1545
1546class TestCopying:
1547
1548    def test_copy(self):
1549        dup = self.set.copy()
1550        dup_list = sorted(dup, key=repr)
1551        set_list = sorted(self.set, key=repr)
1552        self.assertEqual(len(dup_list), len(set_list))
1553        for i in range(len(dup_list)):
1554            self.assertTrue(dup_list[i] is set_list[i])
1555
1556    def test_deep_copy(self):
1557        dup = copy.deepcopy(self.set)
1558        ##print type(dup), repr(dup)
1559        dup_list = sorted(dup, key=repr)
1560        set_list = sorted(self.set, key=repr)
1561        self.assertEqual(len(dup_list), len(set_list))
1562        for i in range(len(dup_list)):
1563            self.assertEqual(dup_list[i], set_list[i])
1564
1565#------------------------------------------------------------------------------
1566
1567class TestCopyingEmpty(TestCopying, unittest.TestCase):
1568    def setUp(self):
1569        self.set = set()
1570
1571#------------------------------------------------------------------------------
1572
1573class TestCopyingSingleton(TestCopying, unittest.TestCase):
1574    def setUp(self):
1575        self.set = set(["hello"])
1576
1577#------------------------------------------------------------------------------
1578
1579class TestCopyingTriple(TestCopying, unittest.TestCase):
1580    def setUp(self):
1581        self.set = set(["zero", 0, None])
1582
1583#------------------------------------------------------------------------------
1584
1585class TestCopyingTuple(TestCopying, unittest.TestCase):
1586    def setUp(self):
1587        self.set = set([(1, 2)])
1588
1589#------------------------------------------------------------------------------
1590
1591class TestCopyingNested(TestCopying, unittest.TestCase):
1592    def setUp(self):
1593        self.set = set([((1, 2), (3, 4))])
1594
1595#==============================================================================
1596
1597class TestIdentities(unittest.TestCase):
1598    def setUp(self):
1599        self.a = set('abracadabra')
1600        self.b = set('alacazam')
1601
1602    def test_binopsVsSubsets(self):
1603        a, b = self.a, self.b
1604        self.assertTrue(a - b < a)
1605        self.assertTrue(b - a < b)
1606        self.assertTrue(a & b < a)
1607        self.assertTrue(a & b < b)
1608        self.assertTrue(a | b > a)
1609        self.assertTrue(a | b > b)
1610        self.assertTrue(a ^ b < a | b)
1611
1612    def test_commutativity(self):
1613        a, b = self.a, self.b
1614        self.assertEqual(a&b, b&a)
1615        self.assertEqual(a|b, b|a)
1616        self.assertEqual(a^b, b^a)
1617        if a != b:
1618            self.assertNotEqual(a-b, b-a)
1619
1620    def test_summations(self):
1621        # check that sums of parts equal the whole
1622        a, b = self.a, self.b
1623        self.assertEqual((a-b)|(a&b)|(b-a), a|b)
1624        self.assertEqual((a&b)|(a^b), a|b)
1625        self.assertEqual(a|(b-a), a|b)
1626        self.assertEqual((a-b)|b, a|b)
1627        self.assertEqual((a-b)|(a&b), a)
1628        self.assertEqual((b-a)|(a&b), b)
1629        self.assertEqual((a-b)|(b-a), a^b)
1630
1631    def test_exclusion(self):
1632        # check that inverse operations show non-overlap
1633        a, b, zero = self.a, self.b, set()
1634        self.assertEqual((a-b)&b, zero)
1635        self.assertEqual((b-a)&a, zero)
1636        self.assertEqual((a&b)&(a^b), zero)
1637
1638# Tests derived from test_itertools.py =======================================
1639
1640def R(seqn):
1641    'Regular generator'
1642    for i in seqn:
1643        yield i
1644
1645class G:
1646    'Sequence using __getitem__'
1647    def __init__(self, seqn):
1648        self.seqn = seqn
1649    def __getitem__(self, i):
1650        return self.seqn[i]
1651
1652class I:
1653    'Sequence using iterator protocol'
1654    def __init__(self, seqn):
1655        self.seqn = seqn
1656        self.i = 0
1657    def __iter__(self):
1658        return self
1659    def __next__(self):
1660        if self.i >= len(self.seqn): raise StopIteration
1661        v = self.seqn[self.i]
1662        self.i += 1
1663        return v
1664
1665class Ig:
1666    'Sequence using iterator protocol defined with a generator'
1667    def __init__(self, seqn):
1668        self.seqn = seqn
1669        self.i = 0
1670    def __iter__(self):
1671        for val in self.seqn:
1672            yield val
1673
1674class X:
1675    'Missing __getitem__ and __iter__'
1676    def __init__(self, seqn):
1677        self.seqn = seqn
1678        self.i = 0
1679    def __next__(self):
1680        if self.i >= len(self.seqn): raise StopIteration
1681        v = self.seqn[self.i]
1682        self.i += 1
1683        return v
1684
1685class N:
1686    'Iterator missing __next__()'
1687    def __init__(self, seqn):
1688        self.seqn = seqn
1689        self.i = 0
1690    def __iter__(self):
1691        return self
1692
1693class E:
1694    'Test propagation of exceptions'
1695    def __init__(self, seqn):
1696        self.seqn = seqn
1697        self.i = 0
1698    def __iter__(self):
1699        return self
1700    def __next__(self):
1701        3 // 0
1702
1703class S:
1704    'Test immediate stop'
1705    def __init__(self, seqn):
1706        pass
1707    def __iter__(self):
1708        return self
1709    def __next__(self):
1710        raise StopIteration
1711
1712from itertools import chain
1713def L(seqn):
1714    'Test multiple tiers of iterators'
1715    return chain(map(lambda x:x, R(Ig(G(seqn)))))
1716
1717class TestVariousIteratorArgs(unittest.TestCase):
1718
1719    def test_constructor(self):
1720        for cons in (set, frozenset):
1721            for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
1722                for g in (G, I, Ig, S, L, R):
1723                    self.assertEqual(sorted(cons(g(s)), key=repr), sorted(g(s), key=repr))
1724                self.assertRaises(TypeError, cons , X(s))
1725                self.assertRaises(TypeError, cons , N(s))
1726                self.assertRaises(ZeroDivisionError, cons , E(s))
1727
1728    def test_inline_methods(self):
1729        s = set('november')
1730        for data in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5), 'december'):
1731            for meth in (s.union, s.intersection, s.difference, s.symmetric_difference, s.isdisjoint):
1732                for g in (G, I, Ig, L, R):
1733                    expected = meth(data)
1734                    actual = meth(g(data))
1735                    if isinstance(expected, bool):
1736                        self.assertEqual(actual, expected)
1737                    else:
1738                        self.assertEqual(sorted(actual, key=repr), sorted(expected, key=repr))
1739                self.assertRaises(TypeError, meth, X(s))
1740                self.assertRaises(TypeError, meth, N(s))
1741                self.assertRaises(ZeroDivisionError, meth, E(s))
1742
1743    def test_inplace_methods(self):
1744        for data in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5), 'december'):
1745            for methname in ('update', 'intersection_update',
1746                             'difference_update', 'symmetric_difference_update'):
1747                for g in (G, I, Ig, S, L, R):
1748                    s = set('january')
1749                    t = s.copy()
1750                    getattr(s, methname)(list(g(data)))
1751                    getattr(t, methname)(g(data))
1752                    self.assertEqual(sorted(s, key=repr), sorted(t, key=repr))
1753
1754                self.assertRaises(TypeError, getattr(set('january'), methname), X(data))
1755                self.assertRaises(TypeError, getattr(set('january'), methname), N(data))
1756                self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data))
1757
1758class bad_eq:
1759    def __eq__(self, other):
1760        if be_bad:
1761            set2.clear()
1762            raise ZeroDivisionError
1763        return self is other
1764    def __hash__(self):
1765        return 0
1766
1767class bad_dict_clear:
1768    def __eq__(self, other):
1769        if be_bad:
1770            dict2.clear()
1771        return self is other
1772    def __hash__(self):
1773        return 0
1774
1775class TestWeirdBugs(unittest.TestCase):
1776    def test_8420_set_merge(self):
1777        # This used to segfault
1778        global be_bad, set2, dict2
1779        be_bad = False
1780        set1 = {bad_eq()}
1781        set2 = {bad_eq() for i in range(75)}
1782        be_bad = True
1783        self.assertRaises(ZeroDivisionError, set1.update, set2)
1784
1785        be_bad = False
1786        set1 = {bad_dict_clear()}
1787        dict2 = {bad_dict_clear(): None}
1788        be_bad = True
1789        set1.symmetric_difference_update(dict2)
1790
1791    def test_iter_and_mutate(self):
1792        # Issue #24581
1793        s = set(range(100))
1794        s.clear()
1795        s.update(range(100))
1796        si = iter(s)
1797        s.clear()
1798        a = list(range(100))
1799        s.update(range(100))
1800        list(si)
1801
1802    def test_merge_and_mutate(self):
1803        class X:
1804            def __hash__(self):
1805                return hash(0)
1806            def __eq__(self, o):
1807                other.clear()
1808                return False
1809
1810        other = set()
1811        other = {X() for i in range(10)}
1812        s = {0}
1813        s.update(other)
1814
1815# Application tests (based on David Eppstein's graph recipes ====================================
1816
1817def powerset(U):
1818    """Generates all subsets of a set or sequence U."""
1819    U = iter(U)
1820    try:
1821        x = frozenset([next(U)])
1822        for S in powerset(U):
1823            yield S
1824            yield S | x
1825    except StopIteration:
1826        yield frozenset()
1827
1828def cube(n):
1829    """Graph of n-dimensional hypercube."""
1830    singletons = [frozenset([x]) for x in range(n)]
1831    return dict([(x, frozenset([x^s for s in singletons]))
1832                 for x in powerset(range(n))])
1833
1834def linegraph(G):
1835    """Graph, the vertices of which are edges of G,
1836    with two vertices being adjacent iff the corresponding
1837    edges share a vertex."""
1838    L = {}
1839    for x in G:
1840        for y in G[x]:
1841            nx = [frozenset([x,z]) for z in G[x] if z != y]
1842            ny = [frozenset([y,z]) for z in G[y] if z != x]
1843            L[frozenset([x,y])] = frozenset(nx+ny)
1844    return L
1845
1846def faces(G):
1847    'Return a set of faces in G.  Where a face is a set of vertices on that face'
1848    # currently limited to triangles,squares, and pentagons
1849    f = set()
1850    for v1, edges in G.items():
1851        for v2 in edges:
1852            for v3 in G[v2]:
1853                if v1 == v3:
1854                    continue
1855                if v1 in G[v3]:
1856                    f.add(frozenset([v1, v2, v3]))
1857                else:
1858                    for v4 in G[v3]:
1859                        if v4 == v2:
1860                            continue
1861                        if v1 in G[v4]:
1862                            f.add(frozenset([v1, v2, v3, v4]))
1863                        else:
1864                            for v5 in G[v4]:
1865                                if v5 == v3 or v5 == v2:
1866                                    continue
1867                                if v1 in G[v5]:
1868                                    f.add(frozenset([v1, v2, v3, v4, v5]))
1869    return f
1870
1871
1872class TestGraphs(unittest.TestCase):
1873
1874    def test_cube(self):
1875
1876        g = cube(3)                             # vert --> {v1, v2, v3}
1877        vertices1 = set(g)
1878        self.assertEqual(len(vertices1), 8)     # eight vertices
1879        for edge in g.values():
1880            self.assertEqual(len(edge), 3)      # each vertex connects to three edges
1881        vertices2 = set(v for edges in g.values() for v in edges)
1882        self.assertEqual(vertices1, vertices2)  # edge vertices in original set
1883
1884        cubefaces = faces(g)
1885        self.assertEqual(len(cubefaces), 6)     # six faces
1886        for face in cubefaces:
1887            self.assertEqual(len(face), 4)      # each face is a square
1888
1889    def test_cuboctahedron(self):
1890
1891        # http://en.wikipedia.org/wiki/Cuboctahedron
1892        # 8 triangular faces and 6 square faces
1893        # 12 identical vertices each connecting a triangle and square
1894
1895        g = cube(3)
1896        cuboctahedron = linegraph(g)            # V( --> {V1, V2, V3, V4}
1897        self.assertEqual(len(cuboctahedron), 12)# twelve vertices
1898
1899        vertices = set(cuboctahedron)
1900        for edges in cuboctahedron.values():
1901            self.assertEqual(len(edges), 4)     # each vertex connects to four other vertices
1902        othervertices = set(edge for edges in cuboctahedron.values() for edge in edges)
1903        self.assertEqual(vertices, othervertices)   # edge vertices in original set
1904
1905        cubofaces = faces(cuboctahedron)
1906        facesizes = collections.defaultdict(int)
1907        for face in cubofaces:
1908            facesizes[len(face)] += 1
1909        self.assertEqual(facesizes[3], 8)       # eight triangular faces
1910        self.assertEqual(facesizes[4], 6)       # six square faces
1911
1912        for vertex in cuboctahedron:
1913            edge = vertex                       # Cuboctahedron vertices are edges in Cube
1914            self.assertEqual(len(edge), 2)      # Two cube vertices define an edge
1915            for cubevert in edge:
1916                self.assertIn(cubevert, g)
1917
1918
1919#==============================================================================
1920
1921if __name__ == "__main__":
1922    unittest.main()
1923