1"""Unit tests for the PickleBuffer object.
2
3Pickling tests themselves are in pickletester.py.
4"""
5
6import gc
7from pickle import PickleBuffer
8import weakref
9import unittest
10
11from test.support import import_helper
12
13
14class B(bytes):
15    pass
16
17
18class PickleBufferTest(unittest.TestCase):
19
20    def check_memoryview(self, pb, equiv):
21        with memoryview(pb) as m:
22            with memoryview(equiv) as expected:
23                self.assertEqual(m.nbytes, expected.nbytes)
24                self.assertEqual(m.readonly, expected.readonly)
25                self.assertEqual(m.itemsize, expected.itemsize)
26                self.assertEqual(m.shape, expected.shape)
27                self.assertEqual(m.strides, expected.strides)
28                self.assertEqual(m.c_contiguous, expected.c_contiguous)
29                self.assertEqual(m.f_contiguous, expected.f_contiguous)
30                self.assertEqual(m.format, expected.format)
31                self.assertEqual(m.tobytes(), expected.tobytes())
32
33    def test_constructor_failure(self):
34        with self.assertRaises(TypeError):
35            PickleBuffer()
36        with self.assertRaises(TypeError):
37            PickleBuffer("foo")
38        # Released memoryview fails taking a buffer
39        m = memoryview(b"foo")
40        m.release()
41        with self.assertRaises(ValueError):
42            PickleBuffer(m)
43
44    def test_basics(self):
45        pb = PickleBuffer(b"foo")
46        self.assertEqual(b"foo", bytes(pb))
47        with memoryview(pb) as m:
48            self.assertTrue(m.readonly)
49
50        pb = PickleBuffer(bytearray(b"foo"))
51        self.assertEqual(b"foo", bytes(pb))
52        with memoryview(pb) as m:
53            self.assertFalse(m.readonly)
54            m[0] = 48
55        self.assertEqual(b"0oo", bytes(pb))
56
57    def test_release(self):
58        pb = PickleBuffer(b"foo")
59        pb.release()
60        with self.assertRaises(ValueError) as raises:
61            memoryview(pb)
62        self.assertIn("operation forbidden on released PickleBuffer object",
63                      str(raises.exception))
64        # Idempotency
65        pb.release()
66
67    def test_cycle(self):
68        b = B(b"foo")
69        pb = PickleBuffer(b)
70        b.cycle = pb
71        wpb = weakref.ref(pb)
72        del b, pb
73        gc.collect()
74        self.assertIsNone(wpb())
75
76    def test_ndarray_2d(self):
77        # C-contiguous
78        ndarray = import_helper.import_module("_testbuffer").ndarray
79        arr = ndarray(list(range(12)), shape=(4, 3), format='<i')
80        self.assertTrue(arr.c_contiguous)
81        self.assertFalse(arr.f_contiguous)
82        pb = PickleBuffer(arr)
83        self.check_memoryview(pb, arr)
84        # Non-contiguous
85        arr = arr[::2]
86        self.assertFalse(arr.c_contiguous)
87        self.assertFalse(arr.f_contiguous)
88        pb = PickleBuffer(arr)
89        self.check_memoryview(pb, arr)
90        # F-contiguous
91        arr = ndarray(list(range(12)), shape=(3, 4), strides=(4, 12), format='<i')
92        self.assertTrue(arr.f_contiguous)
93        self.assertFalse(arr.c_contiguous)
94        pb = PickleBuffer(arr)
95        self.check_memoryview(pb, arr)
96
97    # Tests for PickleBuffer.raw()
98
99    def check_raw(self, obj, equiv):
100        pb = PickleBuffer(obj)
101        with pb.raw() as m:
102            self.assertIsInstance(m, memoryview)
103            self.check_memoryview(m, equiv)
104
105    def test_raw(self):
106        for obj in (b"foo", bytearray(b"foo")):
107            with self.subTest(obj=obj):
108                self.check_raw(obj, obj)
109
110    def test_raw_ndarray(self):
111        # 1-D, contiguous
112        ndarray = import_helper.import_module("_testbuffer").ndarray
113        arr = ndarray(list(range(3)), shape=(3,), format='<h')
114        equiv = b"\x00\x00\x01\x00\x02\x00"
115        self.check_raw(arr, equiv)
116        # 2-D, C-contiguous
117        arr = ndarray(list(range(6)), shape=(2, 3), format='<h')
118        equiv = b"\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00"
119        self.check_raw(arr, equiv)
120        # 2-D, F-contiguous
121        arr = ndarray(list(range(6)), shape=(2, 3), strides=(2, 4),
122                      format='<h')
123        # Note this is different from arr.tobytes()
124        equiv = b"\x00\x00\x01\x00\x02\x00\x03\x00\x04\x00\x05\x00"
125        self.check_raw(arr, equiv)
126        # 0-D
127        arr = ndarray(456, shape=(), format='<i')
128        equiv = b'\xc8\x01\x00\x00'
129        self.check_raw(arr, equiv)
130
131    def check_raw_non_contiguous(self, obj):
132        pb = PickleBuffer(obj)
133        with self.assertRaisesRegex(BufferError, "non-contiguous"):
134            pb.raw()
135
136    def test_raw_non_contiguous(self):
137        # 1-D
138        ndarray = import_helper.import_module("_testbuffer").ndarray
139        arr = ndarray(list(range(6)), shape=(6,), format='<i')[::2]
140        self.check_raw_non_contiguous(arr)
141        # 2-D
142        arr = ndarray(list(range(12)), shape=(4, 3), format='<i')[::2]
143        self.check_raw_non_contiguous(arr)
144
145    def test_raw_released(self):
146        pb = PickleBuffer(b"foo")
147        pb.release()
148        with self.assertRaises(ValueError) as raises:
149            pb.raw()
150
151
152if __name__ == "__main__":
153    unittest.main()
154