1from ctypes import *
2from ctypes.test import need_symbol
3import unittest
4import os
5
6import _ctypes_test
7
8class BITS(Structure):
9    _fields_ = [("A", c_int, 1),
10                ("B", c_int, 2),
11                ("C", c_int, 3),
12                ("D", c_int, 4),
13                ("E", c_int, 5),
14                ("F", c_int, 6),
15                ("G", c_int, 7),
16                ("H", c_int, 8),
17                ("I", c_int, 9),
18
19                ("M", c_short, 1),
20                ("N", c_short, 2),
21                ("O", c_short, 3),
22                ("P", c_short, 4),
23                ("Q", c_short, 5),
24                ("R", c_short, 6),
25                ("S", c_short, 7)]
26
27func = CDLL(_ctypes_test.__file__).unpack_bitfields
28func.argtypes = POINTER(BITS), c_char
29
30##for n in "ABCDEFGHIMNOPQRS":
31##    print n, hex(getattr(BITS, n).size), getattr(BITS, n).offset
32
33class C_Test(unittest.TestCase):
34
35    def test_ints(self):
36        for i in range(512):
37            for name in "ABCDEFGHI":
38                b = BITS()
39                setattr(b, name, i)
40                self.assertEqual(getattr(b, name), func(byref(b), name.encode('ascii')))
41
42    def test_shorts(self):
43        for i in range(256):
44            for name in "MNOPQRS":
45                b = BITS()
46                setattr(b, name, i)
47                self.assertEqual(getattr(b, name), func(byref(b), name.encode('ascii')))
48
49signed_int_types = (c_byte, c_short, c_int, c_long, c_longlong)
50unsigned_int_types = (c_ubyte, c_ushort, c_uint, c_ulong, c_ulonglong)
51int_types = unsigned_int_types + signed_int_types
52
53class BitFieldTest(unittest.TestCase):
54
55    def test_longlong(self):
56        class X(Structure):
57            _fields_ = [("a", c_longlong, 1),
58                        ("b", c_longlong, 62),
59                        ("c", c_longlong, 1)]
60
61        self.assertEqual(sizeof(X), sizeof(c_longlong))
62        x = X()
63        x.a, x.b, x.c = -1, 7, -1
64        self.assertEqual((x.a, x.b, x.c), (-1, 7, -1))
65
66    def test_ulonglong(self):
67        class X(Structure):
68            _fields_ = [("a", c_ulonglong, 1),
69                        ("b", c_ulonglong, 62),
70                        ("c", c_ulonglong, 1)]
71
72        self.assertEqual(sizeof(X), sizeof(c_longlong))
73        x = X()
74        self.assertEqual((x.a, x.b, x.c), (0, 0, 0))
75        x.a, x.b, x.c = 7, 7, 7
76        self.assertEqual((x.a, x.b, x.c), (1, 7, 1))
77
78    def test_signed(self):
79        for c_typ in signed_int_types:
80            class X(Structure):
81                _fields_ = [("dummy", c_typ),
82                            ("a", c_typ, 3),
83                            ("b", c_typ, 3),
84                            ("c", c_typ, 1)]
85            self.assertEqual(sizeof(X), sizeof(c_typ)*2)
86
87            x = X()
88            self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 0, 0))
89            x.a = -1
90            self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, -1, 0, 0))
91            x.a, x.b = 0, -1
92            self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, -1, 0))
93
94
95    def test_unsigned(self):
96        for c_typ in unsigned_int_types:
97            class X(Structure):
98                _fields_ = [("a", c_typ, 3),
99                            ("b", c_typ, 3),
100                            ("c", c_typ, 1)]
101            self.assertEqual(sizeof(X), sizeof(c_typ))
102
103            x = X()
104            self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 0, 0))
105            x.a = -1
106            self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 7, 0, 0))
107            x.a, x.b = 0, -1
108            self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 7, 0))
109
110
111    def fail_fields(self, *fields):
112        return self.get_except(type(Structure), "X", (),
113                               {"_fields_": fields})
114
115    def test_nonint_types(self):
116        # bit fields are not allowed on non-integer types.
117        result = self.fail_fields(("a", c_char_p, 1))
118        self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_char_p'))
119
120        result = self.fail_fields(("a", c_void_p, 1))
121        self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_void_p'))
122
123        if c_int != c_long:
124            result = self.fail_fields(("a", POINTER(c_int), 1))
125            self.assertEqual(result, (TypeError, 'bit fields not allowed for type LP_c_int'))
126
127        result = self.fail_fields(("a", c_char, 1))
128        self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_char'))
129
130        class Dummy(Structure):
131            _fields_ = []
132
133        result = self.fail_fields(("a", Dummy, 1))
134        self.assertEqual(result, (TypeError, 'bit fields not allowed for type Dummy'))
135
136    @need_symbol('c_wchar')
137    def test_c_wchar(self):
138        result = self.fail_fields(("a", c_wchar, 1))
139        self.assertEqual(result,
140                (TypeError, 'bit fields not allowed for type c_wchar'))
141
142    def test_single_bitfield_size(self):
143        for c_typ in int_types:
144            result = self.fail_fields(("a", c_typ, -1))
145            self.assertEqual(result, (ValueError, 'number of bits invalid for bit field'))
146
147            result = self.fail_fields(("a", c_typ, 0))
148            self.assertEqual(result, (ValueError, 'number of bits invalid for bit field'))
149
150            class X(Structure):
151                _fields_ = [("a", c_typ, 1)]
152            self.assertEqual(sizeof(X), sizeof(c_typ))
153
154            class X(Structure):
155                _fields_ = [("a", c_typ, sizeof(c_typ)*8)]
156            self.assertEqual(sizeof(X), sizeof(c_typ))
157
158            result = self.fail_fields(("a", c_typ, sizeof(c_typ)*8 + 1))
159            self.assertEqual(result, (ValueError, 'number of bits invalid for bit field'))
160
161    def test_multi_bitfields_size(self):
162        class X(Structure):
163            _fields_ = [("a", c_short, 1),
164                        ("b", c_short, 14),
165                        ("c", c_short, 1)]
166        self.assertEqual(sizeof(X), sizeof(c_short))
167
168        class X(Structure):
169            _fields_ = [("a", c_short, 1),
170                        ("a1", c_short),
171                        ("b", c_short, 14),
172                        ("c", c_short, 1)]
173        self.assertEqual(sizeof(X), sizeof(c_short)*3)
174        self.assertEqual(X.a.offset, 0)
175        self.assertEqual(X.a1.offset, sizeof(c_short))
176        self.assertEqual(X.b.offset, sizeof(c_short)*2)
177        self.assertEqual(X.c.offset, sizeof(c_short)*2)
178
179        class X(Structure):
180            _fields_ = [("a", c_short, 3),
181                        ("b", c_short, 14),
182                        ("c", c_short, 14)]
183        self.assertEqual(sizeof(X), sizeof(c_short)*3)
184        self.assertEqual(X.a.offset, sizeof(c_short)*0)
185        self.assertEqual(X.b.offset, sizeof(c_short)*1)
186        self.assertEqual(X.c.offset, sizeof(c_short)*2)
187
188
189    def get_except(self, func, *args, **kw):
190        try:
191            func(*args, **kw)
192        except Exception as detail:
193            return detail.__class__, str(detail)
194
195    def test_mixed_1(self):
196        class X(Structure):
197            _fields_ = [("a", c_byte, 4),
198                        ("b", c_int, 4)]
199        if os.name == "nt":
200            self.assertEqual(sizeof(X), sizeof(c_int)*2)
201        else:
202            self.assertEqual(sizeof(X), sizeof(c_int))
203
204    def test_mixed_2(self):
205        class X(Structure):
206            _fields_ = [("a", c_byte, 4),
207                        ("b", c_int, 32)]
208        self.assertEqual(sizeof(X), alignment(c_int)+sizeof(c_int))
209
210    def test_mixed_3(self):
211        class X(Structure):
212            _fields_ = [("a", c_byte, 4),
213                        ("b", c_ubyte, 4)]
214        self.assertEqual(sizeof(X), sizeof(c_byte))
215
216    def test_mixed_4(self):
217        class X(Structure):
218            _fields_ = [("a", c_short, 4),
219                        ("b", c_short, 4),
220                        ("c", c_int, 24),
221                        ("d", c_short, 4),
222                        ("e", c_short, 4),
223                        ("f", c_int, 24)]
224        # MSVC does NOT combine c_short and c_int into one field, GCC
225        # does (unless GCC is run with '-mms-bitfields' which
226        # produces code compatible with MSVC).
227        if os.name == "nt":
228            self.assertEqual(sizeof(X), sizeof(c_int) * 4)
229        else:
230            self.assertEqual(sizeof(X), sizeof(c_int) * 2)
231
232    def test_anon_bitfields(self):
233        # anonymous bit-fields gave a strange error message
234        class X(Structure):
235            _fields_ = [("a", c_byte, 4),
236                        ("b", c_ubyte, 4)]
237        class Y(Structure):
238            _anonymous_ = ["_"]
239            _fields_ = [("_", X)]
240
241    @need_symbol('c_uint32')
242    def test_uint32(self):
243        class X(Structure):
244            _fields_ = [("a", c_uint32, 32)]
245        x = X()
246        x.a = 10
247        self.assertEqual(x.a, 10)
248        x.a = 0xFDCBA987
249        self.assertEqual(x.a, 0xFDCBA987)
250
251    @need_symbol('c_uint64')
252    def test_uint64(self):
253        class X(Structure):
254            _fields_ = [("a", c_uint64, 64)]
255        x = X()
256        x.a = 10
257        self.assertEqual(x.a, 10)
258        x.a = 0xFEDCBA9876543211
259        self.assertEqual(x.a, 0xFEDCBA9876543211)
260
261    @need_symbol('c_uint32')
262    def test_uint32_swap_little_endian(self):
263        # Issue #23319
264        class Little(LittleEndianStructure):
265            _fields_ = [("a", c_uint32, 24),
266                        ("b", c_uint32, 4),
267                        ("c", c_uint32, 4)]
268        b = bytearray(4)
269        x = Little.from_buffer(b)
270        x.a = 0xabcdef
271        x.b = 1
272        x.c = 2
273        self.assertEqual(b, b'\xef\xcd\xab\x21')
274
275    @need_symbol('c_uint32')
276    def test_uint32_swap_big_endian(self):
277        # Issue #23319
278        class Big(BigEndianStructure):
279            _fields_ = [("a", c_uint32, 24),
280                        ("b", c_uint32, 4),
281                        ("c", c_uint32, 4)]
282        b = bytearray(4)
283        x = Big.from_buffer(b)
284        x.a = 0xabcdef
285        x.b = 1
286        x.c = 2
287        self.assertEqual(b, b'\xab\xcd\xef\x12')
288
289if __name__ == "__main__":
290    unittest.main()
291