1"""Unit tests for the memoryview
2
3   Some tests are in test_bytes. Many tests that require _testbuffer.ndarray
4   are in test_buffer.
5"""
6
7import unittest
8import test.support
9import sys
10import gc
11import weakref
12import array
13import io
14import copy
15import pickle
16
17from test.support import import_helper
18
19
20class AbstractMemoryTests:
21    source_bytes = b"abcdef"
22
23    @property
24    def _source(self):
25        return self.source_bytes
26
27    @property
28    def _types(self):
29        return filter(None, [self.ro_type, self.rw_type])
30
31    def check_getitem_with_type(self, tp):
32        b = tp(self._source)
33        oldrefcount = sys.getrefcount(b)
34        m = self._view(b)
35        self.assertEqual(m[0], ord(b"a"))
36        self.assertIsInstance(m[0], int)
37        self.assertEqual(m[5], ord(b"f"))
38        self.assertEqual(m[-1], ord(b"f"))
39        self.assertEqual(m[-6], ord(b"a"))
40        # Bounds checking
41        self.assertRaises(IndexError, lambda: m[6])
42        self.assertRaises(IndexError, lambda: m[-7])
43        self.assertRaises(IndexError, lambda: m[sys.maxsize])
44        self.assertRaises(IndexError, lambda: m[-sys.maxsize])
45        # Type checking
46        self.assertRaises(TypeError, lambda: m[None])
47        self.assertRaises(TypeError, lambda: m[0.0])
48        self.assertRaises(TypeError, lambda: m["a"])
49        m = None
50        self.assertEqual(sys.getrefcount(b), oldrefcount)
51
52    def test_getitem(self):
53        for tp in self._types:
54            self.check_getitem_with_type(tp)
55
56    def test_iter(self):
57        for tp in self._types:
58            b = tp(self._source)
59            m = self._view(b)
60            self.assertEqual(list(m), [m[i] for i in range(len(m))])
61
62    def test_setitem_readonly(self):
63        if not self.ro_type:
64            self.skipTest("no read-only type to test")
65        b = self.ro_type(self._source)
66        oldrefcount = sys.getrefcount(b)
67        m = self._view(b)
68        def setitem(value):
69            m[0] = value
70        self.assertRaises(TypeError, setitem, b"a")
71        self.assertRaises(TypeError, setitem, 65)
72        self.assertRaises(TypeError, setitem, memoryview(b"a"))
73        m = None
74        self.assertEqual(sys.getrefcount(b), oldrefcount)
75
76    def test_setitem_writable(self):
77        if not self.rw_type:
78            self.skipTest("no writable type to test")
79        tp = self.rw_type
80        b = self.rw_type(self._source)
81        oldrefcount = sys.getrefcount(b)
82        m = self._view(b)
83        m[0] = ord(b'1')
84        self._check_contents(tp, b, b"1bcdef")
85        m[0:1] = tp(b"0")
86        self._check_contents(tp, b, b"0bcdef")
87        m[1:3] = tp(b"12")
88        self._check_contents(tp, b, b"012def")
89        m[1:1] = tp(b"")
90        self._check_contents(tp, b, b"012def")
91        m[:] = tp(b"abcdef")
92        self._check_contents(tp, b, b"abcdef")
93
94        # Overlapping copies of a view into itself
95        m[0:3] = m[2:5]
96        self._check_contents(tp, b, b"cdedef")
97        m[:] = tp(b"abcdef")
98        m[2:5] = m[0:3]
99        self._check_contents(tp, b, b"ababcf")
100
101        def setitem(key, value):
102            m[key] = tp(value)
103        # Bounds checking
104        self.assertRaises(IndexError, setitem, 6, b"a")
105        self.assertRaises(IndexError, setitem, -7, b"a")
106        self.assertRaises(IndexError, setitem, sys.maxsize, b"a")
107        self.assertRaises(IndexError, setitem, -sys.maxsize, b"a")
108        # Wrong index/slice types
109        self.assertRaises(TypeError, setitem, 0.0, b"a")
110        self.assertRaises(TypeError, setitem, (0,), b"a")
111        self.assertRaises(TypeError, setitem, (slice(0,1,1), 0), b"a")
112        self.assertRaises(TypeError, setitem, (0, slice(0,1,1)), b"a")
113        self.assertRaises(TypeError, setitem, (0,), b"a")
114        self.assertRaises(TypeError, setitem, "a", b"a")
115        # Not implemented: multidimensional slices
116        slices = (slice(0,1,1), slice(0,1,2))
117        self.assertRaises(NotImplementedError, setitem, slices, b"a")
118        # Trying to resize the memory object
119        exc = ValueError if m.format == 'c' else TypeError
120        self.assertRaises(exc, setitem, 0, b"")
121        self.assertRaises(exc, setitem, 0, b"ab")
122        self.assertRaises(ValueError, setitem, slice(1,1), b"a")
123        self.assertRaises(ValueError, setitem, slice(0,2), b"a")
124
125        m = None
126        self.assertEqual(sys.getrefcount(b), oldrefcount)
127
128    def test_delitem(self):
129        for tp in self._types:
130            b = tp(self._source)
131            m = self._view(b)
132            with self.assertRaises(TypeError):
133                del m[1]
134            with self.assertRaises(TypeError):
135                del m[1:4]
136
137    def test_tobytes(self):
138        for tp in self._types:
139            m = self._view(tp(self._source))
140            b = m.tobytes()
141            # This calls self.getitem_type() on each separate byte of b"abcdef"
142            expected = b"".join(
143                self.getitem_type(bytes([c])) for c in b"abcdef")
144            self.assertEqual(b, expected)
145            self.assertIsInstance(b, bytes)
146
147    def test_tolist(self):
148        for tp in self._types:
149            m = self._view(tp(self._source))
150            l = m.tolist()
151            self.assertEqual(l, list(b"abcdef"))
152
153    def test_compare(self):
154        # memoryviews can compare for equality with other objects
155        # having the buffer interface.
156        for tp in self._types:
157            m = self._view(tp(self._source))
158            for tp_comp in self._types:
159                self.assertTrue(m == tp_comp(b"abcdef"))
160                self.assertFalse(m != tp_comp(b"abcdef"))
161                self.assertFalse(m == tp_comp(b"abcde"))
162                self.assertTrue(m != tp_comp(b"abcde"))
163                self.assertFalse(m == tp_comp(b"abcde1"))
164                self.assertTrue(m != tp_comp(b"abcde1"))
165            self.assertTrue(m == m)
166            self.assertTrue(m == m[:])
167            self.assertTrue(m[0:6] == m[:])
168            self.assertFalse(m[0:5] == m)
169
170            # Comparison with objects which don't support the buffer API
171            self.assertFalse(m == "abcdef")
172            self.assertTrue(m != "abcdef")
173            self.assertFalse("abcdef" == m)
174            self.assertTrue("abcdef" != m)
175
176            # Unordered comparisons
177            for c in (m, b"abcdef"):
178                self.assertRaises(TypeError, lambda: m < c)
179                self.assertRaises(TypeError, lambda: c <= m)
180                self.assertRaises(TypeError, lambda: m >= c)
181                self.assertRaises(TypeError, lambda: c > m)
182
183    def check_attributes_with_type(self, tp):
184        m = self._view(tp(self._source))
185        self.assertEqual(m.format, self.format)
186        self.assertEqual(m.itemsize, self.itemsize)
187        self.assertEqual(m.ndim, 1)
188        self.assertEqual(m.shape, (6,))
189        self.assertEqual(len(m), 6)
190        self.assertEqual(m.strides, (self.itemsize,))
191        self.assertEqual(m.suboffsets, ())
192        return m
193
194    def test_attributes_readonly(self):
195        if not self.ro_type:
196            self.skipTest("no read-only type to test")
197        m = self.check_attributes_with_type(self.ro_type)
198        self.assertEqual(m.readonly, True)
199
200    def test_attributes_writable(self):
201        if not self.rw_type:
202            self.skipTest("no writable type to test")
203        m = self.check_attributes_with_type(self.rw_type)
204        self.assertEqual(m.readonly, False)
205
206    def test_getbuffer(self):
207        # Test PyObject_GetBuffer() on a memoryview object.
208        for tp in self._types:
209            b = tp(self._source)
210            oldrefcount = sys.getrefcount(b)
211            m = self._view(b)
212            oldviewrefcount = sys.getrefcount(m)
213            s = str(m, "utf-8")
214            self._check_contents(tp, b, s.encode("utf-8"))
215            self.assertEqual(sys.getrefcount(m), oldviewrefcount)
216            m = None
217            self.assertEqual(sys.getrefcount(b), oldrefcount)
218
219    def test_gc(self):
220        for tp in self._types:
221            if not isinstance(tp, type):
222                # If tp is a factory rather than a plain type, skip
223                continue
224
225            class MyView():
226                def __init__(self, base):
227                    self.m = memoryview(base)
228            class MySource(tp):
229                pass
230            class MyObject:
231                pass
232
233            # Create a reference cycle through a memoryview object.
234            # This exercises mbuf_clear().
235            b = MySource(tp(b'abc'))
236            m = self._view(b)
237            o = MyObject()
238            b.m = m
239            b.o = o
240            wr = weakref.ref(o)
241            b = m = o = None
242            # The cycle must be broken
243            gc.collect()
244            self.assertTrue(wr() is None, wr())
245
246            # This exercises memory_clear().
247            m = MyView(tp(b'abc'))
248            o = MyObject()
249            m.x = m
250            m.o = o
251            wr = weakref.ref(o)
252            m = o = None
253            # The cycle must be broken
254            gc.collect()
255            self.assertTrue(wr() is None, wr())
256
257    def _check_released(self, m, tp):
258        check = self.assertRaisesRegex(ValueError, "released")
259        with check: bytes(m)
260        with check: m.tobytes()
261        with check: m.tolist()
262        with check: m[0]
263        with check: m[0] = b'x'
264        with check: len(m)
265        with check: m.format
266        with check: m.itemsize
267        with check: m.ndim
268        with check: m.readonly
269        with check: m.shape
270        with check: m.strides
271        with check:
272            with m:
273                pass
274        # str() and repr() still function
275        self.assertIn("released memory", str(m))
276        self.assertIn("released memory", repr(m))
277        self.assertEqual(m, m)
278        self.assertNotEqual(m, memoryview(tp(self._source)))
279        self.assertNotEqual(m, tp(self._source))
280
281    def test_contextmanager(self):
282        for tp in self._types:
283            b = tp(self._source)
284            m = self._view(b)
285            with m as cm:
286                self.assertIs(cm, m)
287            self._check_released(m, tp)
288            m = self._view(b)
289            # Can release explicitly inside the context manager
290            with m:
291                m.release()
292
293    def test_release(self):
294        for tp in self._types:
295            b = tp(self._source)
296            m = self._view(b)
297            m.release()
298            self._check_released(m, tp)
299            # Can be called a second time (it's a no-op)
300            m.release()
301            self._check_released(m, tp)
302
303    def test_writable_readonly(self):
304        # Issue #10451: memoryview incorrectly exposes a readonly
305        # buffer as writable causing a segfault if using mmap
306        tp = self.ro_type
307        if tp is None:
308            self.skipTest("no read-only type to test")
309        b = tp(self._source)
310        m = self._view(b)
311        i = io.BytesIO(b'ZZZZ')
312        self.assertRaises(TypeError, i.readinto, m)
313
314    def test_getbuf_fail(self):
315        self.assertRaises(TypeError, self._view, {})
316
317    def test_hash(self):
318        # Memoryviews of readonly (hashable) types are hashable, and they
319        # hash as hash(obj.tobytes()).
320        tp = self.ro_type
321        if tp is None:
322            self.skipTest("no read-only type to test")
323        b = tp(self._source)
324        m = self._view(b)
325        self.assertEqual(hash(m), hash(b"abcdef"))
326        # Releasing the memoryview keeps the stored hash value (as with weakrefs)
327        m.release()
328        self.assertEqual(hash(m), hash(b"abcdef"))
329        # Hashing a memoryview for the first time after it is released
330        # results in an error (as with weakrefs).
331        m = self._view(b)
332        m.release()
333        self.assertRaises(ValueError, hash, m)
334
335    def test_hash_writable(self):
336        # Memoryviews of writable types are unhashable
337        tp = self.rw_type
338        if tp is None:
339            self.skipTest("no writable type to test")
340        b = tp(self._source)
341        m = self._view(b)
342        self.assertRaises(ValueError, hash, m)
343
344    def test_weakref(self):
345        # Check memoryviews are weakrefable
346        for tp in self._types:
347            b = tp(self._source)
348            m = self._view(b)
349            L = []
350            def callback(wr, b=b):
351                L.append(b)
352            wr = weakref.ref(m, callback)
353            self.assertIs(wr(), m)
354            del m
355            test.support.gc_collect()
356            self.assertIs(wr(), None)
357            self.assertIs(L[0], b)
358
359    def test_reversed(self):
360        for tp in self._types:
361            b = tp(self._source)
362            m = self._view(b)
363            aslist = list(reversed(m.tolist()))
364            self.assertEqual(list(reversed(m)), aslist)
365            self.assertEqual(list(reversed(m)), list(m[::-1]))
366
367    def test_toreadonly(self):
368        for tp in self._types:
369            b = tp(self._source)
370            m = self._view(b)
371            mm = m.toreadonly()
372            self.assertTrue(mm.readonly)
373            self.assertTrue(memoryview(mm).readonly)
374            self.assertEqual(mm.tolist(), m.tolist())
375            mm.release()
376            m.tolist()
377
378    def test_issue22668(self):
379        a = array.array('H', [256, 256, 256, 256])
380        x = memoryview(a)
381        m = x.cast('B')
382        b = m.cast('H')
383        c = b[0:2]
384        d = memoryview(b)
385
386        del b
387
388        self.assertEqual(c[0], 256)
389        self.assertEqual(d[0], 256)
390        self.assertEqual(c.format, "H")
391        self.assertEqual(d.format, "H")
392
393        _ = m.cast('I')
394        self.assertEqual(c[0], 256)
395        self.assertEqual(d[0], 256)
396        self.assertEqual(c.format, "H")
397        self.assertEqual(d.format, "H")
398
399
400# Variations on source objects for the buffer: bytes-like objects, then arrays
401# with itemsize > 1.
402# NOTE: support for multi-dimensional objects is unimplemented.
403
404class BaseBytesMemoryTests(AbstractMemoryTests):
405    ro_type = bytes
406    rw_type = bytearray
407    getitem_type = bytes
408    itemsize = 1
409    format = 'B'
410
411class BaseArrayMemoryTests(AbstractMemoryTests):
412    ro_type = None
413    rw_type = lambda self, b: array.array('i', list(b))
414    getitem_type = lambda self, b: array.array('i', list(b)).tobytes()
415    itemsize = array.array('i').itemsize
416    format = 'i'
417
418    @unittest.skip('XXX test should be adapted for non-byte buffers')
419    def test_getbuffer(self):
420        pass
421
422    @unittest.skip('XXX NotImplementedError: tolist() only supports byte views')
423    def test_tolist(self):
424        pass
425
426
427# Variations on indirection levels: memoryview, slice of memoryview,
428# slice of slice of memoryview.
429# This is important to test allocation subtleties.
430
431class BaseMemoryviewTests:
432    def _view(self, obj):
433        return memoryview(obj)
434
435    def _check_contents(self, tp, obj, contents):
436        self.assertEqual(obj, tp(contents))
437
438class BaseMemorySliceTests:
439    source_bytes = b"XabcdefY"
440
441    def _view(self, obj):
442        m = memoryview(obj)
443        return m[1:7]
444
445    def _check_contents(self, tp, obj, contents):
446        self.assertEqual(obj[1:7], tp(contents))
447
448    def test_refs(self):
449        for tp in self._types:
450            m = memoryview(tp(self._source))
451            oldrefcount = sys.getrefcount(m)
452            m[1:2]
453            self.assertEqual(sys.getrefcount(m), oldrefcount)
454
455class BaseMemorySliceSliceTests:
456    source_bytes = b"XabcdefY"
457
458    def _view(self, obj):
459        m = memoryview(obj)
460        return m[:7][1:]
461
462    def _check_contents(self, tp, obj, contents):
463        self.assertEqual(obj[1:7], tp(contents))
464
465
466# Concrete test classes
467
468class BytesMemoryviewTest(unittest.TestCase,
469    BaseMemoryviewTests, BaseBytesMemoryTests):
470
471    def test_constructor(self):
472        for tp in self._types:
473            ob = tp(self._source)
474            self.assertTrue(memoryview(ob))
475            self.assertTrue(memoryview(object=ob))
476            self.assertRaises(TypeError, memoryview)
477            self.assertRaises(TypeError, memoryview, ob, ob)
478            self.assertRaises(TypeError, memoryview, argument=ob)
479            self.assertRaises(TypeError, memoryview, ob, argument=True)
480
481class ArrayMemoryviewTest(unittest.TestCase,
482    BaseMemoryviewTests, BaseArrayMemoryTests):
483
484    def test_array_assign(self):
485        # Issue #4569: segfault when mutating a memoryview with itemsize != 1
486        a = array.array('i', range(10))
487        m = memoryview(a)
488        new_a = array.array('i', range(9, -1, -1))
489        m[:] = new_a
490        self.assertEqual(a, new_a)
491
492
493class BytesMemorySliceTest(unittest.TestCase,
494    BaseMemorySliceTests, BaseBytesMemoryTests):
495    pass
496
497class ArrayMemorySliceTest(unittest.TestCase,
498    BaseMemorySliceTests, BaseArrayMemoryTests):
499    pass
500
501class BytesMemorySliceSliceTest(unittest.TestCase,
502    BaseMemorySliceSliceTests, BaseBytesMemoryTests):
503    pass
504
505class ArrayMemorySliceSliceTest(unittest.TestCase,
506    BaseMemorySliceSliceTests, BaseArrayMemoryTests):
507    pass
508
509
510class OtherTest(unittest.TestCase):
511    def test_ctypes_cast(self):
512        # Issue 15944: Allow all source formats when casting to bytes.
513        ctypes = import_helper.import_module("ctypes")
514        p6 = bytes(ctypes.c_double(0.6))
515
516        d = ctypes.c_double()
517        m = memoryview(d).cast("B")
518        m[:2] = p6[:2]
519        m[2:] = p6[2:]
520        self.assertEqual(d.value, 0.6)
521
522        for format in "Bbc":
523            with self.subTest(format):
524                d = ctypes.c_double()
525                m = memoryview(d).cast(format)
526                m[:2] = memoryview(p6).cast(format)[:2]
527                m[2:] = memoryview(p6).cast(format)[2:]
528                self.assertEqual(d.value, 0.6)
529
530    def test_memoryview_hex(self):
531        # Issue #9951: memoryview.hex() segfaults with non-contiguous buffers.
532        x = b'0' * 200000
533        m1 = memoryview(x)
534        m2 = m1[::-1]
535        self.assertEqual(m2.hex(), '30' * 200000)
536
537    def test_copy(self):
538        m = memoryview(b'abc')
539        with self.assertRaises(TypeError):
540            copy.copy(m)
541
542    def test_pickle(self):
543        m = memoryview(b'abc')
544        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
545            with self.assertRaises(TypeError):
546                pickle.dumps(m, proto)
547
548
549if __name__ == "__main__":
550    unittest.main()
551