1import unittest
2from test import test_support
3import operator
4from sys import maxint
5maxsize = test_support.MAX_Py_ssize_t
6minsize = -maxsize-1
7
8class oldstyle:
9    def __index__(self):
10        return self.ind
11
12class newstyle(object):
13    def __index__(self):
14        return self.ind
15
16class TrapInt(int):
17    def __index__(self):
18        return self
19
20class TrapLong(long):
21    def __index__(self):
22        return self
23
24class BaseTestCase(unittest.TestCase):
25    def setUp(self):
26        self.o = oldstyle()
27        self.n = newstyle()
28
29    def test_basic(self):
30        self.o.ind = -2
31        self.n.ind = 2
32        self.assertEqual(operator.index(self.o), -2)
33        self.assertEqual(operator.index(self.n), 2)
34
35    def test_slice(self):
36        self.o.ind = 1
37        self.n.ind = 2
38        slc = slice(self.o, self.o, self.o)
39        check_slc = slice(1, 1, 1)
40        self.assertEqual(slc.indices(self.o), check_slc.indices(1))
41        slc = slice(self.n, self.n, self.n)
42        check_slc = slice(2, 2, 2)
43        self.assertEqual(slc.indices(self.n), check_slc.indices(2))
44
45    def test_wrappers(self):
46        self.o.ind = 4
47        self.n.ind = 5
48        self.assertEqual(6 .__index__(), 6)
49        self.assertEqual(-7L.__index__(), -7)
50        self.assertEqual(self.o.__index__(), 4)
51        self.assertEqual(self.n.__index__(), 5)
52        self.assertEqual(True.__index__(), 1)
53        self.assertEqual(False.__index__(), 0)
54
55    def test_subclasses(self):
56        r = range(10)
57        self.assertEqual(r[TrapInt(5):TrapInt(10)], r[5:10])
58        self.assertEqual(r[TrapLong(5):TrapLong(10)], r[5:10])
59        self.assertEqual(slice(TrapInt()).indices(0), (0,0,1))
60        self.assertEqual(slice(TrapLong(0)).indices(0), (0,0,1))
61
62    def test_error(self):
63        self.o.ind = 'dumb'
64        self.n.ind = 'bad'
65        self.assertRaises(TypeError, operator.index, self.o)
66        self.assertRaises(TypeError, operator.index, self.n)
67        self.assertRaises(TypeError, slice(self.o).indices, 0)
68        self.assertRaises(TypeError, slice(self.n).indices, 0)
69
70
71class SeqTestCase(unittest.TestCase):
72    # This test case isn't run directly. It just defines common tests
73    # to the different sequence types below
74    def setUp(self):
75        self.o = oldstyle()
76        self.n = newstyle()
77        self.o2 = oldstyle()
78        self.n2 = newstyle()
79
80    def test_index(self):
81        self.o.ind = -2
82        self.n.ind = 2
83        self.assertEqual(self.seq[self.n], self.seq[2])
84        self.assertEqual(self.seq[self.o], self.seq[-2])
85
86    def test_slice(self):
87        self.o.ind = 1
88        self.o2.ind = 3
89        self.n.ind = 2
90        self.n2.ind = 4
91        self.assertEqual(self.seq[self.o:self.o2], self.seq[1:3])
92        self.assertEqual(self.seq[self.n:self.n2], self.seq[2:4])
93
94    def test_slice_bug7532a(self):
95        seqlen = len(self.seq)
96        self.o.ind = int(seqlen * 1.5)
97        self.n.ind = seqlen + 2
98        self.assertEqual(self.seq[self.o:], self.seq[0:0])
99        self.assertEqual(self.seq[:self.o], self.seq)
100        self.assertEqual(self.seq[self.n:], self.seq[0:0])
101        self.assertEqual(self.seq[:self.n], self.seq)
102
103    def test_slice_bug7532b(self):
104        if isinstance(self.seq, ClassicSeq):
105            self.skipTest('test fails for ClassicSeq')
106        # These tests fail for ClassicSeq (see bug #7532)
107        seqlen = len(self.seq)
108        self.o2.ind = -seqlen - 2
109        self.n2.ind = -int(seqlen * 1.5)
110        self.assertEqual(self.seq[self.o2:], self.seq)
111        self.assertEqual(self.seq[:self.o2], self.seq[0:0])
112        self.assertEqual(self.seq[self.n2:], self.seq)
113        self.assertEqual(self.seq[:self.n2], self.seq[0:0])
114
115    def test_repeat(self):
116        self.o.ind = 3
117        self.n.ind = 2
118        self.assertEqual(self.seq * self.o, self.seq * 3)
119        self.assertEqual(self.seq * self.n, self.seq * 2)
120        self.assertEqual(self.o * self.seq, self.seq * 3)
121        self.assertEqual(self.n * self.seq, self.seq * 2)
122
123    def test_wrappers(self):
124        self.o.ind = 4
125        self.n.ind = 5
126        self.assertEqual(self.seq.__getitem__(self.o), self.seq[4])
127        self.assertEqual(self.seq.__mul__(self.o), self.seq * 4)
128        self.assertEqual(self.seq.__rmul__(self.o), self.seq * 4)
129        self.assertEqual(self.seq.__getitem__(self.n), self.seq[5])
130        self.assertEqual(self.seq.__mul__(self.n), self.seq * 5)
131        self.assertEqual(self.seq.__rmul__(self.n), self.seq * 5)
132
133    def test_subclasses(self):
134        self.assertEqual(self.seq[TrapInt()], self.seq[0])
135        self.assertEqual(self.seq[TrapLong()], self.seq[0])
136
137    def test_error(self):
138        self.o.ind = 'dumb'
139        self.n.ind = 'bad'
140        indexobj = lambda x, obj: obj.seq[x]
141        self.assertRaises(TypeError, indexobj, self.o, self)
142        self.assertRaises(TypeError, indexobj, self.n, self)
143        sliceobj = lambda x, obj: obj.seq[x:]
144        self.assertRaises(TypeError, sliceobj, self.o, self)
145        self.assertRaises(TypeError, sliceobj, self.n, self)
146
147
148class ListTestCase(SeqTestCase):
149    seq = [0,10,20,30,40,50]
150
151    def test_setdelitem(self):
152        self.o.ind = -2
153        self.n.ind = 2
154        lst = list('ab!cdefghi!j')
155        del lst[self.o]
156        del lst[self.n]
157        lst[self.o] = 'X'
158        lst[self.n] = 'Y'
159        self.assertEqual(lst, list('abYdefghXj'))
160
161        lst = [5, 6, 7, 8, 9, 10, 11]
162        lst.__setitem__(self.n, "here")
163        self.assertEqual(lst, [5, 6, "here", 8, 9, 10, 11])
164        lst.__delitem__(self.n)
165        self.assertEqual(lst, [5, 6, 8, 9, 10, 11])
166
167    def test_inplace_repeat(self):
168        self.o.ind = 2
169        self.n.ind = 3
170        lst = [6, 4]
171        lst *= self.o
172        self.assertEqual(lst, [6, 4, 6, 4])
173        lst *= self.n
174        self.assertEqual(lst, [6, 4, 6, 4] * 3)
175
176        lst = [5, 6, 7, 8, 9, 11]
177        l2 = lst.__imul__(self.n)
178        self.assertIs(l2, lst)
179        self.assertEqual(lst, [5, 6, 7, 8, 9, 11] * 3)
180
181
182class _BaseSeq:
183
184    def __init__(self, iterable):
185        self._list = list(iterable)
186
187    def __repr__(self):
188        return repr(self._list)
189
190    def __eq__(self, other):
191        return self._list == other
192
193    def __len__(self):
194        return len(self._list)
195
196    def __mul__(self, n):
197        return self.__class__(self._list*n)
198    __rmul__ = __mul__
199
200    def __getitem__(self, index):
201        return self._list[index]
202
203
204class _GetSliceMixin:
205
206    def __getslice__(self, i, j):
207        return self._list.__getslice__(i, j)
208
209
210class ClassicSeq(_BaseSeq): pass
211class NewSeq(_BaseSeq, object): pass
212class ClassicSeqDeprecated(_GetSliceMixin, ClassicSeq): pass
213class NewSeqDeprecated(_GetSliceMixin, NewSeq): pass
214
215
216class TupleTestCase(SeqTestCase):
217    seq = (0,10,20,30,40,50)
218
219class StringTestCase(SeqTestCase):
220    seq = "this is a test"
221
222class ByteArrayTestCase(SeqTestCase):
223    seq = bytearray("this is a test")
224
225class UnicodeTestCase(SeqTestCase):
226    seq = u"this is a test"
227
228class ClassicSeqTestCase(SeqTestCase):
229    seq = ClassicSeq((0,10,20,30,40,50))
230
231class NewSeqTestCase(SeqTestCase):
232    seq = NewSeq((0,10,20,30,40,50))
233
234class ClassicSeqDeprecatedTestCase(SeqTestCase):
235    seq = ClassicSeqDeprecated((0,10,20,30,40,50))
236
237class NewSeqDeprecatedTestCase(SeqTestCase):
238    seq = NewSeqDeprecated((0,10,20,30,40,50))
239
240
241class XRangeTestCase(unittest.TestCase):
242
243    def test_xrange(self):
244        n = newstyle()
245        n.ind = 5
246        self.assertEqual(xrange(1, 20)[n], 6)
247        self.assertEqual(xrange(1, 20).__getitem__(n), 6)
248
249class OverflowTestCase(unittest.TestCase):
250
251    def setUp(self):
252        self.pos = 2**100
253        self.neg = -self.pos
254
255    def test_large_longs(self):
256        self.assertEqual(self.pos.__index__(), self.pos)
257        self.assertEqual(self.neg.__index__(), self.neg)
258
259    def _getitem_helper(self, base):
260        class GetItem(base):
261            def __len__(self):
262                return maxint # cannot return long here
263            def __getitem__(self, key):
264                return key
265        x = GetItem()
266        self.assertEqual(x[self.pos], self.pos)
267        self.assertEqual(x[self.neg], self.neg)
268        self.assertEqual(x[self.neg:self.pos].indices(maxsize),
269                         (0, maxsize, 1))
270        self.assertEqual(x[self.neg:self.pos:1].indices(maxsize),
271                         (0, maxsize, 1))
272
273    def _getslice_helper_deprecated(self, base):
274        class GetItem(base):
275            def __len__(self):
276                return maxint # cannot return long here
277            def __getitem__(self, key):
278                return key
279            def __getslice__(self, i, j):
280                return i, j
281        x = GetItem()
282        self.assertEqual(x[self.pos], self.pos)
283        self.assertEqual(x[self.neg], self.neg)
284        self.assertEqual(x[self.neg:self.pos], (maxint+minsize, maxsize))
285        self.assertEqual(x[self.neg:self.pos:1].indices(maxsize),
286                         (0, maxsize, 1))
287
288    def test_getitem(self):
289        self._getitem_helper(object)
290        with test_support.check_py3k_warnings():
291            self._getslice_helper_deprecated(object)
292
293    def test_getitem_classic(self):
294        class Empty: pass
295        # XXX This test fails (see bug #7532)
296        #self._getitem_helper(Empty)
297        with test_support.check_py3k_warnings():
298            self._getslice_helper_deprecated(Empty)
299
300    def test_sequence_repeat(self):
301        self.assertRaises(OverflowError, lambda: "a" * self.pos)
302        self.assertRaises(OverflowError, lambda: "a" * self.neg)
303
304
305def test_main():
306    test_support.run_unittest(
307        BaseTestCase,
308        ListTestCase,
309        TupleTestCase,
310        ByteArrayTestCase,
311        StringTestCase,
312        UnicodeTestCase,
313        ClassicSeqTestCase,
314        NewSeqTestCase,
315        XRangeTestCase,
316        OverflowTestCase,
317    )
318    with test_support.check_py3k_warnings():
319        test_support.run_unittest(
320            ClassicSeqDeprecatedTestCase,
321            NewSeqDeprecatedTestCase,
322        )
323
324
325if __name__ == "__main__":
326    test_main()
327