1# mode: run
2
3cimport cython
4from cython.view cimport array
5
6from cython cimport integral
7from cpython cimport Py_INCREF
8
9from Cython import Shadow as pure_cython
10ctypedef char * string_t
11
12# floating = cython.fused_type(float, double) floating
13# integral = cython.fused_type(int, long) integral
14ctypedef cython.floating floating
15fused_type1 = cython.fused_type(int, long, float, double, string_t)
16fused_type2 = cython.fused_type(string_t)
17ctypedef fused_type1 *composed_t
18other_t = cython.fused_type(int, double)
19ctypedef double *p_double
20ctypedef int *p_int
21fused_type3 = cython.fused_type(int, double)
22fused_composite = cython.fused_type(fused_type2, fused_type3)
23
24def test_pure():
25    """
26    >>> test_pure()
27    10
28    """
29    mytype = pure_cython.typedef(pure_cython.fused_type(int, long, complex))
30    print mytype(10)
31
32
33cdef cdef_func_with_fused_args(fused_type1 x, fused_type1 y, fused_type2 z):
34    if fused_type1 is string_t:
35        print x.decode('ascii'), y.decode('ascii'), z.decode('ascii')
36    else:
37        print x, y, z.decode('ascii')
38
39    return x + y
40
41def test_cdef_func_with_fused_args():
42    """
43    >>> test_cdef_func_with_fused_args()
44    spam ham eggs
45    spamham
46    10 20 butter
47    30
48    4.2 8.6 bunny
49    12.8
50    """
51    print cdef_func_with_fused_args(b'spam', b'ham', b'eggs').decode('ascii')
52    print cdef_func_with_fused_args(10, 20, b'butter')
53    print cdef_func_with_fused_args(4.2, 8.6, b'bunny')
54
55cdef fused_type1 fused_with_pointer(fused_type1 *array):
56    for i in range(5):
57        if fused_type1 is string_t:
58            print array[i].decode('ascii')
59        else:
60            print array[i]
61
62    obj = array[0] + array[1] + array[2] + array[3] + array[4]
63    # if cython.typeof(fused_type1) is string_t:
64    Py_INCREF(obj)
65    return obj
66
67def test_fused_with_pointer():
68    """
69    >>> test_fused_with_pointer()
70    0
71    1
72    2
73    3
74    4
75    10
76    <BLANKLINE>
77    0
78    1
79    2
80    3
81    4
82    10
83    <BLANKLINE>
84    0.0
85    1.0
86    2.0
87    3.0
88    4.0
89    10.0
90    <BLANKLINE>
91    humpty
92    dumpty
93    fall
94    splatch
95    breakfast
96    humptydumptyfallsplatchbreakfast
97    """
98    cdef int[5] int_array
99    cdef long[5] long_array
100    cdef float[5] float_array
101    cdef string_t[5] string_array
102
103    cdef char *s
104
105    strings = [b"humpty", b"dumpty", b"fall", b"splatch", b"breakfast"]
106
107    for i in range(5):
108        int_array[i] = i
109        long_array[i] = i
110        float_array[i] = i
111        s = strings[i]
112        string_array[i] = s
113
114    print fused_with_pointer(int_array)
115    print
116    print fused_with_pointer(long_array)
117    print
118    print fused_with_pointer(float_array)
119    print
120    print fused_with_pointer(string_array).decode('ascii')
121
122cdef fused_type1* fused_pointer_except_null(fused_type1* x) except NULL:
123    if fused_type1 is string_t:
124        assert(bool(x[0]))
125    else:
126        assert(x[0] < 10)
127    return x
128
129def test_fused_pointer_except_null(value):
130    """
131    >>> test_fused_pointer_except_null(1)
132    1
133    >>> test_fused_pointer_except_null(2.0)
134    2.0
135    >>> test_fused_pointer_except_null(b'foo')
136    foo
137    >>> test_fused_pointer_except_null(16)
138    Traceback (most recent call last):
139    AssertionError
140    >>> test_fused_pointer_except_null(15.1)
141    Traceback (most recent call last):
142    AssertionError
143    >>> test_fused_pointer_except_null(b'')
144    Traceback (most recent call last):
145    AssertionError
146    """
147    if isinstance(value, int):
148        test_int = cython.declare(cython.int, value)
149        print fused_pointer_except_null(&test_int)[0]
150    elif isinstance(value, float):
151        test_float = cython.declare(cython.float, value)
152        print fused_pointer_except_null(&test_float)[0]
153    elif isinstance(value, bytes):
154        test_str = cython.declare(string_t, value)
155        print fused_pointer_except_null(&test_str)[0].decode('ascii')
156
157include "cythonarrayutil.pxi"
158
159cpdef cython.integral test_fused_memoryviews(cython.integral[:, ::1] a):
160    """
161    >>> import cython
162    >>> a = create_array((3, 5), mode="c")
163    >>> test_fused_memoryviews[cython.int](a)
164    7
165    """
166    return a[1, 2]
167
168ctypedef int[:, ::1] memview_int
169ctypedef long[:, ::1] memview_long
170memview_t = cython.fused_type(memview_int, memview_long)
171
172def test_fused_memoryview_def(memview_t a):
173    """
174    >>> a = create_array((3, 5), mode="c")
175    >>> test_fused_memoryview_def["memview_int"](a)
176    7
177    """
178    return a[1, 2]
179
180cdef test_specialize(fused_type1 x, fused_type1 *y, composed_t z, other_t *a):
181    cdef fused_type1 result
182
183    if composed_t is p_double:
184        print "double pointer"
185
186    if fused_type1 in floating:
187        result = x + y[0] + z[0] + a[0]
188        return result
189
190def test_specializations():
191    """
192    >>> test_specializations()
193    double pointer
194    double pointer
195    double pointer
196    double pointer
197    double pointer
198    """
199    cdef object (*f)(double, double *, double *, int *)
200
201    cdef double somedouble = 2.2
202    cdef double otherdouble = 3.3
203    cdef int someint = 4
204
205    cdef p_double somedouble_p = &somedouble
206    cdef p_double otherdouble_p = &otherdouble
207    cdef p_int someint_p = &someint
208
209    f = test_specialize
210    assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
211
212    f = <object (*)(double, double *, double *, int *)> test_specialize
213    assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
214
215    assert (<object (*)(double, double *, double *, int *)>
216            test_specialize)(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
217
218    f = test_specialize[double, int]
219    assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
220
221    assert test_specialize[double, int](1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
222
223    # The following cases are not supported
224    # f = test_specialize[double][p_int]
225    # print f(1.1, somedouble_p, otherdouble_p)
226    # print
227
228    # print test_specialize[double][p_int](1.1, somedouble_p, otherdouble_p)
229    # print
230
231    # print test_specialize[double](1.1, somedouble_p, otherdouble_p)
232    # print
233
234cdef opt_args(integral x, floating y = 4.0):
235    print x, y
236
237def test_opt_args():
238    """
239    >>> test_opt_args()
240    3 4.0
241    3 4.0
242    3 4.0
243    3 4.0
244    """
245    opt_args[int,  float](3)
246    opt_args[int, double](3)
247    opt_args[int,  float](3, 4.0)
248    opt_args[int, double](3, 4.0)
249
250class NormalClass(object):
251    def method(self, cython.integral i):
252        print cython.typeof(i), i
253
254def test_normal_class():
255    """
256    >>> test_normal_class()
257    short 10
258    """
259    NormalClass().method[pure_cython.short](10)
260
261def test_normal_class_refcount():
262    """
263    >>> test_normal_class_refcount()
264    short 10
265    0
266    """
267    import sys
268    x = NormalClass()
269    c = sys.getrefcount(x)
270    x.method[pure_cython.short](10)
271    print sys.getrefcount(x) - c
272
273def test_fused_declarations(cython.integral i, cython.floating f):
274    """
275    >>> test_fused_declarations[pure_cython.short, pure_cython.float](5, 6.6)
276    short
277    float
278    25 43.56
279    >>> test_fused_declarations[pure_cython.long, pure_cython.double](5, 6.6)
280    long
281    double
282    25 43.56
283    """
284    cdef cython.integral squared_int = i * i
285    cdef cython.floating squared_float = f * f
286
287    assert cython.typeof(squared_int) == cython.typeof(i)
288    assert cython.typeof(squared_float) == cython.typeof(f)
289
290    print cython.typeof(squared_int)
291    print cython.typeof(squared_float)
292    print '%d %.2f' % (squared_int, squared_float)
293
294def test_sizeof_fused_type(fused_type1 b):
295    """
296    >>> test_sizeof_fused_type[pure_cython.double](11.1)
297    """
298    t = sizeof(b), sizeof(fused_type1), sizeof(double)
299    assert t[0] == t[1] == t[2], t
300
301def get_array(itemsize, format):
302    result = array((10,), itemsize, format)
303    result[5] = 5.0
304    result[6] = 6.0
305    return result
306
307def get_intc_array():
308    result = array((10,), sizeof(int), 'i')
309    result[5] = 5
310    result[6] = 6
311    return result
312
313def test_fused_memslice_dtype(cython.floating[:] array):
314    """
315    Note: the np.ndarray dtype test is in numpy_test
316
317    >>> import cython
318    >>> sorted(test_fused_memslice_dtype.__signatures__)
319    ['double', 'float']
320
321    >>> test_fused_memslice_dtype[cython.double](get_array(8, 'd'))
322    double[:] double[:] 5.0 6.0
323    >>> test_fused_memslice_dtype[cython.float](get_array(4, 'f'))
324    float[:] float[:] 5.0 6.0
325    """
326    cdef cython.floating[:] otherarray = array[0:100:1]
327    print cython.typeof(array), cython.typeof(otherarray), \
328          array[5], otherarray[6]
329    cdef cython.floating value;
330    cdef cython.floating[:] test_cast = <cython.floating[:1:1]>&value
331
332def test_fused_memslice_dtype_repeated(cython.floating[:] array1, cython.floating[:] array2):
333    """
334    Note: the np.ndarray dtype test is in numpy_test
335
336    >>> sorted(test_fused_memslice_dtype_repeated.__signatures__)
337    ['double', 'float']
338
339    >>> test_fused_memslice_dtype_repeated(get_array(8, 'd'), get_array(8, 'd'))
340    double[:] double[:]
341    >>> test_fused_memslice_dtype_repeated(get_array(4, 'f'), get_array(4, 'f'))
342    float[:] float[:]
343    >>> test_fused_memslice_dtype_repeated(get_array(8, 'd'), get_array(4, 'f'))
344    Traceback (most recent call last):
345    ValueError: Buffer dtype mismatch, expected 'double' but got 'float'
346    """
347    print cython.typeof(array1), cython.typeof(array2)
348
349def test_fused_memslice_dtype_repeated_2(cython.floating[:] array1, cython.floating[:] array2,
350                                         fused_type3[:] array3):
351    """
352    Note: the np.ndarray dtype test is in numpy_test
353
354    >>> sorted(test_fused_memslice_dtype_repeated_2.__signatures__)
355    ['double|double', 'double|int', 'float|double', 'float|int']
356
357    >>> test_fused_memslice_dtype_repeated_2(get_array(8, 'd'), get_array(8, 'd'), get_array(8, 'd'))
358    double[:] double[:] double[:]
359    >>> test_fused_memslice_dtype_repeated_2(get_array(8, 'd'), get_array(8, 'd'), get_intc_array())
360    double[:] double[:] int[:]
361    >>> test_fused_memslice_dtype_repeated_2(get_array(4, 'f'), get_array(4, 'f'), get_intc_array())
362    float[:] float[:] int[:]
363    """
364    print cython.typeof(array1), cython.typeof(array2), cython.typeof(array3)
365
366def test_cython_numeric(cython.numeric arg):
367    """
368    Test to see whether complex numbers have their utility code declared
369    properly.
370
371    >>> test_cython_numeric(10.0 + 1j)
372    double complex (10+1j)
373    """
374    print cython.typeof(arg), arg
375
376cdef fused ints_t:
377    int
378    long
379
380cdef _test_index_fused_args(cython.floating f, ints_t i):
381    print cython.typeof(f), cython.typeof(i)
382
383def test_index_fused_args(cython.floating f, ints_t i):
384    """
385    >>> import cython
386    >>> test_index_fused_args[cython.double, cython.int](2.0, 3)
387    double int
388    """
389    _test_index_fused_args[cython.floating, ints_t](f, i)
390
391
392def test_composite(fused_composite x):
393    """
394    >>> print(test_composite(b'a').decode('ascii'))
395    a
396    >>> test_composite(3)
397    6
398    >>> test_composite(3.0)
399    6.0
400    """
401    if fused_composite is string_t:
402        return x
403    else:
404        return 2 * x
405
406
407### see GH3642 - presence of cdef inside "unrelated" caused a type to be incorrectly inferred
408cdef unrelated(cython.floating x):
409    cdef cython.floating t = 1
410    return t
411
412cdef handle_float(float* x): return 'float'
413
414cdef handle_double(double* x): return 'double'
415
416def convert_to_ptr(cython.floating x):
417    """
418    >>> convert_to_ptr(1.0)
419    'double'
420    >>> convert_to_ptr['double'](1.0)
421    'double'
422    >>> convert_to_ptr['float'](1.0)
423    'float'
424    """
425    if cython.floating is float:
426        return handle_float(&x)
427    elif cython.floating is double:
428        return handle_double(&x)
429