1import itertools
2import functools
3import sys
4import operator
5
6import numpy as np
7
8import unittest
9from numba.core.compiler import compile_isolated, Flags
10from numba import jit, typeof, njit
11from numba.core import errors, types, utils, config
12from numba.tests.support import TestCase, tag
13
14
15enable_pyobj_flags = Flags()
16enable_pyobj_flags.set("enable_pyobject")
17
18forceobj_flags = Flags()
19forceobj_flags.set("force_pyobject")
20
21no_pyobj_flags = Flags()
22
23nrt_no_pyobj_flags = Flags()
24nrt_no_pyobj_flags.set("nrt")
25
26
27def abs_usecase(x):
28    return abs(x)
29
30def all_usecase(x, y):
31    if x == None and y == None:
32        return all([])
33    elif x == None:
34        return all([y])
35    elif y == None:
36        return all([x])
37    else:
38        return all([x, y])
39
40def any_usecase(x, y):
41    if x == None and y == None:
42        return any([])
43    elif x == None:
44        return any([y])
45    elif y == None:
46        return any([x])
47    else:
48        return any([x, y])
49
50def bool_usecase(x):
51    return bool(x)
52
53def complex_usecase(x, y):
54    return complex(x, y)
55
56def divmod_usecase(x, y):
57    return divmod(x, y)
58
59def enumerate_usecase():
60    result = 0
61    for i, j in enumerate((1., 2.5, 3.)):
62        result += i * j
63    return result
64
65def enumerate_start_usecase():
66    result = 0
67    for i, j in enumerate((1., 2.5, 3.), 42):
68        result += i * j
69    return result
70
71def enumerate_invalid_start_usecase():
72    result = 0
73    for i, j in enumerate((1., 2.5, 3.), 3.14159):
74        result += i * j
75    return result
76
77def filter_usecase(x, filter_func):
78    return filter(filter_func, x)
79
80def float_usecase(x):
81    return float(x)
82
83def format_usecase(x, y):
84    return x.format(y)
85
86def globals_usecase():
87    return globals()
88
89# NOTE: hash() is tested in test_hashing
90
91def hex_usecase(x):
92    return hex(x)
93
94def str_usecase(x):
95    return str(x)
96
97def int_usecase(x, base):
98    return int(x, base=base)
99
100def iter_next_usecase(x):
101    it = iter(x)
102    return next(it), next(it)
103
104def locals_usecase(x):
105    y = 5
106    return locals()['y']
107
108def long_usecase(x, base):
109    return long(x, base=base)
110
111def map_usecase(x, map_func):
112    return map(map_func, x)
113
114
115def max_usecase1(x, y):
116    return max(x, y)
117
118def max_usecase2(x, y):
119    return max([x, y])
120
121def max_usecase3(x):
122    return max(x)
123
124def max_usecase4():
125    return max(())
126
127
128def min_usecase1(x, y):
129    return min(x, y)
130
131def min_usecase2(x, y):
132    return min([x, y])
133
134def min_usecase3(x):
135    return min(x)
136
137def min_usecase4():
138    return min(())
139
140
141def oct_usecase(x):
142    return oct(x)
143
144def reduce_usecase(reduce_func, x):
145    return functools.reduce(reduce_func, x)
146
147def round_usecase1(x):
148    return round(x)
149
150def round_usecase2(x, n):
151    return round(x, n)
152
153def sum_usecase(x):
154    return sum(x)
155
156def type_unary_usecase(a, b):
157    return type(a)(b)
158
159def truth_usecase(p):
160    return operator.truth(p)
161
162def unichr_usecase(x):
163    return unichr(x)
164
165def zip_usecase():
166    result = 0
167    for i, j in zip((1, 2, 3), (4.5, 6.7)):
168        result += i * j
169    return result
170
171def zip_0_usecase():
172    result = 0
173    for i in zip():
174        result += 1
175    return result
176
177def zip_1_usecase():
178    result = 0
179    for i, in zip((1, 2)):
180        result += i
181    return result
182
183
184def zip_3_usecase():
185    result = 0
186    for i, j, k in zip((1, 2), (3, 4, 5), (6.7, 8.9)):
187        result += i * j * k
188    return result
189
190
191def zip_first_exhausted():
192    iterable = range(7)
193    n = 3
194    it = iter(iterable)
195    # 1st iterator is shorter
196    front = list(zip(range(n), it))
197    # Make sure that we didn't skip one in `it`
198    back = list(it)
199    return front, back
200
201
202def pow_op_usecase(x, y):
203    return x ** y
204
205
206def pow_usecase(x, y):
207    return pow(x, y)
208
209
210class TestBuiltins(TestCase):
211
212    def run_nullary_func(self, pyfunc, flags):
213        cr = compile_isolated(pyfunc, (), flags=flags)
214        cfunc = cr.entry_point
215        expected = pyfunc()
216        self.assertPreciseEqual(cfunc(), expected)
217
218    def test_abs(self, flags=enable_pyobj_flags):
219        pyfunc = abs_usecase
220
221        cr = compile_isolated(pyfunc, (types.int32,), flags=flags)
222        cfunc = cr.entry_point
223        for x in [-1, 0, 1]:
224            self.assertPreciseEqual(cfunc(x), pyfunc(x))
225
226        cr = compile_isolated(pyfunc, (types.float32,), flags=flags)
227        cfunc = cr.entry_point
228        for x in [-1.1, 0.0, 1.1]:
229            self.assertPreciseEqual(cfunc(x), pyfunc(x), prec='single')
230
231        complex_values = [-1.1 + 0.5j, 0.0 + 0j, 1.1 + 3j,
232                          float('inf') + 1j * float('nan'),
233                          float('nan') - 1j * float('inf')]
234        cr = compile_isolated(pyfunc, (types.complex64,), flags=flags)
235        cfunc = cr.entry_point
236        for x in complex_values:
237            self.assertPreciseEqual(cfunc(x), pyfunc(x), prec='single')
238        cr = compile_isolated(pyfunc, (types.complex128,), flags=flags)
239        cfunc = cr.entry_point
240        for x in complex_values:
241            self.assertPreciseEqual(cfunc(x), pyfunc(x))
242
243        for unsigned_type in types.unsigned_domain:
244            unsigned_values = [0, 10, 2, 2 ** unsigned_type.bitwidth - 1]
245            cr = compile_isolated(pyfunc, (unsigned_type,), flags=flags)
246            cfunc = cr.entry_point
247            for x in unsigned_values:
248                self.assertPreciseEqual(cfunc(x), pyfunc(x))
249
250    def test_abs_npm(self):
251        self.test_abs(flags=no_pyobj_flags)
252
253    def test_all(self, flags=enable_pyobj_flags):
254        pyfunc = all_usecase
255
256        cr = compile_isolated(pyfunc, (types.int32,types.int32), flags=flags)
257        cfunc = cr.entry_point
258        x_operands = [-1, 0, 1, None]
259        y_operands = [-1, 0, 1, None]
260        for x, y in itertools.product(x_operands, y_operands):
261            self.assertPreciseEqual(cfunc(x, y), pyfunc(x, y))
262
263    def test_all_npm(self):
264        with self.assertTypingError():
265            self.test_all(flags=no_pyobj_flags)
266
267    def test_any(self, flags=enable_pyobj_flags):
268        pyfunc = any_usecase
269
270        cr = compile_isolated(pyfunc, (types.int32,types.int32), flags=flags)
271        cfunc = cr.entry_point
272        x_operands = [-1, 0, 1, None]
273        y_operands = [-1, 0, 1, None]
274        for x, y in itertools.product(x_operands, y_operands):
275            self.assertPreciseEqual(cfunc(x, y), pyfunc(x, y))
276
277    def test_any_npm(self):
278        with self.assertTypingError():
279            self.test_any(flags=no_pyobj_flags)
280
281    def test_bool(self, flags=enable_pyobj_flags):
282        pyfunc = bool_usecase
283
284        cr = compile_isolated(pyfunc, (types.int32,), flags=flags)
285        cfunc = cr.entry_point
286        for x in [-1, 0, 1]:
287            self.assertPreciseEqual(cfunc(x), pyfunc(x))
288        cr = compile_isolated(pyfunc, (types.float64,), flags=flags)
289        cfunc = cr.entry_point
290        for x in [0.0, -0.0, 1.5, float('inf'), float('nan')]:
291            self.assertPreciseEqual(cfunc(x), pyfunc(x))
292        cr = compile_isolated(pyfunc, (types.complex128,), flags=flags)
293        cfunc = cr.entry_point
294        for x in [complex(0, float('inf')), complex(0, float('nan'))]:
295            self.assertPreciseEqual(cfunc(x), pyfunc(x))
296
297    def test_bool_npm(self):
298        self.test_bool(flags=no_pyobj_flags)
299
300    def test_bool_nonnumber(self, flags=enable_pyobj_flags):
301        pyfunc = bool_usecase
302
303        cr = compile_isolated(pyfunc, (types.string,), flags=flags)
304        cfunc = cr.entry_point
305        for x in ['x', '']:
306            self.assertPreciseEqual(cfunc(x), pyfunc(x))
307
308        cr = compile_isolated(pyfunc, (types.Dummy('list'),), flags=flags)
309        cfunc = cr.entry_point
310        for x in [[1], []]:
311            self.assertPreciseEqual(cfunc(x), pyfunc(x))
312
313    def test_bool_nonnumber_npm(self):
314        with self.assertTypingError():
315            self.test_bool_nonnumber(flags=no_pyobj_flags)
316
317    def test_complex(self, flags=enable_pyobj_flags):
318        pyfunc = complex_usecase
319
320        cr = compile_isolated(pyfunc, (types.int32, types.int32), flags=flags)
321        cfunc = cr.entry_point
322
323        x_operands = [-1, 0, 1]
324        y_operands = [-1, 0, 1]
325        for x, y in itertools.product(x_operands, y_operands):
326            self.assertPreciseEqual(cfunc(x, y), pyfunc(x, y))
327
328    def test_complex_npm(self):
329        self.test_complex(flags=no_pyobj_flags)
330
331    def test_divmod_ints(self, flags=enable_pyobj_flags):
332        pyfunc = divmod_usecase
333
334        cr = compile_isolated(pyfunc, (types.int64, types.int64),
335                              flags=flags)
336        cfunc = cr.entry_point
337
338        def truncate_result(x, bits=64):
339            # Remove any extraneous bits (since Numba will return
340            # a 64-bit result by definition)
341            if x >= 0:
342                x &= (1 << (bits - 1)) - 1
343            return x
344
345        denominators = [1, 3, 7, 15, -1, -3, -7, -15, 2**63 - 1, -2**63]
346        numerators = denominators + [0]
347        for x, y, in itertools.product(numerators, denominators):
348            expected_quot, expected_rem = pyfunc(x, y)
349            quot, rem = cfunc(x, y)
350            f = truncate_result
351            self.assertPreciseEqual((f(quot), f(rem)),
352                                    (f(expected_quot), f(expected_rem)))
353
354        for x in numerators:
355            with self.assertRaises(ZeroDivisionError):
356                cfunc(x, 0)
357
358    def test_divmod_ints_npm(self):
359        self.test_divmod_ints(flags=no_pyobj_flags)
360
361    def test_divmod_floats(self, flags=enable_pyobj_flags):
362        pyfunc = divmod_usecase
363
364        cr = compile_isolated(pyfunc, (types.float64, types.float64),
365                              flags=flags)
366        cfunc = cr.entry_point
367
368        denominators = [1., 3.5, 1e100, -2., -7.5, -1e101,
369                        np.inf, -np.inf, np.nan]
370        numerators = denominators + [-0.0, 0.0]
371        for x, y, in itertools.product(numerators, denominators):
372            expected_quot, expected_rem = pyfunc(x, y)
373            quot, rem = cfunc(x, y)
374            self.assertPreciseEqual((quot, rem), (expected_quot, expected_rem))
375
376        for x in numerators:
377            with self.assertRaises(ZeroDivisionError):
378                cfunc(x, 0.0)
379
380    def test_divmod_floats_npm(self):
381        self.test_divmod_floats(flags=no_pyobj_flags)
382
383    def test_enumerate(self, flags=enable_pyobj_flags):
384        self.run_nullary_func(enumerate_usecase, flags)
385
386    def test_enumerate_npm(self):
387        self.test_enumerate(flags=no_pyobj_flags)
388
389    def test_enumerate_start(self, flags=enable_pyobj_flags):
390        self.run_nullary_func(enumerate_start_usecase, flags)
391
392    def test_enumerate_start_npm(self):
393        self.test_enumerate_start(flags=no_pyobj_flags)
394
395    def test_enumerate_start_invalid_start_type(self):
396        pyfunc = enumerate_invalid_start_usecase
397        cr = compile_isolated(pyfunc, (), flags=enable_pyobj_flags)
398        with self.assertRaises(TypeError) as raises:
399            cr.entry_point()
400
401        msg = "'float' object cannot be interpreted as an integer"
402        self.assertIn(msg, str(raises.exception))
403
404    def test_enumerate_start_invalid_start_type_npm(self):
405        pyfunc = enumerate_invalid_start_usecase
406        with self.assertRaises(errors.TypingError) as raises:
407            cr = compile_isolated(pyfunc, (), flags=no_pyobj_flags)
408        msg = "Only integers supported as start value in enumerate"
409        self.assertIn(msg, str(raises.exception))
410
411    def test_filter(self, flags=enable_pyobj_flags):
412        pyfunc = filter_usecase
413        cr = compile_isolated(pyfunc, (types.Dummy('list'),
414                                       types.Dummy('function_ptr')),
415                                       flags=flags)
416        cfunc = cr.entry_point
417
418        filter_func = lambda x: x % 2
419        x = [0, 1, 2, 3, 4]
420        self.assertSequenceEqual(list(cfunc(x, filter_func)),
421                                 list(pyfunc(x, filter_func)))
422
423    def test_filter_npm(self):
424        with self.assertTypingError():
425            self.test_filter(flags=no_pyobj_flags)
426
427    def test_float(self, flags=enable_pyobj_flags):
428        pyfunc = float_usecase
429
430        cr = compile_isolated(pyfunc, (types.int32,), flags=flags)
431        cfunc = cr.entry_point
432        for x in [-1, 0, 1]:
433            self.assertPreciseEqual(cfunc(x), pyfunc(x))
434
435        cr = compile_isolated(pyfunc, (types.float32,), flags=flags)
436        cfunc = cr.entry_point
437        for x in [-1.1, 0.0, 1.1]:
438            self.assertPreciseEqual(cfunc(x), pyfunc(x), prec='single')
439
440        cr = compile_isolated(pyfunc, (types.string,), flags=flags)
441        cfunc = cr.entry_point
442        for x in ['-1.1', '0.0', '1.1']:
443            self.assertPreciseEqual(cfunc(x), pyfunc(x))
444
445    def test_float_npm(self):
446        with self.assertTypingError():
447            self.test_float(flags=no_pyobj_flags)
448
449    def test_format(self, flags=enable_pyobj_flags):
450        pyfunc = format_usecase
451
452        cr = compile_isolated(pyfunc, (types.string, types.int32,), flags=flags)
453        cfunc = cr.entry_point
454        x = '{0}'
455        for y in [-1, 0, 1]:
456            self.assertPreciseEqual(cfunc(x, y), pyfunc(x, y))
457
458        cr = compile_isolated(pyfunc, (types.string,
459                                       types.float32,), flags=flags)
460        cfunc = cr.entry_point
461        x = '{0}'
462        for y in [-1.1, 0.0, 1.1]:
463            self.assertPreciseEqual(cfunc(x, y), pyfunc(x, y))
464
465        cr = compile_isolated(pyfunc, (types.string,
466                                       types.string,), flags=flags)
467        cfunc = cr.entry_point
468        x = '{0}'
469        for y in ['a', 'b', 'c']:
470            self.assertPreciseEqual(cfunc(x, y), pyfunc(x, y))
471
472    def test_format_npm(self):
473        with self.assertTypingError():
474            self.test_format(flags=no_pyobj_flags)
475
476    def test_globals(self, flags=enable_pyobj_flags):
477        pyfunc = globals_usecase
478        cr = compile_isolated(pyfunc, (), flags=flags)
479        cfunc = cr.entry_point
480        g = cfunc()
481        self.assertIs(g, globals())
482
483    def test_globals_npm(self):
484        with self.assertTypingError():
485            self.test_globals(flags=no_pyobj_flags)
486
487    def test_globals_jit(self, **jit_flags):
488        # Issue #416: weird behaviour of globals() in combination with
489        # the @jit decorator.
490        pyfunc = globals_usecase
491        jitted = jit(**jit_flags)(pyfunc)
492        self.assertIs(jitted(), globals())
493        self.assertIs(jitted(), globals())
494
495    def test_globals_jit_npm(self):
496        with self.assertTypingError():
497            self.test_globals_jit(nopython=True)
498
499    def test_hex(self, flags=enable_pyobj_flags):
500        pyfunc = hex_usecase
501
502        cr = compile_isolated(pyfunc, (types.int32,), flags=flags)
503        cfunc = cr.entry_point
504        for x in [-1, 0, 1]:
505            self.assertPreciseEqual(cfunc(x), pyfunc(x))
506
507    def test_hex_npm(self):
508        with self.assertTypingError():
509            self.test_hex(flags=no_pyobj_flags)
510
511    def test_int_str(self, flags=nrt_no_pyobj_flags):
512        pyfunc = str_usecase
513
514        small_inputs = [
515            1234,
516            1,
517            0,
518        ]
519
520        large_inputs = [
521            123456789,
522            2222222,
523            ~0x0
524        ]
525
526        args = [*small_inputs, *large_inputs]
527
528        typs = [
529            types.int8,
530            types.int16,
531            types.int32,
532            types.int64,
533            types.uint,
534            types.uint8,
535            types.uint16,
536            types.uint32,
537            types.uint64,
538        ]
539
540        for typ in typs:
541            cr = compile_isolated(pyfunc, (typ,), flags=flags)
542            cfunc = cr.entry_point
543            for v in args:
544                self.assertPreciseEqual(cfunc(typ(v)), pyfunc(typ(v)))
545
546                if typ.signed:
547                    self.assertPreciseEqual(cfunc(typ(-v)), pyfunc(typ(-v)))
548
549    def test_int(self, flags=enable_pyobj_flags):
550        pyfunc = int_usecase
551
552        cr = compile_isolated(pyfunc, (types.string, types.int32), flags=flags)
553        cfunc = cr.entry_point
554
555        x_operands = ['-1', '0', '1', '10']
556        y_operands = [2, 8, 10, 16]
557        for x, y in itertools.product(x_operands, y_operands):
558            self.assertPreciseEqual(cfunc(x, y), pyfunc(x, y))
559
560    def test_int_npm(self):
561        with self.assertTypingError():
562            self.test_int(flags=no_pyobj_flags)
563
564    def test_iter_next(self, flags=enable_pyobj_flags):
565        pyfunc = iter_next_usecase
566        cr = compile_isolated(pyfunc, (types.UniTuple(types.int32, 3),),
567                              flags=flags)
568        cfunc = cr.entry_point
569        self.assertPreciseEqual(cfunc((1, 42, 5)), (1, 42))
570
571        cr = compile_isolated(pyfunc, (types.UniTuple(types.int32, 1),),
572                              flags=flags)
573        cfunc = cr.entry_point
574        with self.assertRaises(StopIteration):
575            cfunc((1,))
576
577    def test_iter_next_npm(self):
578        self.test_iter_next(flags=no_pyobj_flags)
579
580    def test_locals(self, flags=enable_pyobj_flags):
581        pyfunc = locals_usecase
582        with self.assertRaises(errors.ForbiddenConstruct):
583            cr = compile_isolated(pyfunc, (types.int64,), flags=flags)
584
585    def test_locals_forceobj(self):
586        self.test_locals(flags=forceobj_flags)
587
588    def test_locals_npm(self):
589        with self.assertTypingError():
590            self.test_locals(flags=no_pyobj_flags)
591
592    def test_map(self, flags=enable_pyobj_flags):
593        pyfunc = map_usecase
594        cr = compile_isolated(pyfunc, (types.Dummy('list'),
595                                       types.Dummy('function_ptr')),
596                                       flags=flags)
597        cfunc = cr.entry_point
598
599        map_func = lambda x: x * 2
600        x = [0, 1, 2, 3, 4]
601        self.assertSequenceEqual(list(cfunc(x, map_func)),
602                                 list(pyfunc(x, map_func)))
603
604    def test_map_npm(self):
605        with self.assertTypingError():
606            self.test_map(flags=no_pyobj_flags)
607
608    #
609    # min() and max()
610    #
611
612    def check_minmax_1(self, pyfunc, flags):
613        cr = compile_isolated(pyfunc, (types.int32, types.int32), flags=flags)
614        cfunc = cr.entry_point
615
616        x_operands = [-1, 0, 1]
617        y_operands = [-1, 0, 1]
618        for x, y in itertools.product(x_operands, y_operands):
619            self.assertPreciseEqual(cfunc(x, y), pyfunc(x, y))
620
621    def test_max_1(self, flags=enable_pyobj_flags):
622        """
623        max(*args)
624        """
625        self.check_minmax_1(max_usecase1, flags)
626
627    def test_min_1(self, flags=enable_pyobj_flags):
628        """
629        min(*args)
630        """
631        self.check_minmax_1(min_usecase1, flags)
632
633    def test_max_npm_1(self):
634        self.test_max_1(flags=no_pyobj_flags)
635
636    def test_min_npm_1(self):
637        self.test_min_1(flags=no_pyobj_flags)
638
639    def check_minmax_2(self, pyfunc, flags):
640        cr = compile_isolated(pyfunc, (types.int32, types.int32), flags=flags)
641        cfunc = cr.entry_point
642
643        x_operands = [-1, 0, 1]
644        y_operands = [-1, 0, 1]
645        for x, y in itertools.product(x_operands, y_operands):
646            self.assertPreciseEqual(cfunc(x, y), pyfunc(x, y))
647
648    def test_max_2(self, flags=enable_pyobj_flags):
649        """
650        max(list)
651        """
652        self.check_minmax_2(max_usecase2, flags)
653
654    def test_min_2(self, flags=enable_pyobj_flags):
655        """
656        min(list)
657        """
658        self.check_minmax_2(min_usecase2, flags)
659
660    def test_max_npm_2(self):
661        with self.assertTypingError():
662            self.test_max_2(flags=no_pyobj_flags)
663
664    def test_min_npm_2(self):
665        with self.assertTypingError():
666            self.test_min_2(flags=no_pyobj_flags)
667
668    def check_minmax_3(self, pyfunc, flags):
669        def check(argty):
670            cr = compile_isolated(pyfunc, (argty,), flags=flags)
671            cfunc = cr.entry_point
672            # Check that the algorithm matches Python's with a non-total order
673            tup = (1.5, float('nan'), 2.5)
674            for val in [tup, tup[::-1]]:
675                self.assertPreciseEqual(cfunc(val), pyfunc(val))
676
677        check(types.UniTuple(types.float64, 3))
678        check(types.Tuple((types.float32, types.float64, types.float32)))
679
680    def test_max_3(self, flags=enable_pyobj_flags):
681        """
682        max(tuple)
683        """
684        self.check_minmax_3(max_usecase3, flags)
685
686    def test_min_3(self, flags=enable_pyobj_flags):
687        """
688        min(tuple)
689        """
690        self.check_minmax_3(min_usecase3, flags)
691
692    def test_max_npm_3(self):
693        self.test_max_3(flags=no_pyobj_flags)
694
695    def test_min_npm_3(self):
696        self.test_min_3(flags=no_pyobj_flags)
697
698    def check_min_max_invalid_types(self, pyfunc, flags=enable_pyobj_flags):
699        cr = compile_isolated(pyfunc, (types.int32, types.Dummy('list')),
700                              flags=flags)
701        cfunc = cr.entry_point
702        cfunc(1, [1])
703
704    def test_max_1_invalid_types(self):
705        with self.assertRaises(TypeError):
706            self.check_min_max_invalid_types(max_usecase1)
707
708    def test_max_1_invalid_types_npm(self):
709        with self.assertTypingError():
710            self.check_min_max_invalid_types(max_usecase1, flags=no_pyobj_flags)
711
712    def test_min_1_invalid_types(self):
713        with self.assertRaises(TypeError):
714            self.check_min_max_invalid_types(min_usecase1)
715
716    def test_min_1_invalid_types_npm(self):
717        with self.assertTypingError():
718            self.check_min_max_invalid_types(min_usecase1, flags=no_pyobj_flags)
719
720    # Test that max(1) and min(1) fail
721
722    def check_min_max_unary_non_iterable(self, pyfunc, flags=enable_pyobj_flags):
723        cr = compile_isolated(pyfunc, (types.int32,), flags=flags)
724        cfunc = cr.entry_point
725        cfunc(1)
726
727    def test_max_unary_non_iterable(self):
728        with self.assertRaises(TypeError):
729            self.check_min_max_unary_non_iterable(max_usecase3)
730
731    def test_max_unary_non_iterable_npm(self):
732        with self.assertTypingError():
733            self.check_min_max_unary_non_iterable(max_usecase3)
734
735    def test_min_unary_non_iterable(self):
736        with self.assertRaises(TypeError):
737            self.check_min_max_unary_non_iterable(min_usecase3)
738
739    def test_min_unary_non_iterable_npm(self):
740        with self.assertTypingError():
741            self.check_min_max_unary_non_iterable(min_usecase3)
742
743    # Test that max(()) and min(()) fail
744
745    def check_min_max_empty_tuple(self, pyfunc, func_name):
746        with self.assertTypingError() as raises:
747            compile_isolated(pyfunc, (), flags=no_pyobj_flags)
748        self.assertIn("%s() argument is an empty tuple" % func_name,
749                      str(raises.exception))
750
751    def test_max_empty_tuple(self):
752        self.check_min_max_empty_tuple(max_usecase4, "max")
753
754    def test_min_empty_tuple(self):
755        self.check_min_max_empty_tuple(min_usecase4, "min")
756
757
758    def test_oct(self, flags=enable_pyobj_flags):
759        pyfunc = oct_usecase
760
761        cr = compile_isolated(pyfunc, (types.int32,), flags=flags)
762        cfunc = cr.entry_point
763        for x in [-8, -1, 0, 1, 8]:
764            self.assertPreciseEqual(cfunc(x), pyfunc(x))
765
766    def test_oct_npm(self):
767        with self.assertTypingError():
768            self.test_oct(flags=no_pyobj_flags)
769
770    def test_reduce(self, flags=enable_pyobj_flags):
771        pyfunc = reduce_usecase
772        cr = compile_isolated(pyfunc, (types.Dummy('function_ptr'),
773                                       types.Dummy('list')),
774                                       flags=flags)
775        cfunc = cr.entry_point
776
777        reduce_func = lambda x, y: x + y
778
779        x = range(10)
780        self.assertPreciseEqual(cfunc(reduce_func, x), pyfunc(reduce_func, x))
781
782        x = [x + x/10.0 for x in range(10)]
783        self.assertPreciseEqual(cfunc(reduce_func, x), pyfunc(reduce_func, x))
784
785        x = [complex(x, x) for x in range(10)]
786        self.assertPreciseEqual(cfunc(reduce_func, x), pyfunc(reduce_func, x))
787
788    def test_reduce_npm(self):
789        with self.assertTypingError():
790            self.test_reduce(flags=no_pyobj_flags)
791
792    def test_round1(self, flags=enable_pyobj_flags):
793        pyfunc = round_usecase1
794
795        for tp in (types.float64, types.float32):
796            cr = compile_isolated(pyfunc, (tp,), flags=flags)
797            cfunc = cr.entry_point
798            values = [-1.6, -1.5, -1.4, -0.5, 0.0, 0.1, 0.5, 0.6, 1.4, 1.5, 5.0]
799            values += [-0.1, -0.0]
800            for x in values:
801                self.assertPreciseEqual(cfunc(x), pyfunc(x))
802
803    def test_round1_npm(self):
804        self.test_round1(flags=no_pyobj_flags)
805
806    def test_round2(self, flags=enable_pyobj_flags):
807        pyfunc = round_usecase2
808
809        for tp in (types.float64, types.float32):
810            prec = 'single' if tp is types.float32 else 'exact'
811            cr = compile_isolated(pyfunc, (tp, types.int32), flags=flags)
812            cfunc = cr.entry_point
813            for x in [0.0, 0.1, 0.125, 0.25, 0.5, 0.75, 1.25,
814                      1.5, 1.75, 2.25, 2.5, 2.75, 12.5, 15.0, 22.5]:
815                for n in (-1, 0, 1, 2):
816                    self.assertPreciseEqual(cfunc(x, n), pyfunc(x, n),
817                                            prec=prec)
818                    expected = pyfunc(-x, n)
819                    self.assertPreciseEqual(cfunc(-x, n), pyfunc(-x, n),
820                                            prec=prec)
821
822    def test_round2_npm(self):
823        self.test_round2(flags=no_pyobj_flags)
824
825    def test_sum(self, flags=enable_pyobj_flags):
826        pyfunc = sum_usecase
827
828        cr = compile_isolated(pyfunc, (types.Dummy('list'),), flags=flags)
829        cfunc = cr.entry_point
830
831        x = range(10)
832        self.assertPreciseEqual(cfunc(x), pyfunc(x))
833
834        x = [x + x/10.0 for x in range(10)]
835        self.assertPreciseEqual(cfunc(x), pyfunc(x))
836
837        x = [complex(x, x) for x in range(10)]
838        self.assertPreciseEqual(cfunc(x), pyfunc(x))
839
840    def test_sum_npm(self):
841        with self.assertTypingError():
842            self.test_sum(flags=no_pyobj_flags)
843
844    def test_truth(self):
845        pyfunc = truth_usecase
846        cfunc = jit(nopython=True)(pyfunc)
847
848        self.assertEqual(pyfunc(True), cfunc(True))
849        self.assertEqual(pyfunc(False), cfunc(False))
850
851    def test_type_unary(self):
852        # Test type(val) and type(val)(other_val)
853        pyfunc = type_unary_usecase
854        cfunc = jit(nopython=True)(pyfunc)
855
856        def check(*args):
857            expected = pyfunc(*args)
858            self.assertPreciseEqual(cfunc(*args), expected)
859
860        check(1.5, 2)
861        check(1, 2.5)
862        check(1.5j, 2)
863        check(True, 2)
864        check(2.5j, False)
865
866    def test_zip(self, flags=forceobj_flags):
867        self.run_nullary_func(zip_usecase, flags)
868
869    def test_zip_npm(self):
870        self.test_zip(flags=no_pyobj_flags)
871
872    def test_zip_1(self, flags=forceobj_flags):
873        self.run_nullary_func(zip_1_usecase, flags)
874
875    def test_zip_1_npm(self):
876        self.test_zip_1(flags=no_pyobj_flags)
877
878    def test_zip_3(self, flags=forceobj_flags):
879        self.run_nullary_func(zip_3_usecase, flags)
880
881    def test_zip_3_npm(self):
882        self.test_zip_3(flags=no_pyobj_flags)
883
884    def test_zip_0(self, flags=forceobj_flags):
885        self.run_nullary_func(zip_0_usecase, flags)
886
887    def test_zip_0_npm(self):
888        self.test_zip_0(flags=no_pyobj_flags)
889
890    def test_zip_first_exhausted(self, flags=forceobj_flags):
891        """
892        Test side effect to the input iterators when a left iterator has been
893        exhausted before the ones on the right.
894        """
895        self.run_nullary_func(zip_first_exhausted, flags)
896
897    def test_zip_first_exhausted_npm(self):
898        self.test_zip_first_exhausted(flags=nrt_no_pyobj_flags)
899
900    def test_pow_op_usecase(self):
901        args = [
902            (2, 3),
903            (2.0, 3),
904            (2, 3.0),
905            (2j, 3.0j),
906        ]
907
908        for x, y in args:
909            cres = compile_isolated(pow_op_usecase, (typeof(x), typeof(y)),
910                                    flags=no_pyobj_flags)
911            r = cres.entry_point(x, y)
912            self.assertPreciseEqual(r, pow_op_usecase(x, y))
913
914    def test_pow_usecase(self):
915        args = [
916            (2, 3),
917            (2.0, 3),
918            (2, 3.0),
919            (2j, 3.0j),
920        ]
921
922        for x, y in args:
923            cres = compile_isolated(pow_usecase, (typeof(x), typeof(y)),
924                                    flags=no_pyobj_flags)
925            r = cres.entry_point(x, y)
926            self.assertPreciseEqual(r, pow_usecase(x, y))
927
928    def _check_min_max(self, pyfunc):
929        cfunc = njit()(pyfunc)
930        expected = pyfunc()
931        got = cfunc()
932        self.assertPreciseEqual(expected, got)
933
934    def test_min_max_iterable_input(self):
935
936        @njit
937        def frange(start, stop, step):
938            i = start
939            while i < stop:
940                yield i
941                i += step
942
943        def sample_functions(op):
944            yield lambda: op(range(10))
945            yield lambda: op(range(4, 12))
946            yield lambda: op(range(-4, -15, -1))
947            yield lambda: op([6.6, 5.5, 7.7])
948            yield lambda: op([(3, 4), (1, 2)])
949            yield lambda: op(frange(1.1, 3.3, 0.1))
950            yield lambda: op([np.nan, -np.inf, np.inf, np.nan])
951            yield lambda: op([(3,), (1,), (2,)])
952
953        for fn in sample_functions(op=min):
954            self._check_min_max(fn)
955
956        for fn in sample_functions(op=max):
957            self._check_min_max(fn)
958
959
960class TestOperatorMixedTypes(TestCase):
961
962    def test_eq_ne(self):
963        for opstr in ('eq', 'ne'):
964            op = getattr(operator, opstr)
965
966            @njit
967            def func(a, b):
968                return op(a, b)
969
970            # all these things should evaluate to being equal or not, all should
971            # survive typing.
972            things = (1, 0, True, False, 1.0, 2.0, 1.1, 1j, None, "", "1")
973            for x, y in itertools.product(things, things):
974                self.assertPreciseEqual(func.py_func(x, y), func(x, y))
975
976    def test_cmp(self):
977        for opstr in ('gt', 'lt', 'ge', 'le', 'eq', 'ne'):
978            op = getattr(operator, opstr)
979            @njit
980            def func(a, b):
981                return op(a, b)
982
983            # numerical things should all be comparable
984            things = (1, 0, True, False, 1.0, 0.0, 1.1)
985            for x, y in itertools.product(things, things):
986                expected = func.py_func(x, y)
987                got = func(x, y)
988                self.assertEqual(expected, got)
989
990
991if __name__ == '__main__':
992    unittest.main()
993