1""" Test Iterator Length Transparency
2
3Some functions or methods which accept general iterable arguments have
4optional, more efficient code paths if they know how many items to expect.
5For instance, map(func, iterable), will pre-allocate the exact amount of
6space required whenever the iterable can report its length.
7
8The desired invariant is:  len(it)==len(list(it)).
9
10A complication is that an iterable and iterator can be the same object. To
11maintain the invariant, an iterator needs to dynamically update its length.
12For instance, an iterable such as range(10) always reports its length as ten,
13but it=iter(range(10)) starts at ten, and then goes to nine after next(it).
14Having this capability means that map() can ignore the distinction between
15map(func, iterable) and map(func, iter(iterable)).
16
17When the iterable is immutable, the implementation can straight-forwardly
18report the original length minus the cumulative number of calls to next().
19This is the case for tuples, range objects, and itertools.repeat().
20
21Some containers become temporarily immutable during iteration.  This includes
22dicts, sets, and collections.deque.  Their implementation is equally simple
23though they need to permanently set their length to zero whenever there is
24an attempt to iterate after a length mutation.
25
26The situation slightly more involved whenever an object allows length mutation
27during iteration.  Lists and sequence iterators are dynamically updatable.
28So, if a list is extended during iteration, the iterator will continue through
29the new items.  If it shrinks to a point before the most recent iteration,
30then no further items are available and the length is reported at zero.
31
32Reversed objects can also be wrapped around mutable objects; however, any
33appends after the current position are ignored.  Any other approach leads
34to confusion and possibly returning the same item more than once.
35
36The iterators not listed above, such as enumerate and the other itertools,
37are not length transparent because they have no way to distinguish between
38iterables that report static length and iterators whose length changes with
39each call (i.e. the difference between enumerate('abc') and
40enumerate(iter('abc')).
41
42"""
43
44import unittest
45from itertools import repeat
46from collections import deque
47from operator import length_hint
48
49n = 10
50
51
52class TestInvariantWithoutMutations:
53
54    def test_invariant(self):
55        it = self.it
56        for i in reversed(range(1, n+1)):
57            self.assertEqual(length_hint(it), i)
58            next(it)
59        self.assertEqual(length_hint(it), 0)
60        self.assertRaises(StopIteration, next, it)
61        self.assertEqual(length_hint(it), 0)
62
63class TestTemporarilyImmutable(TestInvariantWithoutMutations):
64
65    def test_immutable_during_iteration(self):
66        # objects such as deques, sets, and dictionaries enforce
67        # length immutability  during iteration
68
69        it = self.it
70        self.assertEqual(length_hint(it), n)
71        next(it)
72        self.assertEqual(length_hint(it), n-1)
73        self.mutate()
74        self.assertRaises(RuntimeError, next, it)
75        self.assertEqual(length_hint(it), 0)
76
77## ------- Concrete Type Tests -------
78
79class TestRepeat(TestInvariantWithoutMutations, unittest.TestCase):
80
81    def setUp(self):
82        self.it = repeat(None, n)
83
84class TestXrange(TestInvariantWithoutMutations, unittest.TestCase):
85
86    def setUp(self):
87        self.it = iter(range(n))
88
89class TestXrangeCustomReversed(TestInvariantWithoutMutations, unittest.TestCase):
90
91    def setUp(self):
92        self.it = reversed(range(n))
93
94class TestTuple(TestInvariantWithoutMutations, unittest.TestCase):
95
96    def setUp(self):
97        self.it = iter(tuple(range(n)))
98
99## ------- Types that should not be mutated during iteration -------
100
101class TestDeque(TestTemporarilyImmutable, unittest.TestCase):
102
103    def setUp(self):
104        d = deque(range(n))
105        self.it = iter(d)
106        self.mutate = d.pop
107
108class TestDequeReversed(TestTemporarilyImmutable, unittest.TestCase):
109
110    def setUp(self):
111        d = deque(range(n))
112        self.it = reversed(d)
113        self.mutate = d.pop
114
115class TestDictKeys(TestTemporarilyImmutable, unittest.TestCase):
116
117    def setUp(self):
118        d = dict.fromkeys(range(n))
119        self.it = iter(d)
120        self.mutate = d.popitem
121
122class TestDictItems(TestTemporarilyImmutable, unittest.TestCase):
123
124    def setUp(self):
125        d = dict.fromkeys(range(n))
126        self.it = iter(d.items())
127        self.mutate = d.popitem
128
129class TestDictValues(TestTemporarilyImmutable, unittest.TestCase):
130
131    def setUp(self):
132        d = dict.fromkeys(range(n))
133        self.it = iter(d.values())
134        self.mutate = d.popitem
135
136class TestSet(TestTemporarilyImmutable, unittest.TestCase):
137
138    def setUp(self):
139        d = set(range(n))
140        self.it = iter(d)
141        self.mutate = d.pop
142
143## ------- Types that can mutate during iteration -------
144
145class TestList(TestInvariantWithoutMutations, unittest.TestCase):
146
147    def setUp(self):
148        self.it = iter(range(n))
149
150    def test_mutation(self):
151        d = list(range(n))
152        it = iter(d)
153        next(it)
154        next(it)
155        self.assertEqual(length_hint(it), n - 2)
156        d.append(n)
157        self.assertEqual(length_hint(it), n - 1)  # grow with append
158        d[1:] = []
159        self.assertEqual(length_hint(it), 0)
160        self.assertEqual(list(it), [])
161        d.extend(range(20))
162        self.assertEqual(length_hint(it), 0)
163
164
165class TestListReversed(TestInvariantWithoutMutations, unittest.TestCase):
166
167    def setUp(self):
168        self.it = reversed(range(n))
169
170    def test_mutation(self):
171        d = list(range(n))
172        it = reversed(d)
173        next(it)
174        next(it)
175        self.assertEqual(length_hint(it), n - 2)
176        d.append(n)
177        self.assertEqual(length_hint(it), n - 2)  # ignore append
178        d[1:] = []
179        self.assertEqual(length_hint(it), 0)
180        self.assertEqual(list(it), [])  # confirm invariant
181        d.extend(range(20))
182        self.assertEqual(length_hint(it), 0)
183
184## -- Check to make sure exceptions are not suppressed by __length_hint__()
185
186
187class BadLen(object):
188    def __iter__(self):
189        return iter(range(10))
190
191    def __len__(self):
192        raise RuntimeError('hello')
193
194
195class BadLengthHint(object):
196    def __iter__(self):
197        return iter(range(10))
198
199    def __length_hint__(self):
200        raise RuntimeError('hello')
201
202
203class NoneLengthHint(object):
204    def __iter__(self):
205        return iter(range(10))
206
207    def __length_hint__(self):
208        return NotImplemented
209
210
211class TestLengthHintExceptions(unittest.TestCase):
212
213    def test_issue1242657(self):
214        self.assertRaises(RuntimeError, list, BadLen())
215        self.assertRaises(RuntimeError, list, BadLengthHint())
216        self.assertRaises(RuntimeError, [].extend, BadLen())
217        self.assertRaises(RuntimeError, [].extend, BadLengthHint())
218        b = bytearray(range(10))
219        self.assertRaises(RuntimeError, b.extend, BadLen())
220        self.assertRaises(RuntimeError, b.extend, BadLengthHint())
221
222    def test_invalid_hint(self):
223        # Make sure an invalid result doesn't muck-up the works
224        self.assertEqual(list(NoneLengthHint()), list(range(10)))
225
226
227if __name__ == "__main__":
228    unittest.main()
229