1"""
2Assorted utilities for use in tests.
3"""
4
5import cmath
6import contextlib
7import enum
8import errno
9import gc
10import math
11import platform
12import os
13import shutil
14import subprocess
15import sys
16import tempfile
17import time
18import io
19import ctypes
20import multiprocessing as mp
21import warnings
22import traceback
23from contextlib import contextmanager
24
25import numpy as np
26
27from numba import testing
28from numba.core import errors, typing, utils, config, cpu
29from numba.core.compiler import compile_extra, compile_isolated, Flags, DEFAULT_FLAGS
30import unittest
31from numba.core.runtime import rtsys
32from numba.np import numpy_support
33
34
35try:
36    import scipy
37except ImportError:
38    scipy = None
39
40
41enable_pyobj_flags = Flags()
42enable_pyobj_flags.set("enable_pyobject")
43
44force_pyobj_flags = Flags()
45force_pyobj_flags.set("force_pyobject")
46
47no_pyobj_flags = Flags()
48
49nrt_flags = Flags()
50nrt_flags.set("nrt")
51
52
53tag = testing.make_tag_decorator(['important', 'long_running'])
54
55_32bit = sys.maxsize <= 2 ** 32
56is_parfors_unsupported = _32bit
57skip_parfors_unsupported = unittest.skipIf(
58    is_parfors_unsupported,
59    'parfors not supported',
60)
61skip_py38_or_later = unittest.skipIf(
62    utils.PYVERSION >= (3, 8),
63    "unsupported on py3.8 or later"
64)
65skip_tryexcept_unsupported = unittest.skipIf(
66    utils.PYVERSION < (3, 7),
67    "try-except unsupported on py3.6 or earlier"
68)
69skip_tryexcept_supported = unittest.skipIf(
70    utils.PYVERSION >= (3, 7),
71    "try-except supported on py3.7 or later"
72)
73
74_msg = "SciPy needed for test"
75skip_unless_scipy = unittest.skipIf(scipy is None, _msg)
76
77_lnx_reason = 'linux only test'
78linux_only = unittest.skipIf(not sys.platform.startswith('linux'), _lnx_reason)
79
80_is_armv7l = platform.machine() == 'armv7l'
81
82disabled_test = unittest.skipIf(True, 'Test disabled')
83
84# See issue #4026, PPC64LE LLVM bug
85skip_ppc64le_issue4026 = unittest.skipIf(platform.machine() == 'ppc64le',
86                                         ("Hits: 'LLVM Invalid PPC CTR Loop! "
87                                          "UNREACHABLE executed' bug"))
88
89try:
90    import scipy.linalg.cython_lapack
91    has_lapack = True
92except ImportError:
93    has_lapack = False
94
95needs_lapack = unittest.skipUnless(has_lapack,
96                                   "LAPACK needs SciPy 1.0+")
97
98try:
99    import scipy.linalg.cython_blas
100    has_blas = True
101except ImportError:
102    has_blas = False
103
104needs_blas = unittest.skipUnless(has_blas, "BLAS needs SciPy 1.0+")
105
106
107class CompilationCache(object):
108    """
109    A cache of compilation results for various signatures and flags.
110    This can make tests significantly faster (or less slow).
111    """
112
113    def __init__(self):
114        self.typingctx = typing.Context()
115        self.targetctx = cpu.CPUContext(self.typingctx)
116        self.cr_cache = {}
117
118    def compile(self, func, args, return_type=None, flags=DEFAULT_FLAGS):
119        """
120        Compile the function or retrieve an already compiled result
121        from the cache.
122        """
123        from numba.core.registry import cpu_target
124
125        cache_key = (func, args, return_type, flags)
126        try:
127            cr = self.cr_cache[cache_key]
128        except KeyError:
129            # Register the contexts in case for nested @jit or @overload calls
130            # (same as compile_isolated())
131            with cpu_target.nested_context(self.typingctx, self.targetctx):
132                cr = compile_extra(self.typingctx, self.targetctx, func,
133                                   args, return_type, flags, locals={})
134            self.cr_cache[cache_key] = cr
135        return cr
136
137
138class TestCase(unittest.TestCase):
139
140    longMessage = True
141
142    # A random state yielding the same random numbers for any test case.
143    # Use as `self.random.<method name>`
144    @utils.cached_property
145    def random(self):
146        return np.random.RandomState(42)
147
148    def reset_module_warnings(self, module):
149        """
150        Reset the warnings registry of a module.  This can be necessary
151        as the warnings module is buggy in that regard.
152        See http://bugs.python.org/issue4180
153        """
154        if isinstance(module, str):
155            module = sys.modules[module]
156        try:
157            del module.__warningregistry__
158        except AttributeError:
159            pass
160
161    @contextlib.contextmanager
162    def assertTypingError(self):
163        """
164        A context manager that asserts the enclosed code block fails
165        compiling in nopython mode.
166        """
167        _accepted_errors = (errors.LoweringError, errors.TypingError,
168                            TypeError, NotImplementedError)
169        with self.assertRaises(_accepted_errors) as cm:
170            yield cm
171
172    @contextlib.contextmanager
173    def assertRefCount(self, *objects):
174        """
175        A context manager that asserts the given objects have the
176        same reference counts before and after executing the
177        enclosed block.
178        """
179        old_refcounts = [sys.getrefcount(x) for x in objects]
180        yield
181        new_refcounts = [sys.getrefcount(x) for x in objects]
182        for old, new, obj in zip(old_refcounts, new_refcounts, objects):
183            if old != new:
184                self.fail("Refcount changed from %d to %d for object: %r"
185                          % (old, new, obj))
186
187    @contextlib.contextmanager
188    def assertNoNRTLeak(self):
189        """
190        A context manager that asserts no NRT leak was created during
191        the execution of the enclosed block.
192        """
193        old = rtsys.get_allocation_stats()
194        yield
195        new = rtsys.get_allocation_stats()
196        total_alloc = new.alloc - old.alloc
197        total_free = new.free - old.free
198        total_mi_alloc = new.mi_alloc - old.mi_alloc
199        total_mi_free = new.mi_free - old.mi_free
200        self.assertEqual(total_alloc, total_free,
201                         "number of data allocs != number of data frees")
202        self.assertEqual(total_mi_alloc, total_mi_free,
203                         "number of meminfo allocs != number of meminfo frees")
204
205
206    _bool_types = (bool, np.bool_)
207    _exact_typesets = [_bool_types, utils.INT_TYPES, (str,), (np.integer,),
208                       (bytes, np.bytes_)]
209    _approx_typesets = [(float,), (complex,), (np.inexact)]
210    _sequence_typesets = [(tuple, list)]
211    _float_types = (float, np.floating)
212    _complex_types = (complex, np.complexfloating)
213
214    def _detect_family(self, numeric_object):
215        """
216        This function returns a string description of the type family
217        that the object in question belongs to.  Possible return values
218        are: "exact", "complex", "approximate", "sequence", and "unknown"
219        """
220        if isinstance(numeric_object, np.ndarray):
221            return "ndarray"
222
223        if isinstance(numeric_object, enum.Enum):
224            return "enum"
225
226        for tp in self._sequence_typesets:
227            if isinstance(numeric_object, tp):
228                return "sequence"
229
230        for tp in self._exact_typesets:
231            if isinstance(numeric_object, tp):
232                return "exact"
233
234        for tp in self._complex_types:
235            if isinstance(numeric_object, tp):
236                return "complex"
237
238        for tp in self._approx_typesets:
239            if isinstance(numeric_object, tp):
240                return "approximate"
241
242        return "unknown"
243
244    def _fix_dtype(self, dtype):
245        """
246        Fix the given *dtype* for comparison.
247        """
248        # Under 64-bit Windows, Numpy may return either int32 or int64
249        # arrays depending on the function.
250        if (sys.platform == 'win32' and sys.maxsize > 2**32 and
251            dtype == np.dtype('int32')):
252            return np.dtype('int64')
253        else:
254            return dtype
255
256    def _fix_strides(self, arr):
257        """
258        Return the strides of the given array, fixed for comparison.
259        Strides for 0- or 1-sized dimensions are ignored.
260        """
261        if arr.size == 0:
262            return [0] * arr.ndim
263        else:
264            return [stride / arr.itemsize
265                    for (stride, shape) in zip(arr.strides, arr.shape)
266                    if shape > 1]
267
268    def assertStridesEqual(self, first, second):
269        """
270        Test that two arrays have the same shape and strides.
271        """
272        self.assertEqual(first.shape, second.shape, "shapes differ")
273        self.assertEqual(first.itemsize, second.itemsize, "itemsizes differ")
274        self.assertEqual(self._fix_strides(first), self._fix_strides(second),
275                         "strides differ")
276
277    def assertPreciseEqual(self, first, second, prec='exact', ulps=1,
278                           msg=None, ignore_sign_on_zero=False,
279                           abs_tol=None
280                           ):
281        """
282        Versatile equality testing function with more built-in checks than
283        standard assertEqual().
284
285        For arrays, test that layout, dtype, shape are identical, and
286        recursively call assertPreciseEqual() on the contents.
287
288        For other sequences, recursively call assertPreciseEqual() on
289        the contents.
290
291        For scalars, test that two scalars or have similar types and are
292        equal up to a computed precision.
293        If the scalars are instances of exact types or if *prec* is
294        'exact', they are compared exactly.
295        If the scalars are instances of inexact types (float, complex)
296        and *prec* is not 'exact', then the number of significant bits
297        is computed according to the value of *prec*: 53 bits if *prec*
298        is 'double', 24 bits if *prec* is single.  This number of bits
299        can be lowered by raising the *ulps* value.
300        ignore_sign_on_zero can be set to True if zeros are to be considered
301        equal regardless of their sign bit.
302        abs_tol if this is set to a float value its value is used in the
303        following. If, however, this is set to the string "eps" then machine
304        precision of the type(first) is used in the following instead. This
305        kwarg is used to check if the absolute difference in value between first
306        and second is less than the value set, if so the numbers being compared
307        are considered equal. (This is to handle small numbers typically of
308        magnitude less than machine precision).
309
310        Any value of *prec* other than 'exact', 'single' or 'double'
311        will raise an error.
312        """
313        try:
314            self._assertPreciseEqual(first, second, prec, ulps, msg,
315                ignore_sign_on_zero, abs_tol)
316        except AssertionError as exc:
317            failure_msg = str(exc)
318            # Fall off of the 'except' scope to avoid Python 3 exception
319            # chaining.
320        else:
321            return
322        # Decorate the failure message with more information
323        self.fail("when comparing %s and %s: %s" % (first, second, failure_msg))
324
325    def _assertPreciseEqual(self, first, second, prec='exact', ulps=1,
326                            msg=None, ignore_sign_on_zero=False,
327                            abs_tol=None):
328        """Recursive workhorse for assertPreciseEqual()."""
329
330        def _assertNumberEqual(first, second, delta=None):
331            if (delta is None or first == second == 0.0
332                or math.isinf(first) or math.isinf(second)):
333                self.assertEqual(first, second, msg=msg)
334                # For signed zeros
335                if not ignore_sign_on_zero:
336                    try:
337                        if math.copysign(1, first) != math.copysign(1, second):
338                            self.fail(
339                                self._formatMessage(msg,
340                                                    "%s != %s" %
341                                                    (first, second)))
342                    except TypeError:
343                        pass
344            else:
345                self.assertAlmostEqual(first, second, delta=delta, msg=msg)
346
347        first_family = self._detect_family(first)
348        second_family = self._detect_family(second)
349
350        assertion_message = "Type Family mismatch. (%s != %s)" % (first_family,
351            second_family)
352        if msg:
353            assertion_message += ': %s' % (msg,)
354        self.assertEqual(first_family, second_family, msg=assertion_message)
355
356        # We now know they are in the same comparison family
357        compare_family = first_family
358
359        # For recognized sequences, recurse
360        if compare_family == "ndarray":
361            dtype = self._fix_dtype(first.dtype)
362            self.assertEqual(dtype, self._fix_dtype(second.dtype))
363            self.assertEqual(first.ndim, second.ndim,
364                             "different number of dimensions")
365            self.assertEqual(first.shape, second.shape,
366                             "different shapes")
367            self.assertEqual(first.flags.writeable, second.flags.writeable,
368                             "different mutability")
369            # itemsize is already checked by the dtype test above
370            self.assertEqual(self._fix_strides(first),
371                self._fix_strides(second), "different strides")
372            if first.dtype != dtype:
373                first = first.astype(dtype)
374            if second.dtype != dtype:
375                second = second.astype(dtype)
376            for a, b in zip(first.flat, second.flat):
377                self._assertPreciseEqual(a, b, prec, ulps, msg,
378                                         ignore_sign_on_zero, abs_tol)
379            return
380
381        elif compare_family == "sequence":
382            self.assertEqual(len(first), len(second), msg=msg)
383            for a, b in zip(first, second):
384                self._assertPreciseEqual(a, b, prec, ulps, msg,
385                                         ignore_sign_on_zero, abs_tol)
386            return
387
388        elif compare_family == "exact":
389            exact_comparison = True
390
391        elif compare_family in ["complex", "approximate"]:
392            exact_comparison = False
393
394        elif compare_family == "enum":
395            self.assertIs(first.__class__, second.__class__)
396            self._assertPreciseEqual(first.value, second.value,
397                                     prec, ulps, msg,
398                                     ignore_sign_on_zero, abs_tol)
399            return
400
401        elif compare_family == "unknown":
402            # Assume these are non-numeric types: we will fall back
403            # on regular unittest comparison.
404            self.assertIs(first.__class__, second.__class__)
405            exact_comparison = True
406
407        else:
408            assert 0, "unexpected family"
409
410        # If a Numpy scalar, check the dtype is exactly the same too
411        # (required for datetime64 and timedelta64).
412        if hasattr(first, 'dtype') and hasattr(second, 'dtype'):
413            self.assertEqual(first.dtype, second.dtype)
414
415        # Mixing bools and non-bools should always fail
416        if (isinstance(first, self._bool_types) !=
417            isinstance(second, self._bool_types)):
418            assertion_message = ("Mismatching return types (%s vs. %s)"
419                                 % (first.__class__, second.__class__))
420            if msg:
421                assertion_message += ': %s' % (msg,)
422            self.fail(assertion_message)
423
424        try:
425            if cmath.isnan(first) and cmath.isnan(second):
426                # The NaNs will compare unequal, skip regular comparison
427                return
428        except TypeError:
429            # Not floats.
430            pass
431
432        # if absolute comparison is set, use it
433        if abs_tol is not None:
434            if abs_tol == "eps":
435                rtol = np.finfo(type(first)).eps
436            elif isinstance(abs_tol, float):
437                rtol = abs_tol
438            else:
439                raise ValueError("abs_tol is not \"eps\" or a float, found %s"
440                    % abs_tol)
441            if abs(first - second) < rtol:
442                return
443
444        exact_comparison = exact_comparison or prec == 'exact'
445
446        if not exact_comparison and prec != 'exact':
447            if prec == 'single':
448                bits = 24
449            elif prec == 'double':
450                bits = 53
451            else:
452                raise ValueError("unsupported precision %r" % (prec,))
453            k = 2 ** (ulps - bits - 1)
454            delta = k * (abs(first) + abs(second))
455        else:
456            delta = None
457        if isinstance(first, self._complex_types):
458            _assertNumberEqual(first.real, second.real, delta)
459            _assertNumberEqual(first.imag, second.imag, delta)
460        elif isinstance(first, (np.timedelta64, np.datetime64)):
461            # Since Np 1.16 NaT == NaT is False, so special comparison needed
462            if numpy_support.numpy_version >= (1, 16) and np.isnat(first):
463                self.assertEqual(np.isnat(first), np.isnat(second))
464            else:
465                _assertNumberEqual(first, second, delta)
466        else:
467            _assertNumberEqual(first, second, delta)
468
469    def run_nullary_func(self, pyfunc, flags):
470        """
471        Compile the 0-argument *pyfunc* with the given *flags*, and check
472        it returns the same result as the pure Python function.
473        The got and expected results are returned.
474        """
475        cr = compile_isolated(pyfunc, (), flags=flags)
476        cfunc = cr.entry_point
477        expected = pyfunc()
478        got = cfunc()
479        self.assertPreciseEqual(got, expected)
480        return got, expected
481
482
483class SerialMixin(object):
484    """Mixin to mark test for serial execution.
485    """
486    _numba_parallel_test_ = False
487
488
489# Various helpers
490
491@contextlib.contextmanager
492def override_config(name, value):
493    """
494    Return a context manager that temporarily sets Numba config variable
495    *name* to *value*.  *name* must be the name of an existing variable
496    in numba.config.
497    """
498    old_value = getattr(config, name)
499    setattr(config, name, value)
500    try:
501        yield
502    finally:
503        setattr(config, name, old_value)
504
505
506@contextlib.contextmanager
507def override_env_config(name, value):
508    """
509    Return a context manager that temporarily sets an Numba config environment
510    *name* to *value*.
511    """
512    old = os.environ.get(name)
513    os.environ[name] = value
514    config.reload_config()
515
516    try:
517        yield
518    finally:
519        if old is None:
520            # If it wasn't set originally, delete the environ var
521            del os.environ[name]
522        else:
523            # Otherwise, restore to the old value
524            os.environ[name] = old
525        # Always reload config
526        config.reload_config()
527
528
529def compile_function(name, code, globs):
530    """
531    Given a *code* string, compile it with globals *globs* and return
532    the function named *name*.
533    """
534    co = compile(code.rstrip(), "<string>", "single")
535    ns = {}
536    eval(co, globs, ns)
537    return ns[name]
538
539def tweak_code(func, codestring=None, consts=None):
540    """
541    Tweak the code object of the given function by replacing its
542    *codestring* (a bytes object) and *consts* tuple, optionally.
543    """
544    co = func.__code__
545    tp = type(co)
546    if codestring is None:
547        codestring = co.co_code
548    if consts is None:
549        consts = co.co_consts
550    if utils.PYVERSION >= (3, 8):
551        new_code = tp(co.co_argcount, co.co_posonlyargcount,
552                      co.co_kwonlyargcount, co.co_nlocals,
553                      co.co_stacksize, co.co_flags, codestring,
554                      consts, co.co_names, co.co_varnames,
555                      co.co_filename, co.co_name, co.co_firstlineno,
556                      co.co_lnotab)
557    else:
558        new_code = tp(co.co_argcount, co.co_kwonlyargcount, co.co_nlocals,
559                      co.co_stacksize, co.co_flags, codestring,
560                      consts, co.co_names, co.co_varnames,
561                      co.co_filename, co.co_name, co.co_firstlineno,
562                      co.co_lnotab)
563    func.__code__ = new_code
564
565
566_trashcan_dir = 'numba-tests'
567
568if os.name == 'nt':
569    # Under Windows, gettempdir() points to the user-local temp dir
570    _trashcan_dir = os.path.join(tempfile.gettempdir(), _trashcan_dir)
571else:
572    # Mix the UID into the directory name to allow different users to
573    # run the test suite without permission errors (issue #1586)
574    _trashcan_dir = os.path.join(tempfile.gettempdir(),
575                                 "%s.%s" % (_trashcan_dir, os.getuid()))
576
577# Stale temporary directories are deleted after they are older than this value.
578# The test suite probably won't ever take longer than this...
579_trashcan_timeout = 24 * 3600  # 1 day
580
581def _create_trashcan_dir():
582    try:
583        os.mkdir(_trashcan_dir)
584    except OSError as e:
585        if e.errno != errno.EEXIST:
586            raise
587
588def _purge_trashcan_dir():
589    freshness_threshold = time.time() - _trashcan_timeout
590    for fn in sorted(os.listdir(_trashcan_dir)):
591        fn = os.path.join(_trashcan_dir, fn)
592        try:
593            st = os.stat(fn)
594            if st.st_mtime < freshness_threshold:
595                shutil.rmtree(fn, ignore_errors=True)
596        except OSError as e:
597            # In parallel testing, several processes can attempt to
598            # remove the same entry at once, ignore.
599            pass
600
601def _create_trashcan_subdir(prefix):
602    _purge_trashcan_dir()
603    path = tempfile.mkdtemp(prefix=prefix + '-', dir=_trashcan_dir)
604    return path
605
606def temp_directory(prefix):
607    """
608    Create a temporary directory with the given *prefix* that will survive
609    at least as long as this process invocation.  The temporary directory
610    will be eventually deleted when it becomes stale enough.
611
612    This is necessary because a DLL file can't be deleted while in use
613    under Windows.
614
615    An interesting side-effect is to be able to inspect the test files
616    shortly after a test suite run.
617    """
618    _create_trashcan_dir()
619    return _create_trashcan_subdir(prefix)
620
621
622def import_dynamic(modname):
623    """
624    Import and return a module of the given name.  Care is taken to
625    avoid issues due to Python's internal directory caching.
626    """
627    import importlib
628    importlib.invalidate_caches()
629    __import__(modname)
630    return sys.modules[modname]
631
632
633# From CPython
634
635@contextlib.contextmanager
636def captured_output(stream_name):
637    """Return a context manager used by captured_stdout/stdin/stderr
638    that temporarily replaces the sys stream *stream_name* with a StringIO."""
639    orig_stdout = getattr(sys, stream_name)
640    setattr(sys, stream_name, io.StringIO())
641    try:
642        yield getattr(sys, stream_name)
643    finally:
644        setattr(sys, stream_name, orig_stdout)
645
646def captured_stdout():
647    """Capture the output of sys.stdout:
648
649       with captured_stdout() as stdout:
650           print("hello")
651       self.assertEqual(stdout.getvalue(), "hello\n")
652    """
653    return captured_output("stdout")
654
655def captured_stderr():
656    """Capture the output of sys.stderr:
657
658       with captured_stderr() as stderr:
659           print("hello", file=sys.stderr)
660       self.assertEqual(stderr.getvalue(), "hello\n")
661    """
662    return captured_output("stderr")
663
664
665@contextlib.contextmanager
666def capture_cache_log():
667    with captured_stdout() as out:
668        with override_config('DEBUG_CACHE', True):
669            yield out
670
671
672class MemoryLeak(object):
673
674    __enable_leak_check = True
675
676    def memory_leak_setup(self):
677        # Clean up any NRT-backed objects hanging in a dead reference cycle
678        gc.collect()
679        self.__init_stats = rtsys.get_allocation_stats()
680
681    def memory_leak_teardown(self):
682        if self.__enable_leak_check:
683            self.assert_no_memory_leak()
684
685    def assert_no_memory_leak(self):
686        old = self.__init_stats
687        new = rtsys.get_allocation_stats()
688        total_alloc = new.alloc - old.alloc
689        total_free = new.free - old.free
690        total_mi_alloc = new.mi_alloc - old.mi_alloc
691        total_mi_free = new.mi_free - old.mi_free
692        self.assertEqual(total_alloc, total_free)
693        self.assertEqual(total_mi_alloc, total_mi_free)
694
695    def disable_leak_check(self):
696        # For per-test use when MemoryLeakMixin is injected into a TestCase
697        self.__enable_leak_check = False
698
699
700class MemoryLeakMixin(MemoryLeak):
701
702    def setUp(self):
703        super(MemoryLeakMixin, self).setUp()
704        self.memory_leak_setup()
705
706    def tearDown(self):
707        super(MemoryLeakMixin, self).tearDown()
708        gc.collect()
709        self.memory_leak_teardown()
710
711
712@contextlib.contextmanager
713def forbid_codegen():
714    """
715    Forbid LLVM code generation during the execution of the context
716    manager's enclosed block.
717
718    If code generation is invoked, a RuntimeError is raised.
719    """
720    from numba.core import codegen
721    patchpoints = ['CodeLibrary._finalize_final_module']
722
723    old = {}
724    def fail(*args, **kwargs):
725        raise RuntimeError("codegen forbidden by test case")
726    try:
727        # XXX use the mock library instead?
728        for name in patchpoints:
729            parts = name.split('.')
730            obj = codegen
731            for attrname in parts[:-1]:
732                obj = getattr(obj, attrname)
733            attrname = parts[-1]
734            value = getattr(obj, attrname)
735            assert callable(value), ("%r should be callable" % name)
736            old[obj, attrname] = value
737            setattr(obj, attrname, fail)
738        yield
739    finally:
740        for (obj, attrname), value in old.items():
741            setattr(obj, attrname, value)
742
743
744# For details about redirection of file-descriptor, read
745# https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/
746
747@contextlib.contextmanager
748def redirect_fd(fd):
749    """
750    Temporarily redirect *fd* to a pipe's write end and return a file object
751    wrapping the pipe's read end.
752    """
753
754    from numba import _helperlib
755    libnumba = ctypes.CDLL(_helperlib.__file__)
756
757    libnumba._numba_flush_stdout()
758    save = os.dup(fd)
759    r, w = os.pipe()
760    try:
761        os.dup2(w, fd)
762        yield io.open(r, "r")
763    finally:
764        libnumba._numba_flush_stdout()
765        os.close(w)
766        os.dup2(save, fd)
767        os.close(save)
768
769
770def redirect_c_stdout():
771    """Redirect C stdout
772    """
773    fd = sys.__stdout__.fileno()
774    return redirect_fd(fd)
775
776
777def run_in_new_process_caching(func, cache_dir_prefix=__name__, verbose=True):
778    """Spawn a new process to run `func` with a temporary cache directory.
779
780    The childprocess's stdout and stderr will be captured and redirected to
781    the current process's stdout and stderr.
782
783    Returns
784    -------
785    ret : dict
786        exitcode: 0 for success. 1 for exception-raised.
787        stdout: str
788        stderr: str
789    """
790    cache_dir = temp_directory(cache_dir_prefix)
791    return run_in_new_process_in_cache_dir(func, cache_dir, verbose=verbose)
792
793
794def run_in_new_process_in_cache_dir(func, cache_dir, verbose=True):
795    """Spawn a new process to run `func` with a temporary cache directory.
796
797    The childprocess's stdout and stderr will be captured and redirected to
798    the current process's stdout and stderr.
799
800    Similar to ``run_in_new_process_caching()`` but the ``cache_dir`` is a
801    directory path instead of a name prefix for the directory path.
802
803    Returns
804    -------
805    ret : dict
806        exitcode: 0 for success. 1 for exception-raised.
807        stdout: str
808        stderr: str
809    """
810    ctx = mp.get_context('spawn')
811    qout = ctx.Queue()
812    with override_env_config('NUMBA_CACHE_DIR', cache_dir):
813        proc = ctx.Process(target=_remote_runner, args=[func, qout])
814        proc.start()
815        proc.join()
816        stdout = qout.get_nowait()
817        stderr = qout.get_nowait()
818        if verbose and stdout.strip():
819            print()
820            print('STDOUT'.center(80, '-'))
821            print(stdout)
822        if verbose and stderr.strip():
823            print(file=sys.stderr)
824            print('STDERR'.center(80, '-'), file=sys.stderr)
825            print(stderr, file=sys.stderr)
826    return {
827        'exitcode': proc.exitcode,
828        'stdout': stdout,
829        'stderr': stderr,
830    }
831
832
833def _remote_runner(fn, qout):
834    """Used by `run_in_new_process_caching()`
835    """
836    with captured_stderr() as stderr:
837        with captured_stdout() as stdout:
838            try:
839                fn()
840            except Exception:
841                traceback.print_exc()
842                exitcode = 1
843            else:
844                exitcode = 0
845        qout.put(stdout.getvalue())
846    qout.put(stderr.getvalue())
847    sys.exit(exitcode)
848
849class CheckWarningsMixin(object):
850    @contextlib.contextmanager
851    def check_warnings(self, messages, category=RuntimeWarning):
852        with warnings.catch_warnings(record=True) as catch:
853            warnings.simplefilter("always")
854            yield
855        found = 0
856        for w in catch:
857            for m in messages:
858                if m in str(w.message):
859                    self.assertEqual(w.category, category)
860                    found += 1
861        self.assertEqual(found, len(messages))
862