1from ctypes import *
2from ctypes.test import need_symbol
3import unittest
4import os
5
6import _ctypes_test
7
8class BITS(Structure):
9    _fields_ = [("A", c_int, 1),
10                ("B", c_int, 2),
11                ("C", c_int, 3),
12                ("D", c_int, 4),
13                ("E", c_int, 5),
14                ("F", c_int, 6),
15                ("G", c_int, 7),
16                ("H", c_int, 8),
17                ("I", c_int, 9),
18
19                ("M", c_short, 1),
20                ("N", c_short, 2),
21                ("O", c_short, 3),
22                ("P", c_short, 4),
23                ("Q", c_short, 5),
24                ("R", c_short, 6),
25                ("S", c_short, 7)]
26
27func = CDLL(_ctypes_test.__file__).unpack_bitfields
28func.argtypes = POINTER(BITS), c_char
29
30##for n in "ABCDEFGHIMNOPQRS":
31##    print n, hex(getattr(BITS, n).size), getattr(BITS, n).offset
32
33class C_Test(unittest.TestCase):
34
35    def test_ints(self):
36        for i in range(512):
37            for name in "ABCDEFGHI":
38                b = BITS()
39                setattr(b, name, i)
40                self.assertEqual(getattr(b, name), func(byref(b), name.encode('ascii')))
41
42    def test_shorts(self):
43        b = BITS()
44        name = "M"
45        if func(byref(b), name.encode('ascii')) == 999:
46            self.skipTest("Compiler does not support signed short bitfields")
47        for i in range(256):
48            for name in "MNOPQRS":
49                b = BITS()
50                setattr(b, name, i)
51                self.assertEqual(getattr(b, name), func(byref(b), name.encode('ascii')))
52
53signed_int_types = (c_byte, c_short, c_int, c_long, c_longlong)
54unsigned_int_types = (c_ubyte, c_ushort, c_uint, c_ulong, c_ulonglong)
55int_types = unsigned_int_types + signed_int_types
56
57class BitFieldTest(unittest.TestCase):
58
59    def test_longlong(self):
60        class X(Structure):
61            _fields_ = [("a", c_longlong, 1),
62                        ("b", c_longlong, 62),
63                        ("c", c_longlong, 1)]
64
65        self.assertEqual(sizeof(X), sizeof(c_longlong))
66        x = X()
67        x.a, x.b, x.c = -1, 7, -1
68        self.assertEqual((x.a, x.b, x.c), (-1, 7, -1))
69
70    def test_ulonglong(self):
71        class X(Structure):
72            _fields_ = [("a", c_ulonglong, 1),
73                        ("b", c_ulonglong, 62),
74                        ("c", c_ulonglong, 1)]
75
76        self.assertEqual(sizeof(X), sizeof(c_longlong))
77        x = X()
78        self.assertEqual((x.a, x.b, x.c), (0, 0, 0))
79        x.a, x.b, x.c = 7, 7, 7
80        self.assertEqual((x.a, x.b, x.c), (1, 7, 1))
81
82    def test_signed(self):
83        for c_typ in signed_int_types:
84            class X(Structure):
85                _fields_ = [("dummy", c_typ),
86                            ("a", c_typ, 3),
87                            ("b", c_typ, 3),
88                            ("c", c_typ, 1)]
89            self.assertEqual(sizeof(X), sizeof(c_typ)*2)
90
91            x = X()
92            self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 0, 0))
93            x.a = -1
94            self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, -1, 0, 0))
95            x.a, x.b = 0, -1
96            self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, -1, 0))
97
98
99    def test_unsigned(self):
100        for c_typ in unsigned_int_types:
101            class X(Structure):
102                _fields_ = [("a", c_typ, 3),
103                            ("b", c_typ, 3),
104                            ("c", c_typ, 1)]
105            self.assertEqual(sizeof(X), sizeof(c_typ))
106
107            x = X()
108            self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 0, 0))
109            x.a = -1
110            self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 7, 0, 0))
111            x.a, x.b = 0, -1
112            self.assertEqual((c_typ, x.a, x.b, x.c), (c_typ, 0, 7, 0))
113
114
115    def fail_fields(self, *fields):
116        return self.get_except(type(Structure), "X", (),
117                               {"_fields_": fields})
118
119    def test_nonint_types(self):
120        # bit fields are not allowed on non-integer types.
121        result = self.fail_fields(("a", c_char_p, 1))
122        self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_char_p'))
123
124        result = self.fail_fields(("a", c_void_p, 1))
125        self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_void_p'))
126
127        if c_int != c_long:
128            result = self.fail_fields(("a", POINTER(c_int), 1))
129            self.assertEqual(result, (TypeError, 'bit fields not allowed for type LP_c_int'))
130
131        result = self.fail_fields(("a", c_char, 1))
132        self.assertEqual(result, (TypeError, 'bit fields not allowed for type c_char'))
133
134        class Dummy(Structure):
135            _fields_ = []
136
137        result = self.fail_fields(("a", Dummy, 1))
138        self.assertEqual(result, (TypeError, 'bit fields not allowed for type Dummy'))
139
140    @need_symbol('c_wchar')
141    def test_c_wchar(self):
142        result = self.fail_fields(("a", c_wchar, 1))
143        self.assertEqual(result,
144                (TypeError, 'bit fields not allowed for type c_wchar'))
145
146    def test_single_bitfield_size(self):
147        for c_typ in int_types:
148            result = self.fail_fields(("a", c_typ, -1))
149            self.assertEqual(result, (ValueError, 'number of bits invalid for bit field'))
150
151            result = self.fail_fields(("a", c_typ, 0))
152            self.assertEqual(result, (ValueError, 'number of bits invalid for bit field'))
153
154            class X(Structure):
155                _fields_ = [("a", c_typ, 1)]
156            self.assertEqual(sizeof(X), sizeof(c_typ))
157
158            class X(Structure):
159                _fields_ = [("a", c_typ, sizeof(c_typ)*8)]
160            self.assertEqual(sizeof(X), sizeof(c_typ))
161
162            result = self.fail_fields(("a", c_typ, sizeof(c_typ)*8 + 1))
163            self.assertEqual(result, (ValueError, 'number of bits invalid for bit field'))
164
165    def test_multi_bitfields_size(self):
166        class X(Structure):
167            _fields_ = [("a", c_short, 1),
168                        ("b", c_short, 14),
169                        ("c", c_short, 1)]
170        self.assertEqual(sizeof(X), sizeof(c_short))
171
172        class X(Structure):
173            _fields_ = [("a", c_short, 1),
174                        ("a1", c_short),
175                        ("b", c_short, 14),
176                        ("c", c_short, 1)]
177        self.assertEqual(sizeof(X), sizeof(c_short)*3)
178        self.assertEqual(X.a.offset, 0)
179        self.assertEqual(X.a1.offset, sizeof(c_short))
180        self.assertEqual(X.b.offset, sizeof(c_short)*2)
181        self.assertEqual(X.c.offset, sizeof(c_short)*2)
182
183        class X(Structure):
184            _fields_ = [("a", c_short, 3),
185                        ("b", c_short, 14),
186                        ("c", c_short, 14)]
187        self.assertEqual(sizeof(X), sizeof(c_short)*3)
188        self.assertEqual(X.a.offset, sizeof(c_short)*0)
189        self.assertEqual(X.b.offset, sizeof(c_short)*1)
190        self.assertEqual(X.c.offset, sizeof(c_short)*2)
191
192
193    def get_except(self, func, *args, **kw):
194        try:
195            func(*args, **kw)
196        except Exception as detail:
197            return detail.__class__, str(detail)
198
199    def test_mixed_1(self):
200        class X(Structure):
201            _fields_ = [("a", c_byte, 4),
202                        ("b", c_int, 4)]
203        if os.name == "nt":
204            self.assertEqual(sizeof(X), sizeof(c_int)*2)
205        else:
206            self.assertEqual(sizeof(X), sizeof(c_int))
207
208    def test_mixed_2(self):
209        class X(Structure):
210            _fields_ = [("a", c_byte, 4),
211                        ("b", c_int, 32)]
212        self.assertEqual(sizeof(X), alignment(c_int)+sizeof(c_int))
213
214    def test_mixed_3(self):
215        class X(Structure):
216            _fields_ = [("a", c_byte, 4),
217                        ("b", c_ubyte, 4)]
218        self.assertEqual(sizeof(X), sizeof(c_byte))
219
220    def test_mixed_4(self):
221        class X(Structure):
222            _fields_ = [("a", c_short, 4),
223                        ("b", c_short, 4),
224                        ("c", c_int, 24),
225                        ("d", c_short, 4),
226                        ("e", c_short, 4),
227                        ("f", c_int, 24)]
228        # MSVC does NOT combine c_short and c_int into one field, GCC
229        # does (unless GCC is run with '-mms-bitfields' which
230        # produces code compatible with MSVC).
231        if os.name == "nt":
232            self.assertEqual(sizeof(X), sizeof(c_int) * 4)
233        else:
234            self.assertEqual(sizeof(X), sizeof(c_int) * 2)
235
236    def test_anon_bitfields(self):
237        # anonymous bit-fields gave a strange error message
238        class X(Structure):
239            _fields_ = [("a", c_byte, 4),
240                        ("b", c_ubyte, 4)]
241        class Y(Structure):
242            _anonymous_ = ["_"]
243            _fields_ = [("_", X)]
244
245    @need_symbol('c_uint32')
246    def test_uint32(self):
247        class X(Structure):
248            _fields_ = [("a", c_uint32, 32)]
249        x = X()
250        x.a = 10
251        self.assertEqual(x.a, 10)
252        x.a = 0xFDCBA987
253        self.assertEqual(x.a, 0xFDCBA987)
254
255    @need_symbol('c_uint64')
256    def test_uint64(self):
257        class X(Structure):
258            _fields_ = [("a", c_uint64, 64)]
259        x = X()
260        x.a = 10
261        self.assertEqual(x.a, 10)
262        x.a = 0xFEDCBA9876543211
263        self.assertEqual(x.a, 0xFEDCBA9876543211)
264
265    @need_symbol('c_uint32')
266    def test_uint32_swap_little_endian(self):
267        # Issue #23319
268        class Little(LittleEndianStructure):
269            _fields_ = [("a", c_uint32, 24),
270                        ("b", c_uint32, 4),
271                        ("c", c_uint32, 4)]
272        b = bytearray(4)
273        x = Little.from_buffer(b)
274        x.a = 0xabcdef
275        x.b = 1
276        x.c = 2
277        self.assertEqual(b, b'\xef\xcd\xab\x21')
278
279    @need_symbol('c_uint32')
280    def test_uint32_swap_big_endian(self):
281        # Issue #23319
282        class Big(BigEndianStructure):
283            _fields_ = [("a", c_uint32, 24),
284                        ("b", c_uint32, 4),
285                        ("c", c_uint32, 4)]
286        b = bytearray(4)
287        x = Big.from_buffer(b)
288        x.a = 0xabcdef
289        x.b = 1
290        x.c = 2
291        self.assertEqual(b, b'\xab\xcd\xef\x12')
292
293if __name__ == "__main__":
294    unittest.main()
295