1# mode: run 2 3cimport cython 4from cython.view cimport array 5 6from cython cimport integral 7from cpython cimport Py_INCREF 8 9from Cython import Shadow as pure_cython 10ctypedef char * string_t 11 12# floating = cython.fused_type(float, double) floating 13# integral = cython.fused_type(int, long) integral 14ctypedef cython.floating floating 15fused_type1 = cython.fused_type(int, long, float, double, string_t) 16fused_type2 = cython.fused_type(string_t) 17ctypedef fused_type1 *composed_t 18other_t = cython.fused_type(int, double) 19ctypedef double *p_double 20ctypedef int *p_int 21fused_type3 = cython.fused_type(int, double) 22fused_composite = cython.fused_type(fused_type2, fused_type3) 23 24def test_pure(): 25 """ 26 >>> test_pure() 27 10 28 """ 29 mytype = pure_cython.typedef(pure_cython.fused_type(int, long, complex)) 30 print mytype(10) 31 32 33cdef cdef_func_with_fused_args(fused_type1 x, fused_type1 y, fused_type2 z): 34 if fused_type1 is string_t: 35 print x.decode('ascii'), y.decode('ascii'), z.decode('ascii') 36 else: 37 print x, y, z.decode('ascii') 38 39 return x + y 40 41def test_cdef_func_with_fused_args(): 42 """ 43 >>> test_cdef_func_with_fused_args() 44 spam ham eggs 45 spamham 46 10 20 butter 47 30 48 4.2 8.6 bunny 49 12.8 50 """ 51 print cdef_func_with_fused_args(b'spam', b'ham', b'eggs').decode('ascii') 52 print cdef_func_with_fused_args(10, 20, b'butter') 53 print cdef_func_with_fused_args(4.2, 8.6, b'bunny') 54 55cdef fused_type1 fused_with_pointer(fused_type1 *array): 56 for i in range(5): 57 if fused_type1 is string_t: 58 print array[i].decode('ascii') 59 else: 60 print array[i] 61 62 obj = array[0] + array[1] + array[2] + array[3] + array[4] 63 # if cython.typeof(fused_type1) is string_t: 64 Py_INCREF(obj) 65 return obj 66 67def test_fused_with_pointer(): 68 """ 69 >>> test_fused_with_pointer() 70 0 71 1 72 2 73 3 74 4 75 10 76 <BLANKLINE> 77 0 78 1 79 2 80 3 81 4 82 10 83 <BLANKLINE> 84 0.0 85 1.0 86 2.0 87 3.0 88 4.0 89 10.0 90 <BLANKLINE> 91 humpty 92 dumpty 93 fall 94 splatch 95 breakfast 96 humptydumptyfallsplatchbreakfast 97 """ 98 cdef int[5] int_array 99 cdef long[5] long_array 100 cdef float[5] float_array 101 cdef string_t[5] string_array 102 103 cdef char *s 104 105 strings = [b"humpty", b"dumpty", b"fall", b"splatch", b"breakfast"] 106 107 for i in range(5): 108 int_array[i] = i 109 long_array[i] = i 110 float_array[i] = i 111 s = strings[i] 112 string_array[i] = s 113 114 print fused_with_pointer(int_array) 115 print 116 print fused_with_pointer(long_array) 117 print 118 print fused_with_pointer(float_array) 119 print 120 print fused_with_pointer(string_array).decode('ascii') 121 122cdef fused_type1* fused_pointer_except_null(fused_type1* x) except NULL: 123 if fused_type1 is string_t: 124 assert(bool(x[0])) 125 else: 126 assert(x[0] < 10) 127 return x 128 129def test_fused_pointer_except_null(value): 130 """ 131 >>> test_fused_pointer_except_null(1) 132 1 133 >>> test_fused_pointer_except_null(2.0) 134 2.0 135 >>> test_fused_pointer_except_null(b'foo') 136 foo 137 >>> test_fused_pointer_except_null(16) 138 Traceback (most recent call last): 139 AssertionError 140 >>> test_fused_pointer_except_null(15.1) 141 Traceback (most recent call last): 142 AssertionError 143 >>> test_fused_pointer_except_null(b'') 144 Traceback (most recent call last): 145 AssertionError 146 """ 147 if isinstance(value, int): 148 test_int = cython.declare(cython.int, value) 149 print fused_pointer_except_null(&test_int)[0] 150 elif isinstance(value, float): 151 test_float = cython.declare(cython.float, value) 152 print fused_pointer_except_null(&test_float)[0] 153 elif isinstance(value, bytes): 154 test_str = cython.declare(string_t, value) 155 print fused_pointer_except_null(&test_str)[0].decode('ascii') 156 157include "cythonarrayutil.pxi" 158 159cpdef cython.integral test_fused_memoryviews(cython.integral[:, ::1] a): 160 """ 161 >>> import cython 162 >>> a = create_array((3, 5), mode="c") 163 >>> test_fused_memoryviews[cython.int](a) 164 7 165 """ 166 return a[1, 2] 167 168ctypedef int[:, ::1] memview_int 169ctypedef long[:, ::1] memview_long 170memview_t = cython.fused_type(memview_int, memview_long) 171 172def test_fused_memoryview_def(memview_t a): 173 """ 174 >>> a = create_array((3, 5), mode="c") 175 >>> test_fused_memoryview_def["memview_int"](a) 176 7 177 """ 178 return a[1, 2] 179 180cdef test_specialize(fused_type1 x, fused_type1 *y, composed_t z, other_t *a): 181 cdef fused_type1 result 182 183 if composed_t is p_double: 184 print "double pointer" 185 186 if fused_type1 in floating: 187 result = x + y[0] + z[0] + a[0] 188 return result 189 190def test_specializations(): 191 """ 192 >>> test_specializations() 193 double pointer 194 double pointer 195 double pointer 196 double pointer 197 double pointer 198 """ 199 cdef object (*f)(double, double *, double *, int *) 200 201 cdef double somedouble = 2.2 202 cdef double otherdouble = 3.3 203 cdef int someint = 4 204 205 cdef p_double somedouble_p = &somedouble 206 cdef p_double otherdouble_p = &otherdouble 207 cdef p_int someint_p = &someint 208 209 f = test_specialize 210 assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6 211 212 f = <object (*)(double, double *, double *, int *)> test_specialize 213 assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6 214 215 assert (<object (*)(double, double *, double *, int *)> 216 test_specialize)(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6 217 218 f = test_specialize[double, int] 219 assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6 220 221 assert test_specialize[double, int](1.1, somedouble_p, otherdouble_p, someint_p) == 10.6 222 223 # The following cases are not supported 224 # f = test_specialize[double][p_int] 225 # print f(1.1, somedouble_p, otherdouble_p) 226 # print 227 228 # print test_specialize[double][p_int](1.1, somedouble_p, otherdouble_p) 229 # print 230 231 # print test_specialize[double](1.1, somedouble_p, otherdouble_p) 232 # print 233 234cdef opt_args(integral x, floating y = 4.0): 235 print x, y 236 237def test_opt_args(): 238 """ 239 >>> test_opt_args() 240 3 4.0 241 3 4.0 242 3 4.0 243 3 4.0 244 """ 245 opt_args[int, float](3) 246 opt_args[int, double](3) 247 opt_args[int, float](3, 4.0) 248 opt_args[int, double](3, 4.0) 249 250class NormalClass(object): 251 def method(self, cython.integral i): 252 print cython.typeof(i), i 253 254def test_normal_class(): 255 """ 256 >>> test_normal_class() 257 short 10 258 """ 259 NormalClass().method[pure_cython.short](10) 260 261def test_normal_class_refcount(): 262 """ 263 >>> test_normal_class_refcount() 264 short 10 265 0 266 """ 267 import sys 268 x = NormalClass() 269 c = sys.getrefcount(x) 270 x.method[pure_cython.short](10) 271 print sys.getrefcount(x) - c 272 273def test_fused_declarations(cython.integral i, cython.floating f): 274 """ 275 >>> test_fused_declarations[pure_cython.short, pure_cython.float](5, 6.6) 276 short 277 float 278 25 43.56 279 >>> test_fused_declarations[pure_cython.long, pure_cython.double](5, 6.6) 280 long 281 double 282 25 43.56 283 """ 284 cdef cython.integral squared_int = i * i 285 cdef cython.floating squared_float = f * f 286 287 assert cython.typeof(squared_int) == cython.typeof(i) 288 assert cython.typeof(squared_float) == cython.typeof(f) 289 290 print cython.typeof(squared_int) 291 print cython.typeof(squared_float) 292 print '%d %.2f' % (squared_int, squared_float) 293 294def test_sizeof_fused_type(fused_type1 b): 295 """ 296 >>> test_sizeof_fused_type[pure_cython.double](11.1) 297 """ 298 t = sizeof(b), sizeof(fused_type1), sizeof(double) 299 assert t[0] == t[1] == t[2], t 300 301def get_array(itemsize, format): 302 result = array((10,), itemsize, format) 303 result[5] = 5.0 304 result[6] = 6.0 305 return result 306 307def get_intc_array(): 308 result = array((10,), sizeof(int), 'i') 309 result[5] = 5 310 result[6] = 6 311 return result 312 313def test_fused_memslice_dtype(cython.floating[:] array): 314 """ 315 Note: the np.ndarray dtype test is in numpy_test 316 317 >>> import cython 318 >>> sorted(test_fused_memslice_dtype.__signatures__) 319 ['double', 'float'] 320 321 >>> test_fused_memslice_dtype[cython.double](get_array(8, 'd')) 322 double[:] double[:] 5.0 6.0 323 >>> test_fused_memslice_dtype[cython.float](get_array(4, 'f')) 324 float[:] float[:] 5.0 6.0 325 """ 326 cdef cython.floating[:] otherarray = array[0:100:1] 327 print cython.typeof(array), cython.typeof(otherarray), \ 328 array[5], otherarray[6] 329 cdef cython.floating value; 330 cdef cython.floating[:] test_cast = <cython.floating[:1:1]>&value 331 332def test_fused_memslice_dtype_repeated(cython.floating[:] array1, cython.floating[:] array2): 333 """ 334 Note: the np.ndarray dtype test is in numpy_test 335 336 >>> sorted(test_fused_memslice_dtype_repeated.__signatures__) 337 ['double', 'float'] 338 339 >>> test_fused_memslice_dtype_repeated(get_array(8, 'd'), get_array(8, 'd')) 340 double[:] double[:] 341 >>> test_fused_memslice_dtype_repeated(get_array(4, 'f'), get_array(4, 'f')) 342 float[:] float[:] 343 >>> test_fused_memslice_dtype_repeated(get_array(8, 'd'), get_array(4, 'f')) 344 Traceback (most recent call last): 345 ValueError: Buffer dtype mismatch, expected 'double' but got 'float' 346 """ 347 print cython.typeof(array1), cython.typeof(array2) 348 349def test_fused_memslice_dtype_repeated_2(cython.floating[:] array1, cython.floating[:] array2, 350 fused_type3[:] array3): 351 """ 352 Note: the np.ndarray dtype test is in numpy_test 353 354 >>> sorted(test_fused_memslice_dtype_repeated_2.__signatures__) 355 ['double|double', 'double|int', 'float|double', 'float|int'] 356 357 >>> test_fused_memslice_dtype_repeated_2(get_array(8, 'd'), get_array(8, 'd'), get_array(8, 'd')) 358 double[:] double[:] double[:] 359 >>> test_fused_memslice_dtype_repeated_2(get_array(8, 'd'), get_array(8, 'd'), get_intc_array()) 360 double[:] double[:] int[:] 361 >>> test_fused_memslice_dtype_repeated_2(get_array(4, 'f'), get_array(4, 'f'), get_intc_array()) 362 float[:] float[:] int[:] 363 """ 364 print cython.typeof(array1), cython.typeof(array2), cython.typeof(array3) 365 366def test_cython_numeric(cython.numeric arg): 367 """ 368 Test to see whether complex numbers have their utility code declared 369 properly. 370 371 >>> test_cython_numeric(10.0 + 1j) 372 double complex (10+1j) 373 """ 374 print cython.typeof(arg), arg 375 376cdef fused ints_t: 377 int 378 long 379 380cdef _test_index_fused_args(cython.floating f, ints_t i): 381 print cython.typeof(f), cython.typeof(i) 382 383def test_index_fused_args(cython.floating f, ints_t i): 384 """ 385 >>> import cython 386 >>> test_index_fused_args[cython.double, cython.int](2.0, 3) 387 double int 388 """ 389 _test_index_fused_args[cython.floating, ints_t](f, i) 390 391 392def test_composite(fused_composite x): 393 """ 394 >>> print(test_composite(b'a').decode('ascii')) 395 a 396 >>> test_composite(3) 397 6 398 >>> test_composite(3.0) 399 6.0 400 """ 401 if fused_composite is string_t: 402 return x 403 else: 404 return 2 * x 405 406 407### see GH3642 - presence of cdef inside "unrelated" caused a type to be incorrectly inferred 408cdef unrelated(cython.floating x): 409 cdef cython.floating t = 1 410 return t 411 412cdef handle_float(float* x): return 'float' 413 414cdef handle_double(double* x): return 'double' 415 416def convert_to_ptr(cython.floating x): 417 """ 418 >>> convert_to_ptr(1.0) 419 'double' 420 >>> convert_to_ptr['double'](1.0) 421 'double' 422 >>> convert_to_ptr['float'](1.0) 423 'float' 424 """ 425 if cython.floating is float: 426 return handle_float(&x) 427 elif cython.floating is double: 428 return handle_double(&x) 429