1import types as pytypes 2from numba import jit, njit, cfunc, types, int64, float64, float32, errors 3from numba import literal_unroll 4from numba.core.config import IS_32BITS, IS_WIN32 5import ctypes 6import warnings 7 8from .support import TestCase 9 10 11def dump(foo): # FOR DEBUGGING, TO BE REMOVED 12 from numba.core import function 13 foo_type = function.fromobject(foo) 14 foo_sig = foo_type.signature() 15 foo.compile(foo_sig) 16 print('{" LLVM IR OF "+foo.__name__+" ":*^70}') 17 print(foo.inspect_llvm(foo_sig.args)) 18 print('{"":*^70}') 19 20 21# Decorators for transforming a Python function to different kinds of 22# functions: 23 24def mk_cfunc_func(sig): 25 def cfunc_func(func): 26 assert isinstance(func, pytypes.FunctionType), repr(func) 27 f = cfunc(sig)(func) 28 f.pyfunc = func 29 return f 30 return cfunc_func 31 32 33def njit_func(func): 34 assert isinstance(func, pytypes.FunctionType), repr(func) 35 f = jit(nopython=True)(func) 36 f.pyfunc = func 37 return f 38 39 40def mk_njit_with_sig_func(sig): 41 def njit_with_sig_func(func): 42 assert isinstance(func, pytypes.FunctionType), repr(func) 43 f = jit(sig, nopython=True)(func) 44 f.pyfunc = func 45 return f 46 return njit_with_sig_func 47 48 49def mk_ctypes_func(sig): 50 def ctypes_func(func, sig=int64(int64)): 51 assert isinstance(func, pytypes.FunctionType), repr(func) 52 cfunc = mk_cfunc_func(sig)(func) 53 addr = cfunc._wrapper_address 54 if sig == int64(int64): 55 f = ctypes.CFUNCTYPE(ctypes.c_int64)(addr) 56 f.pyfunc = func 57 return f 58 raise NotImplementedError( 59 f'ctypes decorator for {func} with signature {sig}') 60 return ctypes_func 61 62 63class WAP(types.WrapperAddressProtocol): 64 """An example implementation of wrapper address protocol. 65 66 """ 67 def __init__(self, func, sig): 68 self.pyfunc = func 69 self.cfunc = cfunc(sig)(func) 70 self.sig = sig 71 72 def __wrapper_address__(self): 73 return self.cfunc._wrapper_address 74 75 def signature(self): 76 return self.sig 77 78 def __call__(self, *args, **kwargs): 79 return self.pyfunc(*args, **kwargs) 80 81 82def mk_wap_func(sig): 83 def wap_func(func): 84 return WAP(func, sig) 85 return wap_func 86 87 88class TestFunctionType(TestCase): 89 """Test first-class functions in the context of a Numba jit compiled 90 function. 91 92 """ 93 94 def test_in__(self): 95 """Function is passed in as an argument. 96 """ 97 98 def a(i): 99 return i + 1 100 101 def foo(f): 102 return 0 103 104 sig = int64(int64) 105 106 for decor in [mk_cfunc_func(sig), 107 njit_func, 108 mk_njit_with_sig_func(sig), 109 mk_ctypes_func(sig), 110 mk_wap_func(sig)]: 111 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 112 jit_ = jit(**jit_opts) 113 with self.subTest(decor=decor.__name__, jit=jit_opts): 114 a_ = decor(a) 115 self.assertEqual(jit_(foo)(a_), foo(a)) 116 117 def test_in_call__(self): 118 """Function is passed in as an argument and called. 119 Also test different return values. 120 """ 121 122 def a_i64(i): 123 return i + 1234567 124 125 def a_f64(i): 126 return i + 1.5 127 128 def a_str(i): 129 return "abc" 130 131 def foo(f): 132 return f(123) 133 134 for f, sig in [(a_i64, int64(int64)), (a_f64, float64(int64))]: 135 for decor in [mk_cfunc_func(sig), njit_func, 136 mk_njit_with_sig_func(sig), 137 mk_wap_func(sig)]: 138 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 139 jit_ = jit(**jit_opts) 140 with self.subTest( 141 sig=sig, decor=decor.__name__, jit=jit_opts): 142 f_ = decor(f) 143 self.assertEqual(jit_(foo)(f_), foo(f)) 144 145 def test_in_call_out(self): 146 """Function is passed in as an argument, called, and returned. 147 """ 148 149 def a(i): 150 return i + 1 151 152 def foo(f): 153 f(123) 154 return f 155 156 sig = int64(int64) 157 158 for decor in [mk_cfunc_func(sig), njit_func, 159 mk_njit_with_sig_func(sig), mk_wap_func(sig)]: 160 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 161 jit_ = jit(**jit_opts) 162 with self.subTest(decor=decor.__name__): 163 a_ = decor(a) 164 r1 = jit_(foo)(a_).pyfunc 165 r2 = foo(a) 166 self.assertEqual(r1, r2) 167 168 def test_in_seq_call(self): 169 """Functions are passed in as arguments, used as tuple items, and 170 called. 171 172 """ 173 def a(i): 174 return i + 1 175 176 def b(i): 177 return i + 2 178 179 def foo(f, g): 180 r = 0 181 for f_ in (f, g): 182 r = r + f_(r) 183 return r 184 185 sig = int64(int64) 186 187 for decor in [mk_cfunc_func(sig), mk_wap_func(sig), 188 mk_njit_with_sig_func(sig)]: 189 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 190 jit_ = jit(**jit_opts) 191 with self.subTest(decor=decor.__name__): 192 a_ = decor(a) 193 b_ = decor(b) 194 self.assertEqual(jit_(foo)(a_, b_), foo(a, b)) 195 196 def test_in_ns_seq_call(self): 197 """Functions are passed in as an argument and via namespace scoping 198 (mixed pathways), used as tuple items, and called. 199 200 """ 201 202 def a(i): 203 return i + 1 204 205 def b(i): 206 return i + 2 207 208 def mkfoo(b_): 209 def foo(f): 210 r = 0 211 for f_ in (f, b_): 212 r = r + f_(r) 213 return r 214 return foo 215 216 sig = int64(int64) 217 218 for decor in [mk_cfunc_func(sig), 219 mk_njit_with_sig_func(sig), mk_wap_func(sig), 220 mk_ctypes_func(sig)][:-1]: 221 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 222 jit_ = jit(**jit_opts) 223 with self.subTest(decor=decor.__name__): 224 a_ = decor(a) 225 b_ = decor(b) 226 self.assertEqual(jit_(mkfoo(b_))(a_), mkfoo(b)(a)) 227 228 def test_ns_call(self): 229 """Function is passed in via namespace scoping and called. 230 231 """ 232 233 def a(i): 234 return i + 1 235 236 def mkfoo(a_): 237 def foo(): 238 return a_(123) 239 return foo 240 241 sig = int64(int64) 242 243 for decor in [mk_cfunc_func(sig), njit_func, 244 mk_njit_with_sig_func(sig), mk_wap_func(sig)]: 245 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 246 jit_ = jit(**jit_opts) 247 with self.subTest(decor=decor.__name__): 248 a_ = decor(a) 249 self.assertEqual(jit_(mkfoo(a_))(), mkfoo(a)()) 250 251 def test_ns_out(self): 252 """Function is passed in via namespace scoping and returned. 253 254 """ 255 def a(i): 256 return i + 1 257 258 def mkfoo(a_): 259 def foo(): 260 return a_ 261 return foo 262 263 sig = int64(int64) 264 265 for decor in [mk_cfunc_func(sig), njit_func, 266 mk_njit_with_sig_func(sig), mk_wap_func(sig), 267 mk_ctypes_func(sig)][:-1]: 268 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 269 jit_ = jit(**jit_opts) 270 with self.subTest(decor=decor.__name__): 271 a_ = decor(a) 272 self.assertEqual(jit_(mkfoo(a_))().pyfunc, mkfoo(a)()) 273 274 def test_ns_call_out(self): 275 """Function is passed in via namespace scoping, called, and then 276 returned. 277 278 """ 279 def a(i): 280 return i + 1 281 282 def mkfoo(a_): 283 def foo(): 284 a_(123) 285 return a_ 286 return foo 287 288 sig = int64(int64) 289 290 for decor in [mk_cfunc_func(sig), njit_func, 291 mk_njit_with_sig_func(sig), mk_wap_func(sig), 292 mk_ctypes_func(sig)]: 293 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 294 jit_ = jit(**jit_opts) 295 with self.subTest(decor=decor.__name__): 296 a_ = decor(a) 297 self.assertEqual(jit_(mkfoo(a_))().pyfunc, mkfoo(a)()) 298 299 def test_in_overload(self): 300 """Function is passed in as an argument and called with different 301 argument types. 302 303 """ 304 def a(i): 305 return i + 1 306 307 def foo(f): 308 r1 = f(123) 309 r2 = f(123.45) 310 return (r1, r2) 311 312 for decor in [njit_func]: 313 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 314 jit_ = jit(**jit_opts) 315 with self.subTest(decor=decor.__name__): 316 a_ = decor(a) 317 self.assertEqual(jit_(foo)(a_), foo(a)) 318 319 def test_ns_overload(self): 320 """Function is passed in via namespace scoping and called with 321 different argument types. 322 323 """ 324 def a(i): 325 return i + 1 326 327 def mkfoo(a_): 328 def foo(): 329 r1 = a_(123) 330 r2 = a_(123.45) 331 return (r1, r2) 332 return foo 333 334 for decor in [njit_func]: 335 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 336 jit_ = jit(**jit_opts) 337 with self.subTest(decor=decor.__name__): 338 a_ = decor(a) 339 self.assertEqual(jit_(mkfoo(a_))(), mkfoo(a)()) 340 341 def test_in_choose(self): 342 """Functions are passed in as arguments and called conditionally. 343 344 """ 345 def a(i): 346 return i + 1 347 348 def b(i): 349 return i + 2 350 351 def foo(a, b, choose_left): 352 if choose_left: 353 r = a(1) 354 else: 355 r = b(2) 356 return r 357 358 sig = int64(int64) 359 360 for decor in [mk_cfunc_func(sig), njit_func, 361 mk_njit_with_sig_func(sig), mk_wap_func(sig)]: 362 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 363 jit_ = jit(**jit_opts) 364 with self.subTest(decor=decor.__name__): 365 a_ = decor(a) 366 b_ = decor(b) 367 self.assertEqual(jit_(foo)(a_, b_, True), foo(a, b, True)) 368 self.assertEqual(jit_(foo)(a_, b_, False), 369 foo(a, b, False)) 370 self.assertNotEqual(jit_(foo)(a_, b_, True), 371 foo(a, b, False)) 372 373 def test_ns_choose(self): 374 """Functions are passed in via namespace scoping and called 375 conditionally. 376 377 """ 378 def a(i): 379 return i + 1 380 381 def b(i): 382 return i + 2 383 384 def mkfoo(a_, b_): 385 def foo(choose_left): 386 if choose_left: 387 r = a_(1) 388 else: 389 r = b_(2) 390 return r 391 return foo 392 393 sig = int64(int64) 394 395 for decor in [mk_cfunc_func(sig), njit_func, 396 mk_njit_with_sig_func(sig), mk_wap_func(sig)]: 397 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 398 jit_ = jit(**jit_opts) 399 with self.subTest(decor=decor.__name__): 400 a_ = decor(a) 401 b_ = decor(b) 402 self.assertEqual(jit_(mkfoo(a_, b_))(True), 403 mkfoo(a, b)(True)) 404 self.assertEqual(jit_(mkfoo(a_, b_))(False), 405 mkfoo(a, b)(False)) 406 self.assertNotEqual(jit_(mkfoo(a_, b_))(True), 407 mkfoo(a, b)(False)) 408 409 def test_in_choose_out(self): 410 """Functions are passed in as arguments and returned conditionally. 411 412 """ 413 def a(i): 414 return i + 1 415 416 def b(i): 417 return i + 2 418 419 def foo(a, b, choose_left): 420 if choose_left: 421 return a 422 else: 423 return b 424 425 sig = int64(int64) 426 427 for decor in [mk_cfunc_func(sig), njit_func, 428 mk_njit_with_sig_func(sig), mk_wap_func(sig)]: 429 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 430 jit_ = jit(**jit_opts) 431 with self.subTest(decor=decor.__name__): 432 a_ = decor(a) 433 b_ = decor(b) 434 self.assertEqual(jit_(foo)(a_, b_, True).pyfunc, 435 foo(a, b, True)) 436 self.assertEqual(jit_(foo)(a_, b_, False).pyfunc, 437 foo(a, b, False)) 438 self.assertNotEqual(jit_(foo)(a_, b_, True).pyfunc, 439 foo(a, b, False)) 440 441 def test_in_choose_func_value(self): 442 """Functions are passed in as arguments, selected conditionally and 443 called. 444 445 """ 446 def a(i): 447 return i + 1 448 449 def b(i): 450 return i + 2 451 452 def foo(a, b, choose_left): 453 if choose_left: 454 f = a 455 else: 456 f = b 457 return f(1) 458 459 sig = int64(int64) 460 461 for decor in [mk_cfunc_func(sig), mk_wap_func(sig), njit_func, 462 mk_njit_with_sig_func(sig)]: 463 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 464 jit_ = jit(**jit_opts) 465 with self.subTest(decor=decor.__name__): 466 a_ = decor(a) 467 b_ = decor(b) 468 self.assertEqual(jit_(foo)(a_, b_, True), foo(a, b, True)) 469 self.assertEqual(jit_(foo)(a_, b_, False), 470 foo(a, b, False)) 471 self.assertNotEqual(jit_(foo)(a_, b_, True), 472 foo(a, b, False)) 473 474 def test_in_pick_func_call(self): 475 """Functions are passed in as items of tuple argument, retrieved via 476 indexing, and called. 477 478 """ 479 def a(i): 480 return i + 1 481 482 def b(i): 483 return i + 2 484 485 def foo(funcs, i): 486 f = funcs[i] 487 r = f(123) 488 return r 489 490 sig = int64(int64) 491 492 for decor in [mk_cfunc_func(sig), mk_wap_func(sig), 493 mk_njit_with_sig_func(sig)]: 494 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 495 jit_ = jit(**jit_opts) 496 with self.subTest(decor=decor.__name__): 497 a_ = decor(a) 498 b_ = decor(b) 499 self.assertEqual(jit_(foo)((a_, b_), 0), foo((a, b), 0)) 500 self.assertEqual(jit_(foo)((a_, b_), 1), foo((a, b), 1)) 501 self.assertNotEqual(jit_(foo)((a_, b_), 0), foo((a, b), 1)) 502 503 def test_in_iter_func_call(self): 504 """Functions are passed in as items of tuple argument, retrieved via 505 indexing, and called within a variable for-loop. 506 507 """ 508 def a(i): 509 return i + 1 510 511 def b(i): 512 return i + 2 513 514 def foo(funcs, n): 515 r = 0 516 for i in range(n): 517 f = funcs[i] 518 r = r + f(r) 519 return r 520 521 sig = int64(int64) 522 523 for decor in [mk_cfunc_func(sig), mk_wap_func(sig), 524 mk_njit_with_sig_func(sig)]: 525 for jit_opts in [dict(nopython=True), dict(forceobj=True)]: 526 jit_ = jit(**jit_opts) 527 with self.subTest(decor=decor.__name__): 528 a_ = decor(a) 529 b_ = decor(b) 530 self.assertEqual(jit_(foo)((a_, b_), 2), foo((a, b), 2)) 531 532 def test_experimental_feature_warning(self): 533 @jit(nopython=True) 534 def more(x): 535 return x + 1 536 537 @jit(nopython=True) 538 def less(x): 539 return x - 1 540 541 @jit(nopython=True) 542 def foo(sel, x): 543 fn = more if sel else less 544 return fn(x) 545 546 with warnings.catch_warnings(record=True) as ws: 547 warnings.simplefilter("always") 548 res = foo(True, 10) 549 550 self.assertEqual(res, 11) 551 self.assertEqual(foo(False, 10), 9) 552 553 self.assertGreaterEqual(len(ws), 1) 554 pat = "First-class function type feature is experimental" 555 for w in ws: 556 if pat in str(w.message): 557 break 558 else: 559 self.fail("missing warning") 560 561 562class TestFunctionTypeExtensions(TestCase): 563 """Test calling external library functions within Numba jit compiled 564 functions. 565 566 """ 567 568 def test_wrapper_address_protocol_libm(self): 569 """Call cos and sinf from standard math library. 570 571 """ 572 import ctypes.util 573 574 class LibM(types.WrapperAddressProtocol): 575 576 def __init__(self, fname): 577 if IS_WIN32: 578 lib = ctypes.cdll.msvcrt 579 else: 580 libpath = ctypes.util.find_library('m') 581 lib = ctypes.cdll.LoadLibrary(libpath) 582 self.lib = lib 583 self._name = fname 584 if fname == 'cos': 585 # test for double-precision math function 586 if IS_WIN32 and IS_32BITS: 587 # 32-bit Windows math library does not provide 588 # a double-precision cos function, so 589 # disabling the function 590 addr = None 591 signature = None 592 else: 593 addr = ctypes.cast(self.lib.cos, ctypes.c_voidp).value 594 signature = float64(float64) 595 elif fname == 'sinf': 596 # test for single-precision math function 597 if IS_WIN32 and IS_32BITS: 598 # 32-bit Windows math library provides sin 599 # (instead of sinf) that is a single-precision 600 # sin function 601 addr = ctypes.cast(self.lib.sin, ctypes.c_voidp).value 602 else: 603 # Other 32/64 bit platforms define sinf as the 604 # single-precision sin function 605 addr = ctypes.cast(self.lib.sinf, ctypes.c_voidp).value 606 signature = float32(float32) 607 else: 608 raise NotImplementedError( 609 f'wrapper address of `{fname}`' 610 f' with signature `{signature}`') 611 self._signature = signature 612 self._address = addr 613 614 def __repr__(self): 615 return f'{type(self).__name__}({self._name!r})' 616 617 def __wrapper_address__(self): 618 return self._address 619 620 def signature(self): 621 return self._signature 622 623 mycos = LibM('cos') 624 mysin = LibM('sinf') 625 626 def myeval(f, x): 627 return f(x) 628 629 # Not testing forceobj=True as it requires implementing 630 # LibM.__call__ using ctypes which would be out-of-scope here. 631 for jit_opts in [dict(nopython=True)]: 632 jit_ = jit(**jit_opts) 633 with self.subTest(jit=jit_opts): 634 if mycos.signature() is not None: 635 self.assertEqual(jit_(myeval)(mycos, 0.0), 1.0) 636 if mysin.signature() is not None: 637 self.assertEqual(jit_(myeval)(mysin, float32(0.0)), 0.0) 638 639 def test_compilation_results(self): 640 """Turn the existing compilation results of a dispatcher instance to 641 first-class functions with precise types. 642 """ 643 644 @jit(nopython=True) 645 def add_template(x, y): 646 return x + y 647 648 # Trigger compilations 649 self.assertEqual(add_template(1, 2), 3) 650 self.assertEqual(add_template(1.2, 3.4), 4.6) 651 652 cres1, cres2 = add_template.overloads.values() 653 654 # Turn compilation results into first-class functions 655 iadd = types.CompileResultWAP(cres1) 656 fadd = types.CompileResultWAP(cres2) 657 658 @jit(nopython=True) 659 def foo(add, x, y): 660 return add(x, y) 661 662 @jit(forceobj=True) 663 def foo_obj(add, x, y): 664 return add(x, y) 665 666 self.assertEqual(foo(iadd, 3, 4), 7) 667 self.assertEqual(foo(fadd, 3.4, 4.5), 7.9) 668 669 self.assertEqual(foo_obj(iadd, 3, 4), 7) 670 self.assertEqual(foo_obj(fadd, 3.4, 4.5), 7.9) 671 672 673class TestMiscIssues(TestCase): 674 """Test issues of using first-class functions in the context of Numba 675 jit compiled functions. 676 677 """ 678 679 def test_issue_3405_using_cfunc(self): 680 681 @cfunc('int64()') 682 def a(): 683 return 2 684 685 @cfunc('int64()') 686 def b(): 687 return 3 688 689 def g(arg): 690 if arg: 691 f = a 692 else: 693 f = b 694 return f() 695 696 self.assertEqual(jit(nopython=True)(g)(True), 2) 697 self.assertEqual(jit(nopython=True)(g)(False), 3) 698 699 def test_issue_3405_using_njit(self): 700 701 @jit(nopython=True) 702 def a(): 703 return 2 704 705 @jit(nopython=True) 706 def b(): 707 return 3 708 709 def g(arg): 710 if not arg: 711 f = b 712 else: 713 f = a 714 return f() 715 716 self.assertEqual(jit(nopython=True)(g)(True), 2) 717 self.assertEqual(jit(nopython=True)(g)(False), 3) 718 719 def test_pr4967_example(self): 720 721 @cfunc('int64(int64)') 722 def a(i): 723 return i + 1 724 725 @cfunc('int64(int64)') 726 def b(i): 727 return i + 2 728 729 @jit(nopython=True) 730 def foo(f, g): 731 i = f(2) 732 seq = (f, g) 733 for fun in seq: 734 i += fun(i) 735 return i 736 737 a_ = a._pyfunc 738 b_ = b._pyfunc 739 self.assertEqual(foo(a, b), 740 a_(2) + a_(a_(2)) + b_(a_(2) + a_(a_(2)))) 741 742 def test_pr4967_array(self): 743 import numpy as np 744 745 @cfunc("intp(intp[:], float64[:])") 746 def foo1(x, y): 747 return x[0] + y[0] 748 749 @cfunc("intp(intp[:], float64[:])") 750 def foo2(x, y): 751 return x[0] - y[0] 752 753 def bar(fx, fy, i): 754 a = np.array([10], dtype=np.intp) 755 b = np.array([12], dtype=np.float64) 756 if i == 0: 757 f = fx 758 elif i == 1: 759 f = fy 760 else: 761 return 762 return f(a, b) 763 764 r = jit(nopython=True, no_cfunc_wrapper=True)(bar)(foo1, foo2, 0) 765 self.assertEqual(r, bar(foo1, foo2, 0)) 766 self.assertNotEqual(r, bar(foo1, foo2, 1)) 767 768 def test_reference_example(self): 769 import numba 770 771 @numba.njit 772 def composition(funcs, x): 773 r = x 774 for f in funcs[::-1]: 775 r = f(r) 776 return r 777 778 @numba.cfunc("double(double)") 779 def a(x): 780 return x + 1.0 781 782 @numba.njit() 783 def b(x): 784 return x * x 785 786 r = composition((a, b, b, a), 0.5) 787 self.assertEqual(r, (0.5 + 1.0) ** 4 + 1.0) 788 789 r = composition((b, a, b, b, a), 0.5) 790 self.assertEqual(r, ((0.5 + 1.0) ** 4 + 1.0) ** 2) 791 792 def test_apply_function_in_function(self): 793 794 def foo(f, f_inner): 795 return f(f_inner) 796 797 @cfunc('int64(float64)') 798 def f_inner(i): 799 return int64(i * 3) 800 801 @cfunc(int64(types.FunctionType(f_inner._sig))) 802 def f(f_inner): 803 return f_inner(123.4) 804 805 self.assertEqual(jit(nopython=True)(foo)(f, f_inner), 806 foo(f._pyfunc, f_inner._pyfunc)) 807 808 def test_function_with_none_argument(self): 809 810 @cfunc(int64(types.none)) 811 def a(i): 812 return 1 813 814 @jit(nopython=True) 815 def foo(f): 816 return f(None) 817 818 self.assertEqual(foo(a), 1) 819 820 def test_constant_functions(self): 821 822 @jit(nopython=True) 823 def a(): 824 return 123 825 826 @jit(nopython=True) 827 def b(): 828 return 456 829 830 @jit(nopython=True) 831 def foo(): 832 return a() + b() 833 834 r = foo() 835 if r != 123 + 456: 836 print(foo.overloads[()].library.get_llvm_str()) 837 self.assertEqual(r, 123 + 456) 838 839 def test_generators(self): 840 841 @jit(forceobj=True) 842 def gen(xs): 843 for x in xs: 844 x += 1 845 yield x 846 847 @jit(forceobj=True) 848 def con(gen_fn, xs): 849 return [it for it in gen_fn(xs)] 850 851 self.assertEqual(con(gen, (1, 2, 3)), [2, 3, 4]) 852 853 @jit(nopython=True) 854 def gen_(xs): 855 for x in xs: 856 x += 1 857 yield x 858 self.assertEqual(con(gen_, (1, 2, 3)), [2, 3, 4]) 859 860 def test_jit_support(self): 861 862 @jit(nopython=True) 863 def foo(f, x): 864 return f(x) 865 866 @jit() 867 def a(x): 868 return x + 1 869 870 @jit() 871 def a2(x): 872 return x - 1 873 874 @jit() 875 def b(x): 876 return x + 1.5 877 878 self.assertEqual(foo(a, 1), 2) 879 a2(5) # pre-compile 880 self.assertEqual(foo(a2, 2), 1) 881 self.assertEqual(foo(a2, 3), 2) 882 self.assertEqual(foo(a, 2), 3) 883 self.assertEqual(foo(a, 1.5), 2.5) 884 self.assertEqual(foo(a2, 1), 0) 885 self.assertEqual(foo(a, 2.5), 3.5) 886 self.assertEqual(foo(b, 1.5), 3.0) 887 self.assertEqual(foo(b, 1), 2.5) 888 889 def test_signature_mismatch(self): 890 @jit(nopython=True) 891 def f1(x): 892 return x 893 894 @jit(nopython=True) 895 def f2(x): 896 return x 897 898 @jit(nopython=True) 899 def foo(disp1, disp2, sel): 900 if sel == 1: 901 fn = disp1 902 else: 903 fn = disp2 904 return fn([1]), fn(2) 905 906 with self.assertRaises(errors.UnsupportedError) as cm: 907 foo(f1, f2, sel=1) 908 self.assertRegex( 909 str(cm.exception), 'mismatch of function types:') 910 911 # this works because `sel == 1` condition is optimized away: 912 self.assertEqual(foo(f1, f1, sel=1), ([1], 2)) 913 914 def test_unique_dispatcher(self): 915 # In general, the type of a dispatcher instance is imprecise 916 # and when used as an input to type-inference, the typing will 917 # likely fail. However, if a dispatcher instance contains 918 # exactly one overload and compilation is disabled for the dispatcher, 919 # then the type of dispatcher instance is interpreted as precise 920 # and is transformed to a FunctionType instance with the defined 921 # signature of the single overload. 922 923 def foo_template(funcs, x): 924 r = x 925 for f in funcs: 926 r = f(r) 927 return r 928 929 a = jit(nopython=True)(lambda x: x + 1) 930 b = jit(nopython=True)(lambda x: x + 2) 931 foo = jit(nopython=True)(foo_template) 932 933 # compiling and disabling compilation for `a` is sufficient, 934 # `b` will inherit its type from the container Tuple type 935 a(0) 936 a.disable_compile() 937 938 r = foo((a, b), 0) 939 self.assertEqual(r, 3) 940 # the Tuple type of foo's first argument is a precise FunctionType: 941 self.assertEqual(foo.signatures[0][0].dtype.is_precise(), True) 942 943 def test_zero_address(self): 944 945 sig = int64() 946 947 @cfunc(sig) 948 def test(): 949 return 123 950 951 class Good(types.WrapperAddressProtocol): 952 """A first-class function type with valid address. 953 """ 954 955 def __wrapper_address__(self): 956 return test.address 957 958 def signature(self): 959 return sig 960 961 class Bad(types.WrapperAddressProtocol): 962 """A first-class function type with invalid 0 address. 963 """ 964 965 def __wrapper_address__(self): 966 return 0 967 968 def signature(self): 969 return sig 970 971 class BadToGood(types.WrapperAddressProtocol): 972 """A first-class function type with invalid address that is 973 recovered to a valid address. 974 """ 975 976 counter = -1 977 978 def __wrapper_address__(self): 979 self.counter += 1 980 return test.address * min(1, self.counter) 981 982 def signature(self): 983 return sig 984 985 good = Good() 986 bad = Bad() 987 bad2good = BadToGood() 988 989 @jit(int64(sig.as_type())) 990 def foo(func): 991 return func() 992 993 @jit(int64()) 994 def foo_good(): 995 return good() 996 997 @jit(int64()) 998 def foo_bad(): 999 return bad() 1000 1001 @jit(int64()) 1002 def foo_bad2good(): 1003 return bad2good() 1004 1005 self.assertEqual(foo(good), 123) 1006 1007 self.assertEqual(foo_good(), 123) 1008 1009 with self.assertRaises(ValueError) as cm: 1010 foo(bad) 1011 self.assertRegex( 1012 str(cm.exception), 1013 'wrapper address of <.*> instance must be a positive') 1014 1015 with self.assertRaises(RuntimeError) as cm: 1016 foo_bad() 1017 self.assertRegex( 1018 str(cm.exception), r'.* function address is null') 1019 1020 self.assertEqual(foo_bad2good(), 123) 1021 1022 def test_issue_5470(self): 1023 1024 @njit() 1025 def foo1(): 1026 return 10 1027 1028 @njit() 1029 def foo2(): 1030 return 20 1031 1032 formulae_foo = (foo1, foo1) 1033 1034 @njit() 1035 def bar_scalar(f1, f2): 1036 return f1() + f2() 1037 1038 @njit() 1039 def bar(): 1040 return bar_scalar(*formulae_foo) 1041 1042 self.assertEqual(bar(), 20) 1043 1044 formulae_foo = (foo1, foo2) 1045 1046 @njit() 1047 def bar(): 1048 return bar_scalar(*formulae_foo) 1049 1050 self.assertEqual(bar(), 30) 1051 1052 def test_issue_5540(self): 1053 1054 @njit(types.int64(types.int64)) 1055 def foo(x): 1056 return x + 1 1057 1058 @njit 1059 def bar_bad(foos): 1060 f = foos[0] 1061 return f(x=1) 1062 1063 @njit 1064 def bar_good(foos): 1065 f = foos[0] 1066 return f(1) 1067 1068 self.assertEqual(bar_good((foo, )), 2) 1069 1070 with self.assertRaises(errors.TypingError) as cm: 1071 bar_bad((foo, )) 1072 1073 self.assertRegex( 1074 str(cm.exception), 1075 r'.*first-class function call cannot use keyword arguments') 1076 1077 def test_issue_5615(self): 1078 1079 @njit 1080 def foo1(x): 1081 return x + 1 1082 1083 @njit 1084 def foo2(x): 1085 return x + 2 1086 1087 @njit 1088 def bar(fcs): 1089 x = 0 1090 a = 10 1091 i, j = fcs[0] 1092 x += i(j(a)) 1093 for t in literal_unroll(fcs): 1094 i, j = t 1095 x += i(j(a)) 1096 return x 1097 1098 tup = ((foo1, foo2), (foo2, foo1)) 1099 1100 self.assertEqual(bar(tup), 39) 1101 1102 def test_issue_5685(self): 1103 1104 @njit 1105 def foo1(): 1106 return 1 1107 1108 @njit 1109 def foo2(x): 1110 return x + 1 1111 1112 @njit 1113 def foo3(x): 1114 return x + 2 1115 1116 @njit 1117 def bar(fcs): 1118 r = 0 1119 for pair in literal_unroll(fcs): 1120 f1, f2 = pair 1121 r += f1() + f2(2) 1122 return r 1123 1124 self.assertEqual(bar(((foo1, foo2),)), 4) 1125 self.assertEqual(bar(((foo1, foo2), (foo1, foo3))), 9) # reproducer 1126