1import functools
2import unittest
3from test import support
4
5from ctypes import *
6from ctypes.test import need_symbol
7import _ctypes_test
8
9class Callbacks(unittest.TestCase):
10    functype = CFUNCTYPE
11
12##    def tearDown(self):
13##        import gc
14##        gc.collect()
15
16    def callback(self, *args):
17        self.got_args = args
18        return args[-1]
19
20    def check_type(self, typ, arg):
21        PROTO = self.functype.__func__(typ, typ)
22        result = PROTO(self.callback)(arg)
23        if typ == c_float:
24            self.assertAlmostEqual(result, arg, places=5)
25        else:
26            self.assertEqual(self.got_args, (arg,))
27            self.assertEqual(result, arg)
28
29        PROTO = self.functype.__func__(typ, c_byte, typ)
30        result = PROTO(self.callback)(-3, arg)
31        if typ == c_float:
32            self.assertAlmostEqual(result, arg, places=5)
33        else:
34            self.assertEqual(self.got_args, (-3, arg))
35            self.assertEqual(result, arg)
36
37    ################
38
39    def test_byte(self):
40        self.check_type(c_byte, 42)
41        self.check_type(c_byte, -42)
42
43    def test_ubyte(self):
44        self.check_type(c_ubyte, 42)
45
46    def test_short(self):
47        self.check_type(c_short, 42)
48        self.check_type(c_short, -42)
49
50    def test_ushort(self):
51        self.check_type(c_ushort, 42)
52
53    def test_int(self):
54        self.check_type(c_int, 42)
55        self.check_type(c_int, -42)
56
57    def test_uint(self):
58        self.check_type(c_uint, 42)
59
60    def test_long(self):
61        self.check_type(c_long, 42)
62        self.check_type(c_long, -42)
63
64    def test_ulong(self):
65        self.check_type(c_ulong, 42)
66
67    def test_longlong(self):
68        self.check_type(c_longlong, 42)
69        self.check_type(c_longlong, -42)
70
71    def test_ulonglong(self):
72        self.check_type(c_ulonglong, 42)
73
74    def test_float(self):
75        # only almost equal: double -> float -> double
76        import math
77        self.check_type(c_float, math.e)
78        self.check_type(c_float, -math.e)
79
80    def test_double(self):
81        self.check_type(c_double, 3.14)
82        self.check_type(c_double, -3.14)
83
84    def test_longdouble(self):
85        self.check_type(c_longdouble, 3.14)
86        self.check_type(c_longdouble, -3.14)
87
88    def test_char(self):
89        self.check_type(c_char, b"x")
90        self.check_type(c_char, b"a")
91
92    # disabled: would now (correctly) raise a RuntimeWarning about
93    # a memory leak.  A callback function cannot return a non-integral
94    # C type without causing a memory leak.
95    @unittest.skip('test disabled')
96    def test_char_p(self):
97        self.check_type(c_char_p, "abc")
98        self.check_type(c_char_p, "def")
99
100    def test_pyobject(self):
101        o = ()
102        from sys import getrefcount as grc
103        for o in (), [], object():
104            initial = grc(o)
105            # This call leaks a reference to 'o'...
106            self.check_type(py_object, o)
107            before = grc(o)
108            # ...but this call doesn't leak any more.  Where is the refcount?
109            self.check_type(py_object, o)
110            after = grc(o)
111            self.assertEqual((after, o), (before, o))
112
113    def test_unsupported_restype_1(self):
114        # Only "fundamental" result types are supported for callback
115        # functions, the type must have a non-NULL stgdict->setfunc.
116        # POINTER(c_double), for example, is not supported.
117
118        prototype = self.functype.__func__(POINTER(c_double))
119        # The type is checked when the prototype is called
120        self.assertRaises(TypeError, prototype, lambda: None)
121
122    def test_unsupported_restype_2(self):
123        prototype = self.functype.__func__(object)
124        self.assertRaises(TypeError, prototype, lambda: None)
125
126    def test_issue_7959(self):
127        proto = self.functype.__func__(None)
128
129        class X(object):
130            def func(self): pass
131            def __init__(self):
132                self.v = proto(self.func)
133
134        import gc
135        for i in range(32):
136            X()
137        gc.collect()
138        live = [x for x in gc.get_objects()
139                if isinstance(x, X)]
140        self.assertEqual(len(live), 0)
141
142    def test_issue12483(self):
143        import gc
144        class Nasty:
145            def __del__(self):
146                gc.collect()
147        CFUNCTYPE(None)(lambda x=Nasty(): None)
148
149
150@need_symbol('WINFUNCTYPE')
151class StdcallCallbacks(Callbacks):
152    try:
153        functype = WINFUNCTYPE
154    except NameError:
155        pass
156
157################################################################
158
159class SampleCallbacksTestCase(unittest.TestCase):
160
161    def test_integrate(self):
162        # Derived from some then non-working code, posted by David Foster
163        dll = CDLL(_ctypes_test.__file__)
164
165        # The function prototype called by 'integrate': double func(double);
166        CALLBACK = CFUNCTYPE(c_double, c_double)
167
168        # The integrate function itself, exposed from the _ctypes_test dll
169        integrate = dll.integrate
170        integrate.argtypes = (c_double, c_double, CALLBACK, c_long)
171        integrate.restype = c_double
172
173        def func(x):
174            return x**2
175
176        result = integrate(0.0, 1.0, CALLBACK(func), 10)
177        diff = abs(result - 1./3.)
178
179        self.assertLess(diff, 0.01, "%s not less than 0.01" % diff)
180
181    def test_issue_8959_a(self):
182        from ctypes.util import find_library
183        libc_path = find_library("c")
184        if not libc_path:
185            self.skipTest('could not find libc')
186        libc = CDLL(libc_path)
187
188        @CFUNCTYPE(c_int, POINTER(c_int), POINTER(c_int))
189        def cmp_func(a, b):
190            return a[0] - b[0]
191
192        array = (c_int * 5)(5, 1, 99, 7, 33)
193
194        libc.qsort(array, len(array), sizeof(c_int), cmp_func)
195        self.assertEqual(array[:], [1, 5, 7, 33, 99])
196
197    @need_symbol('WINFUNCTYPE')
198    def test_issue_8959_b(self):
199        from ctypes.wintypes import BOOL, HWND, LPARAM
200        global windowCount
201        windowCount = 0
202
203        @WINFUNCTYPE(BOOL, HWND, LPARAM)
204        def EnumWindowsCallbackFunc(hwnd, lParam):
205            global windowCount
206            windowCount += 1
207            return True #Allow windows to keep enumerating
208
209        windll.user32.EnumWindows(EnumWindowsCallbackFunc, 0)
210
211    def test_callback_register_int(self):
212        # Issue #8275: buggy handling of callback args under Win64
213        # NOTE: should be run on release builds as well
214        dll = CDLL(_ctypes_test.__file__)
215        CALLBACK = CFUNCTYPE(c_int, c_int, c_int, c_int, c_int, c_int)
216        # All this function does is call the callback with its args squared
217        func = dll._testfunc_cbk_reg_int
218        func.argtypes = (c_int, c_int, c_int, c_int, c_int, CALLBACK)
219        func.restype = c_int
220
221        def callback(a, b, c, d, e):
222            return a + b + c + d + e
223
224        result = func(2, 3, 4, 5, 6, CALLBACK(callback))
225        self.assertEqual(result, callback(2*2, 3*3, 4*4, 5*5, 6*6))
226
227    def test_callback_register_double(self):
228        # Issue #8275: buggy handling of callback args under Win64
229        # NOTE: should be run on release builds as well
230        dll = CDLL(_ctypes_test.__file__)
231        CALLBACK = CFUNCTYPE(c_double, c_double, c_double, c_double,
232                             c_double, c_double)
233        # All this function does is call the callback with its args squared
234        func = dll._testfunc_cbk_reg_double
235        func.argtypes = (c_double, c_double, c_double,
236                         c_double, c_double, CALLBACK)
237        func.restype = c_double
238
239        def callback(a, b, c, d, e):
240            return a + b + c + d + e
241
242        result = func(1.1, 2.2, 3.3, 4.4, 5.5, CALLBACK(callback))
243        self.assertEqual(result,
244                         callback(1.1*1.1, 2.2*2.2, 3.3*3.3, 4.4*4.4, 5.5*5.5))
245
246    def test_callback_large_struct(self):
247        class Check: pass
248
249        # This should mirror the structure in Modules/_ctypes/_ctypes_test.c
250        class X(Structure):
251            _fields_ = [
252                ('first', c_ulong),
253                ('second', c_ulong),
254                ('third', c_ulong),
255            ]
256
257        def callback(check, s):
258            check.first = s.first
259            check.second = s.second
260            check.third = s.third
261            # See issue #29565.
262            # The structure should be passed by value, so
263            # any changes to it should not be reflected in
264            # the value passed
265            s.first = s.second = s.third = 0x0badf00d
266
267        check = Check()
268        s = X()
269        s.first = 0xdeadbeef
270        s.second = 0xcafebabe
271        s.third = 0x0bad1dea
272
273        CALLBACK = CFUNCTYPE(None, X)
274        dll = CDLL(_ctypes_test.__file__)
275        func = dll._testfunc_cbk_large_struct
276        func.argtypes = (X, CALLBACK)
277        func.restype = None
278        # the function just calls the callback with the passed structure
279        func(s, CALLBACK(functools.partial(callback, check)))
280        self.assertEqual(check.first, s.first)
281        self.assertEqual(check.second, s.second)
282        self.assertEqual(check.third, s.third)
283        self.assertEqual(check.first, 0xdeadbeef)
284        self.assertEqual(check.second, 0xcafebabe)
285        self.assertEqual(check.third, 0x0bad1dea)
286        # See issue #29565.
287        # Ensure that the original struct is unchanged.
288        self.assertEqual(s.first, check.first)
289        self.assertEqual(s.second, check.second)
290        self.assertEqual(s.third, check.third)
291
292    def test_callback_too_many_args(self):
293        def func(*args):
294            return len(args)
295
296        CTYPES_MAX_ARGCOUNT = 1024
297        proto = CFUNCTYPE(c_int, *(c_int,) * CTYPES_MAX_ARGCOUNT)
298        cb = proto(func)
299        args1 = (1,) * CTYPES_MAX_ARGCOUNT
300        self.assertEqual(cb(*args1), CTYPES_MAX_ARGCOUNT)
301
302        args2 = (1,) * (CTYPES_MAX_ARGCOUNT + 1)
303        with self.assertRaises(ArgumentError):
304            cb(*args2)
305
306    def test_convert_result_error(self):
307        def func():
308            return ("tuple",)
309
310        proto = CFUNCTYPE(c_int)
311        ctypes_func = proto(func)
312        with support.catch_unraisable_exception() as cm:
313            # don't test the result since it is an uninitialized value
314            result = ctypes_func()
315
316            self.assertIsInstance(cm.unraisable.exc_value, TypeError)
317            self.assertEqual(cm.unraisable.err_msg,
318                             "Exception ignored on converting result "
319                             "of ctypes callback function")
320            self.assertIs(cm.unraisable.object, func)
321
322
323if __name__ == '__main__':
324    unittest.main()
325