1import types as pytypes
2from numba import jit, njit, cfunc, types, int64, float64, float32, errors
3from numba import literal_unroll
4from numba.core.config import IS_32BITS, IS_WIN32
5import ctypes
6import warnings
7
8from .support import TestCase
9
10
11def dump(foo):  # FOR DEBUGGING, TO BE REMOVED
12    from numba.core import function
13    foo_type = function.fromobject(foo)
14    foo_sig = foo_type.signature()
15    foo.compile(foo_sig)
16    print('{" LLVM IR OF "+foo.__name__+" ":*^70}')
17    print(foo.inspect_llvm(foo_sig.args))
18    print('{"":*^70}')
19
20
21# Decorators for transforming a Python function to different kinds of
22# functions:
23
24def mk_cfunc_func(sig):
25    def cfunc_func(func):
26        assert isinstance(func, pytypes.FunctionType), repr(func)
27        f = cfunc(sig)(func)
28        f.pyfunc = func
29        return f
30    return cfunc_func
31
32
33def njit_func(func):
34    assert isinstance(func, pytypes.FunctionType), repr(func)
35    f = jit(nopython=True)(func)
36    f.pyfunc = func
37    return f
38
39
40def mk_njit_with_sig_func(sig):
41    def njit_with_sig_func(func):
42        assert isinstance(func, pytypes.FunctionType), repr(func)
43        f = jit(sig, nopython=True)(func)
44        f.pyfunc = func
45        return f
46    return njit_with_sig_func
47
48
49def mk_ctypes_func(sig):
50    def ctypes_func(func, sig=int64(int64)):
51        assert isinstance(func, pytypes.FunctionType), repr(func)
52        cfunc = mk_cfunc_func(sig)(func)
53        addr = cfunc._wrapper_address
54        if sig == int64(int64):
55            f = ctypes.CFUNCTYPE(ctypes.c_int64)(addr)
56            f.pyfunc = func
57            return f
58        raise NotImplementedError(
59            f'ctypes decorator for {func} with signature {sig}')
60    return ctypes_func
61
62
63class WAP(types.WrapperAddressProtocol):
64    """An example implementation of wrapper address protocol.
65
66    """
67    def __init__(self, func, sig):
68        self.pyfunc = func
69        self.cfunc = cfunc(sig)(func)
70        self.sig = sig
71
72    def __wrapper_address__(self):
73        return self.cfunc._wrapper_address
74
75    def signature(self):
76        return self.sig
77
78    def __call__(self, *args, **kwargs):
79        return self.pyfunc(*args, **kwargs)
80
81
82def mk_wap_func(sig):
83    def wap_func(func):
84        return WAP(func, sig)
85    return wap_func
86
87
88class TestFunctionType(TestCase):
89    """Test first-class functions in the context of a Numba jit compiled
90    function.
91
92    """
93
94    def test_in__(self):
95        """Function is passed in as an argument.
96        """
97
98        def a(i):
99            return i + 1
100
101        def foo(f):
102            return 0
103
104        sig = int64(int64)
105
106        for decor in [mk_cfunc_func(sig),
107                      njit_func,
108                      mk_njit_with_sig_func(sig),
109                      mk_ctypes_func(sig),
110                      mk_wap_func(sig)]:
111            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
112                jit_ = jit(**jit_opts)
113                with self.subTest(decor=decor.__name__, jit=jit_opts):
114                    a_ = decor(a)
115                    self.assertEqual(jit_(foo)(a_), foo(a))
116
117    def test_in_call__(self):
118        """Function is passed in as an argument and called.
119        Also test different return values.
120        """
121
122        def a_i64(i):
123            return i + 1234567
124
125        def a_f64(i):
126            return i + 1.5
127
128        def a_str(i):
129            return "abc"
130
131        def foo(f):
132            return f(123)
133
134        for f, sig in [(a_i64, int64(int64)), (a_f64, float64(int64))]:
135            for decor in [mk_cfunc_func(sig), njit_func,
136                          mk_njit_with_sig_func(sig),
137                          mk_wap_func(sig)]:
138                for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
139                    jit_ = jit(**jit_opts)
140                    with self.subTest(
141                            sig=sig, decor=decor.__name__, jit=jit_opts):
142                        f_ = decor(f)
143                        self.assertEqual(jit_(foo)(f_), foo(f))
144
145    def test_in_call_out(self):
146        """Function is passed in as an argument, called, and returned.
147        """
148
149        def a(i):
150            return i + 1
151
152        def foo(f):
153            f(123)
154            return f
155
156        sig = int64(int64)
157
158        for decor in [mk_cfunc_func(sig), njit_func,
159                      mk_njit_with_sig_func(sig), mk_wap_func(sig)]:
160            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
161                jit_ = jit(**jit_opts)
162                with self.subTest(decor=decor.__name__):
163                    a_ = decor(a)
164                    r1 = jit_(foo)(a_).pyfunc
165                    r2 = foo(a)
166                    self.assertEqual(r1, r2)
167
168    def test_in_seq_call(self):
169        """Functions are passed in as arguments, used as tuple items, and
170        called.
171
172        """
173        def a(i):
174            return i + 1
175
176        def b(i):
177            return i + 2
178
179        def foo(f, g):
180            r = 0
181            for f_ in (f, g):
182                r = r + f_(r)
183            return r
184
185        sig = int64(int64)
186
187        for decor in [mk_cfunc_func(sig), mk_wap_func(sig),
188                      mk_njit_with_sig_func(sig)]:
189            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
190                jit_ = jit(**jit_opts)
191                with self.subTest(decor=decor.__name__):
192                    a_ = decor(a)
193                    b_ = decor(b)
194                    self.assertEqual(jit_(foo)(a_, b_), foo(a, b))
195
196    def test_in_ns_seq_call(self):
197        """Functions are passed in as an argument and via namespace scoping
198        (mixed pathways), used as tuple items, and called.
199
200        """
201
202        def a(i):
203            return i + 1
204
205        def b(i):
206            return i + 2
207
208        def mkfoo(b_):
209            def foo(f):
210                r = 0
211                for f_ in (f, b_):
212                    r = r + f_(r)
213                return r
214            return foo
215
216        sig = int64(int64)
217
218        for decor in [mk_cfunc_func(sig),
219                      mk_njit_with_sig_func(sig), mk_wap_func(sig),
220                      mk_ctypes_func(sig)][:-1]:
221            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
222                jit_ = jit(**jit_opts)
223                with self.subTest(decor=decor.__name__):
224                    a_ = decor(a)
225                    b_ = decor(b)
226                    self.assertEqual(jit_(mkfoo(b_))(a_), mkfoo(b)(a))
227
228    def test_ns_call(self):
229        """Function is passed in via namespace scoping and called.
230
231        """
232
233        def a(i):
234            return i + 1
235
236        def mkfoo(a_):
237            def foo():
238                return a_(123)
239            return foo
240
241        sig = int64(int64)
242
243        for decor in [mk_cfunc_func(sig), njit_func,
244                      mk_njit_with_sig_func(sig), mk_wap_func(sig)]:
245            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
246                jit_ = jit(**jit_opts)
247                with self.subTest(decor=decor.__name__):
248                    a_ = decor(a)
249                    self.assertEqual(jit_(mkfoo(a_))(), mkfoo(a)())
250
251    def test_ns_out(self):
252        """Function is passed in via namespace scoping and returned.
253
254        """
255        def a(i):
256            return i + 1
257
258        def mkfoo(a_):
259            def foo():
260                return a_
261            return foo
262
263        sig = int64(int64)
264
265        for decor in [mk_cfunc_func(sig), njit_func,
266                      mk_njit_with_sig_func(sig), mk_wap_func(sig),
267                      mk_ctypes_func(sig)][:-1]:
268            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
269                jit_ = jit(**jit_opts)
270                with self.subTest(decor=decor.__name__):
271                    a_ = decor(a)
272                    self.assertEqual(jit_(mkfoo(a_))().pyfunc, mkfoo(a)())
273
274    def test_ns_call_out(self):
275        """Function is passed in via namespace scoping, called, and then
276        returned.
277
278        """
279        def a(i):
280            return i + 1
281
282        def mkfoo(a_):
283            def foo():
284                a_(123)
285                return a_
286            return foo
287
288        sig = int64(int64)
289
290        for decor in [mk_cfunc_func(sig), njit_func,
291                      mk_njit_with_sig_func(sig), mk_wap_func(sig),
292                      mk_ctypes_func(sig)]:
293            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
294                jit_ = jit(**jit_opts)
295            with self.subTest(decor=decor.__name__):
296                a_ = decor(a)
297                self.assertEqual(jit_(mkfoo(a_))().pyfunc, mkfoo(a)())
298
299    def test_in_overload(self):
300        """Function is passed in as an argument and called with different
301        argument types.
302
303        """
304        def a(i):
305            return i + 1
306
307        def foo(f):
308            r1 = f(123)
309            r2 = f(123.45)
310            return (r1, r2)
311
312        for decor in [njit_func]:
313            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
314                jit_ = jit(**jit_opts)
315                with self.subTest(decor=decor.__name__):
316                    a_ = decor(a)
317                    self.assertEqual(jit_(foo)(a_), foo(a))
318
319    def test_ns_overload(self):
320        """Function is passed in via namespace scoping and called with
321        different argument types.
322
323        """
324        def a(i):
325            return i + 1
326
327        def mkfoo(a_):
328            def foo():
329                r1 = a_(123)
330                r2 = a_(123.45)
331                return (r1, r2)
332            return foo
333
334        for decor in [njit_func]:
335            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
336                jit_ = jit(**jit_opts)
337                with self.subTest(decor=decor.__name__):
338                    a_ = decor(a)
339                    self.assertEqual(jit_(mkfoo(a_))(), mkfoo(a)())
340
341    def test_in_choose(self):
342        """Functions are passed in as arguments and called conditionally.
343
344        """
345        def a(i):
346            return i + 1
347
348        def b(i):
349            return i + 2
350
351        def foo(a, b, choose_left):
352            if choose_left:
353                r = a(1)
354            else:
355                r = b(2)
356            return r
357
358        sig = int64(int64)
359
360        for decor in [mk_cfunc_func(sig), njit_func,
361                      mk_njit_with_sig_func(sig), mk_wap_func(sig)]:
362            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
363                jit_ = jit(**jit_opts)
364                with self.subTest(decor=decor.__name__):
365                    a_ = decor(a)
366                    b_ = decor(b)
367                    self.assertEqual(jit_(foo)(a_, b_, True), foo(a, b, True))
368                    self.assertEqual(jit_(foo)(a_, b_, False),
369                                     foo(a, b, False))
370                    self.assertNotEqual(jit_(foo)(a_, b_, True),
371                                        foo(a, b, False))
372
373    def test_ns_choose(self):
374        """Functions are passed in via namespace scoping and called
375        conditionally.
376
377        """
378        def a(i):
379            return i + 1
380
381        def b(i):
382            return i + 2
383
384        def mkfoo(a_, b_):
385            def foo(choose_left):
386                if choose_left:
387                    r = a_(1)
388                else:
389                    r = b_(2)
390                return r
391            return foo
392
393        sig = int64(int64)
394
395        for decor in [mk_cfunc_func(sig), njit_func,
396                      mk_njit_with_sig_func(sig), mk_wap_func(sig)]:
397            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
398                jit_ = jit(**jit_opts)
399                with self.subTest(decor=decor.__name__):
400                    a_ = decor(a)
401                    b_ = decor(b)
402                    self.assertEqual(jit_(mkfoo(a_, b_))(True),
403                                     mkfoo(a, b)(True))
404                    self.assertEqual(jit_(mkfoo(a_, b_))(False),
405                                     mkfoo(a, b)(False))
406                    self.assertNotEqual(jit_(mkfoo(a_, b_))(True),
407                                        mkfoo(a, b)(False))
408
409    def test_in_choose_out(self):
410        """Functions are passed in as arguments and returned conditionally.
411
412        """
413        def a(i):
414            return i + 1
415
416        def b(i):
417            return i + 2
418
419        def foo(a, b, choose_left):
420            if choose_left:
421                return a
422            else:
423                return b
424
425        sig = int64(int64)
426
427        for decor in [mk_cfunc_func(sig), njit_func,
428                      mk_njit_with_sig_func(sig), mk_wap_func(sig)]:
429            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
430                jit_ = jit(**jit_opts)
431                with self.subTest(decor=decor.__name__):
432                    a_ = decor(a)
433                    b_ = decor(b)
434                    self.assertEqual(jit_(foo)(a_, b_, True).pyfunc,
435                                     foo(a, b, True))
436                    self.assertEqual(jit_(foo)(a_, b_, False).pyfunc,
437                                     foo(a, b, False))
438                    self.assertNotEqual(jit_(foo)(a_, b_, True).pyfunc,
439                                        foo(a, b, False))
440
441    def test_in_choose_func_value(self):
442        """Functions are passed in as arguments, selected conditionally and
443        called.
444
445        """
446        def a(i):
447            return i + 1
448
449        def b(i):
450            return i + 2
451
452        def foo(a, b, choose_left):
453            if choose_left:
454                f = a
455            else:
456                f = b
457            return f(1)
458
459        sig = int64(int64)
460
461        for decor in [mk_cfunc_func(sig), mk_wap_func(sig), njit_func,
462                      mk_njit_with_sig_func(sig)]:
463            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
464                jit_ = jit(**jit_opts)
465                with self.subTest(decor=decor.__name__):
466                    a_ = decor(a)
467                    b_ = decor(b)
468                    self.assertEqual(jit_(foo)(a_, b_, True), foo(a, b, True))
469                    self.assertEqual(jit_(foo)(a_, b_, False),
470                                     foo(a, b, False))
471                    self.assertNotEqual(jit_(foo)(a_, b_, True),
472                                        foo(a, b, False))
473
474    def test_in_pick_func_call(self):
475        """Functions are passed in as items of tuple argument, retrieved via
476        indexing, and called.
477
478        """
479        def a(i):
480            return i + 1
481
482        def b(i):
483            return i + 2
484
485        def foo(funcs, i):
486            f = funcs[i]
487            r = f(123)
488            return r
489
490        sig = int64(int64)
491
492        for decor in [mk_cfunc_func(sig), mk_wap_func(sig),
493                      mk_njit_with_sig_func(sig)]:
494            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
495                jit_ = jit(**jit_opts)
496                with self.subTest(decor=decor.__name__):
497                    a_ = decor(a)
498                    b_ = decor(b)
499                    self.assertEqual(jit_(foo)((a_, b_), 0), foo((a, b), 0))
500                    self.assertEqual(jit_(foo)((a_, b_), 1), foo((a, b), 1))
501                    self.assertNotEqual(jit_(foo)((a_, b_), 0), foo((a, b), 1))
502
503    def test_in_iter_func_call(self):
504        """Functions are passed in as items of tuple argument, retrieved via
505        indexing, and called within a variable for-loop.
506
507        """
508        def a(i):
509            return i + 1
510
511        def b(i):
512            return i + 2
513
514        def foo(funcs, n):
515            r = 0
516            for i in range(n):
517                f = funcs[i]
518                r = r + f(r)
519            return r
520
521        sig = int64(int64)
522
523        for decor in [mk_cfunc_func(sig), mk_wap_func(sig),
524                      mk_njit_with_sig_func(sig)]:
525            for jit_opts in [dict(nopython=True), dict(forceobj=True)]:
526                jit_ = jit(**jit_opts)
527                with self.subTest(decor=decor.__name__):
528                    a_ = decor(a)
529                    b_ = decor(b)
530                    self.assertEqual(jit_(foo)((a_, b_), 2), foo((a, b), 2))
531
532    def test_experimental_feature_warning(self):
533        @jit(nopython=True)
534        def more(x):
535            return x + 1
536
537        @jit(nopython=True)
538        def less(x):
539            return x - 1
540
541        @jit(nopython=True)
542        def foo(sel, x):
543            fn = more if sel else less
544            return fn(x)
545
546        with warnings.catch_warnings(record=True) as ws:
547            warnings.simplefilter("always")
548            res = foo(True, 10)
549
550        self.assertEqual(res, 11)
551        self.assertEqual(foo(False, 10), 9)
552
553        self.assertGreaterEqual(len(ws), 1)
554        pat = "First-class function type feature is experimental"
555        for w in ws:
556            if pat in str(w.message):
557                break
558        else:
559            self.fail("missing warning")
560
561
562class TestFunctionTypeExtensions(TestCase):
563    """Test calling external library functions within Numba jit compiled
564    functions.
565
566    """
567
568    def test_wrapper_address_protocol_libm(self):
569        """Call cos and sinf from standard math library.
570
571        """
572        import ctypes.util
573
574        class LibM(types.WrapperAddressProtocol):
575
576            def __init__(self, fname):
577                if IS_WIN32:
578                    lib = ctypes.cdll.msvcrt
579                else:
580                    libpath = ctypes.util.find_library('m')
581                    lib = ctypes.cdll.LoadLibrary(libpath)
582                self.lib = lib
583                self._name = fname
584                if fname == 'cos':
585                    # test for double-precision math function
586                    if IS_WIN32 and IS_32BITS:
587                        # 32-bit Windows math library does not provide
588                        # a double-precision cos function, so
589                        # disabling the function
590                        addr = None
591                        signature = None
592                    else:
593                        addr = ctypes.cast(self.lib.cos, ctypes.c_voidp).value
594                        signature = float64(float64)
595                elif fname == 'sinf':
596                    # test for single-precision math function
597                    if IS_WIN32 and IS_32BITS:
598                        # 32-bit Windows math library provides sin
599                        # (instead of sinf) that is a single-precision
600                        # sin function
601                        addr = ctypes.cast(self.lib.sin, ctypes.c_voidp).value
602                    else:
603                        # Other 32/64 bit platforms define sinf as the
604                        # single-precision sin function
605                        addr = ctypes.cast(self.lib.sinf, ctypes.c_voidp).value
606                    signature = float32(float32)
607                else:
608                    raise NotImplementedError(
609                        f'wrapper address of `{fname}`'
610                        f' with signature `{signature}`')
611                self._signature = signature
612                self._address = addr
613
614            def __repr__(self):
615                return f'{type(self).__name__}({self._name!r})'
616
617            def __wrapper_address__(self):
618                return self._address
619
620            def signature(self):
621                return self._signature
622
623        mycos = LibM('cos')
624        mysin = LibM('sinf')
625
626        def myeval(f, x):
627            return f(x)
628
629        # Not testing forceobj=True as it requires implementing
630        # LibM.__call__ using ctypes which would be out-of-scope here.
631        for jit_opts in [dict(nopython=True)]:
632            jit_ = jit(**jit_opts)
633            with self.subTest(jit=jit_opts):
634                if mycos.signature() is not None:
635                    self.assertEqual(jit_(myeval)(mycos, 0.0), 1.0)
636                if mysin.signature() is not None:
637                    self.assertEqual(jit_(myeval)(mysin, float32(0.0)), 0.0)
638
639    def test_compilation_results(self):
640        """Turn the existing compilation results of a dispatcher instance to
641        first-class functions with precise types.
642        """
643
644        @jit(nopython=True)
645        def add_template(x, y):
646            return x + y
647
648        # Trigger compilations
649        self.assertEqual(add_template(1, 2), 3)
650        self.assertEqual(add_template(1.2, 3.4), 4.6)
651
652        cres1, cres2 = add_template.overloads.values()
653
654        # Turn compilation results into first-class functions
655        iadd = types.CompileResultWAP(cres1)
656        fadd = types.CompileResultWAP(cres2)
657
658        @jit(nopython=True)
659        def foo(add, x, y):
660            return add(x, y)
661
662        @jit(forceobj=True)
663        def foo_obj(add, x, y):
664            return add(x, y)
665
666        self.assertEqual(foo(iadd, 3, 4), 7)
667        self.assertEqual(foo(fadd, 3.4, 4.5), 7.9)
668
669        self.assertEqual(foo_obj(iadd, 3, 4), 7)
670        self.assertEqual(foo_obj(fadd, 3.4, 4.5), 7.9)
671
672
673class TestMiscIssues(TestCase):
674    """Test issues of using first-class functions in the context of Numba
675    jit compiled functions.
676
677    """
678
679    def test_issue_3405_using_cfunc(self):
680
681        @cfunc('int64()')
682        def a():
683            return 2
684
685        @cfunc('int64()')
686        def b():
687            return 3
688
689        def g(arg):
690            if arg:
691                f = a
692            else:
693                f = b
694            return f()
695
696        self.assertEqual(jit(nopython=True)(g)(True), 2)
697        self.assertEqual(jit(nopython=True)(g)(False), 3)
698
699    def test_issue_3405_using_njit(self):
700
701        @jit(nopython=True)
702        def a():
703            return 2
704
705        @jit(nopython=True)
706        def b():
707            return 3
708
709        def g(arg):
710            if not arg:
711                f = b
712            else:
713                f = a
714            return f()
715
716        self.assertEqual(jit(nopython=True)(g)(True), 2)
717        self.assertEqual(jit(nopython=True)(g)(False), 3)
718
719    def test_pr4967_example(self):
720
721        @cfunc('int64(int64)')
722        def a(i):
723            return i + 1
724
725        @cfunc('int64(int64)')
726        def b(i):
727            return i + 2
728
729        @jit(nopython=True)
730        def foo(f, g):
731            i = f(2)
732            seq = (f, g)
733            for fun in seq:
734                i += fun(i)
735            return i
736
737        a_ = a._pyfunc
738        b_ = b._pyfunc
739        self.assertEqual(foo(a, b),
740                         a_(2) + a_(a_(2)) + b_(a_(2) + a_(a_(2))))
741
742    def test_pr4967_array(self):
743        import numpy as np
744
745        @cfunc("intp(intp[:], float64[:])")
746        def foo1(x, y):
747            return x[0] + y[0]
748
749        @cfunc("intp(intp[:], float64[:])")
750        def foo2(x, y):
751            return x[0] - y[0]
752
753        def bar(fx, fy, i):
754            a = np.array([10], dtype=np.intp)
755            b = np.array([12], dtype=np.float64)
756            if i == 0:
757                f = fx
758            elif i == 1:
759                f = fy
760            else:
761                return
762            return f(a, b)
763
764        r = jit(nopython=True, no_cfunc_wrapper=True)(bar)(foo1, foo2, 0)
765        self.assertEqual(r, bar(foo1, foo2, 0))
766        self.assertNotEqual(r, bar(foo1, foo2, 1))
767
768    def test_reference_example(self):
769        import numba
770
771        @numba.njit
772        def composition(funcs, x):
773            r = x
774            for f in funcs[::-1]:
775                r = f(r)
776            return r
777
778        @numba.cfunc("double(double)")
779        def a(x):
780            return x + 1.0
781
782        @numba.njit()
783        def b(x):
784            return x * x
785
786        r = composition((a, b, b, a), 0.5)
787        self.assertEqual(r, (0.5 + 1.0) ** 4 + 1.0)
788
789        r = composition((b, a, b, b, a), 0.5)
790        self.assertEqual(r, ((0.5 + 1.0) ** 4 + 1.0) ** 2)
791
792    def test_apply_function_in_function(self):
793
794        def foo(f, f_inner):
795            return f(f_inner)
796
797        @cfunc('int64(float64)')
798        def f_inner(i):
799            return int64(i * 3)
800
801        @cfunc(int64(types.FunctionType(f_inner._sig)))
802        def f(f_inner):
803            return f_inner(123.4)
804
805        self.assertEqual(jit(nopython=True)(foo)(f, f_inner),
806                         foo(f._pyfunc, f_inner._pyfunc))
807
808    def test_function_with_none_argument(self):
809
810        @cfunc(int64(types.none))
811        def a(i):
812            return 1
813
814        @jit(nopython=True)
815        def foo(f):
816            return f(None)
817
818        self.assertEqual(foo(a), 1)
819
820    def test_constant_functions(self):
821
822        @jit(nopython=True)
823        def a():
824            return 123
825
826        @jit(nopython=True)
827        def b():
828            return 456
829
830        @jit(nopython=True)
831        def foo():
832            return a() + b()
833
834        r = foo()
835        if r != 123 + 456:
836            print(foo.overloads[()].library.get_llvm_str())
837        self.assertEqual(r, 123 + 456)
838
839    def test_generators(self):
840
841        @jit(forceobj=True)
842        def gen(xs):
843            for x in xs:
844                x += 1
845                yield x
846
847        @jit(forceobj=True)
848        def con(gen_fn, xs):
849            return [it for it in gen_fn(xs)]
850
851        self.assertEqual(con(gen, (1, 2, 3)), [2, 3, 4])
852
853        @jit(nopython=True)
854        def gen_(xs):
855            for x in xs:
856                x += 1
857                yield x
858        self.assertEqual(con(gen_, (1, 2, 3)), [2, 3, 4])
859
860    def test_jit_support(self):
861
862        @jit(nopython=True)
863        def foo(f, x):
864            return f(x)
865
866        @jit()
867        def a(x):
868            return x + 1
869
870        @jit()
871        def a2(x):
872            return x - 1
873
874        @jit()
875        def b(x):
876            return x + 1.5
877
878        self.assertEqual(foo(a, 1), 2)
879        a2(5)  # pre-compile
880        self.assertEqual(foo(a2, 2), 1)
881        self.assertEqual(foo(a2, 3), 2)
882        self.assertEqual(foo(a, 2), 3)
883        self.assertEqual(foo(a, 1.5), 2.5)
884        self.assertEqual(foo(a2, 1), 0)
885        self.assertEqual(foo(a, 2.5), 3.5)
886        self.assertEqual(foo(b, 1.5), 3.0)
887        self.assertEqual(foo(b, 1), 2.5)
888
889    def test_signature_mismatch(self):
890        @jit(nopython=True)
891        def f1(x):
892            return x
893
894        @jit(nopython=True)
895        def f2(x):
896            return x
897
898        @jit(nopython=True)
899        def foo(disp1, disp2, sel):
900            if sel == 1:
901                fn = disp1
902            else:
903                fn = disp2
904            return fn([1]), fn(2)
905
906        with self.assertRaises(errors.UnsupportedError) as cm:
907            foo(f1, f2, sel=1)
908        self.assertRegex(
909            str(cm.exception), 'mismatch of function types:')
910
911        # this works because `sel == 1` condition is optimized away:
912        self.assertEqual(foo(f1, f1, sel=1), ([1], 2))
913
914    def test_unique_dispatcher(self):
915        # In general, the type of a dispatcher instance is imprecise
916        # and when used as an input to type-inference, the typing will
917        # likely fail. However, if a dispatcher instance contains
918        # exactly one overload and compilation is disabled for the dispatcher,
919        # then the type of dispatcher instance is interpreted as precise
920        # and is transformed to a FunctionType instance with the defined
921        # signature of the single overload.
922
923        def foo_template(funcs, x):
924            r = x
925            for f in funcs:
926                r = f(r)
927            return r
928
929        a = jit(nopython=True)(lambda x: x + 1)
930        b = jit(nopython=True)(lambda x: x + 2)
931        foo = jit(nopython=True)(foo_template)
932
933        # compiling and disabling compilation for `a` is sufficient,
934        # `b` will inherit its type from the container Tuple type
935        a(0)
936        a.disable_compile()
937
938        r = foo((a, b), 0)
939        self.assertEqual(r, 3)
940        # the Tuple type of foo's first argument is a precise FunctionType:
941        self.assertEqual(foo.signatures[0][0].dtype.is_precise(), True)
942
943    def test_zero_address(self):
944
945        sig = int64()
946
947        @cfunc(sig)
948        def test():
949            return 123
950
951        class Good(types.WrapperAddressProtocol):
952            """A first-class function type with valid address.
953            """
954
955            def __wrapper_address__(self):
956                return test.address
957
958            def signature(self):
959                return sig
960
961        class Bad(types.WrapperAddressProtocol):
962            """A first-class function type with invalid 0 address.
963            """
964
965            def __wrapper_address__(self):
966                return 0
967
968            def signature(self):
969                return sig
970
971        class BadToGood(types.WrapperAddressProtocol):
972            """A first-class function type with invalid address that is
973            recovered to a valid address.
974            """
975
976            counter = -1
977
978            def __wrapper_address__(self):
979                self.counter += 1
980                return test.address * min(1, self.counter)
981
982            def signature(self):
983                return sig
984
985        good = Good()
986        bad = Bad()
987        bad2good = BadToGood()
988
989        @jit(int64(sig.as_type()))
990        def foo(func):
991            return func()
992
993        @jit(int64())
994        def foo_good():
995            return good()
996
997        @jit(int64())
998        def foo_bad():
999            return bad()
1000
1001        @jit(int64())
1002        def foo_bad2good():
1003            return bad2good()
1004
1005        self.assertEqual(foo(good), 123)
1006
1007        self.assertEqual(foo_good(), 123)
1008
1009        with self.assertRaises(ValueError) as cm:
1010            foo(bad)
1011        self.assertRegex(
1012            str(cm.exception),
1013            'wrapper address of <.*> instance must be a positive')
1014
1015        with self.assertRaises(RuntimeError) as cm:
1016            foo_bad()
1017        self.assertRegex(
1018            str(cm.exception), r'.* function address is null')
1019
1020        self.assertEqual(foo_bad2good(), 123)
1021
1022    def test_issue_5470(self):
1023
1024        @njit()
1025        def foo1():
1026            return 10
1027
1028        @njit()
1029        def foo2():
1030            return 20
1031
1032        formulae_foo = (foo1, foo1)
1033
1034        @njit()
1035        def bar_scalar(f1, f2):
1036            return f1() + f2()
1037
1038        @njit()
1039        def bar():
1040            return bar_scalar(*formulae_foo)
1041
1042        self.assertEqual(bar(), 20)
1043
1044        formulae_foo = (foo1, foo2)
1045
1046        @njit()
1047        def bar():
1048            return bar_scalar(*formulae_foo)
1049
1050        self.assertEqual(bar(), 30)
1051
1052    def test_issue_5540(self):
1053
1054        @njit(types.int64(types.int64))
1055        def foo(x):
1056            return x + 1
1057
1058        @njit
1059        def bar_bad(foos):
1060            f = foos[0]
1061            return f(x=1)
1062
1063        @njit
1064        def bar_good(foos):
1065            f = foos[0]
1066            return f(1)
1067
1068        self.assertEqual(bar_good((foo, )), 2)
1069
1070        with self.assertRaises(errors.TypingError) as cm:
1071            bar_bad((foo, ))
1072
1073        self.assertRegex(
1074            str(cm.exception),
1075            r'.*first-class function call cannot use keyword arguments')
1076
1077    def test_issue_5615(self):
1078
1079        @njit
1080        def foo1(x):
1081            return x + 1
1082
1083        @njit
1084        def foo2(x):
1085            return x + 2
1086
1087        @njit
1088        def bar(fcs):
1089            x = 0
1090            a = 10
1091            i, j = fcs[0]
1092            x += i(j(a))
1093            for t in literal_unroll(fcs):
1094                i, j = t
1095                x += i(j(a))
1096            return x
1097
1098        tup = ((foo1, foo2), (foo2, foo1))
1099
1100        self.assertEqual(bar(tup), 39)
1101
1102    def test_issue_5685(self):
1103
1104        @njit
1105        def foo1():
1106            return 1
1107
1108        @njit
1109        def foo2(x):
1110            return x + 1
1111
1112        @njit
1113        def foo3(x):
1114            return x + 2
1115
1116        @njit
1117        def bar(fcs):
1118            r = 0
1119            for pair in literal_unroll(fcs):
1120                f1, f2 = pair
1121                r += f1() + f2(2)
1122            return r
1123
1124        self.assertEqual(bar(((foo1, foo2),)), 4)
1125        self.assertEqual(bar(((foo1, foo2), (foo1, foo3))), 9)  # reproducer
1126