1""" 2Assorted utilities for use in tests. 3""" 4 5import cmath 6import contextlib 7import enum 8import errno 9import gc 10import math 11import platform 12import os 13import shutil 14import subprocess 15import sys 16import tempfile 17import time 18import io 19import ctypes 20import multiprocessing as mp 21import warnings 22import traceback 23from contextlib import contextmanager 24 25import numpy as np 26 27from numba import testing 28from numba.core import errors, typing, utils, config, cpu 29from numba.core.compiler import compile_extra, compile_isolated, Flags, DEFAULT_FLAGS 30import unittest 31from numba.core.runtime import rtsys 32from numba.np import numpy_support 33 34 35try: 36 import scipy 37except ImportError: 38 scipy = None 39 40 41enable_pyobj_flags = Flags() 42enable_pyobj_flags.set("enable_pyobject") 43 44force_pyobj_flags = Flags() 45force_pyobj_flags.set("force_pyobject") 46 47no_pyobj_flags = Flags() 48 49nrt_flags = Flags() 50nrt_flags.set("nrt") 51 52 53tag = testing.make_tag_decorator(['important', 'long_running']) 54 55_32bit = sys.maxsize <= 2 ** 32 56is_parfors_unsupported = _32bit 57skip_parfors_unsupported = unittest.skipIf( 58 is_parfors_unsupported, 59 'parfors not supported', 60) 61skip_py38_or_later = unittest.skipIf( 62 utils.PYVERSION >= (3, 8), 63 "unsupported on py3.8 or later" 64) 65skip_tryexcept_unsupported = unittest.skipIf( 66 utils.PYVERSION < (3, 7), 67 "try-except unsupported on py3.6 or earlier" 68) 69skip_tryexcept_supported = unittest.skipIf( 70 utils.PYVERSION >= (3, 7), 71 "try-except supported on py3.7 or later" 72) 73 74_msg = "SciPy needed for test" 75skip_unless_scipy = unittest.skipIf(scipy is None, _msg) 76 77_lnx_reason = 'linux only test' 78linux_only = unittest.skipIf(not sys.platform.startswith('linux'), _lnx_reason) 79 80_is_armv7l = platform.machine() == 'armv7l' 81 82disabled_test = unittest.skipIf(True, 'Test disabled') 83 84# See issue #4026, PPC64LE LLVM bug 85skip_ppc64le_issue4026 = unittest.skipIf(platform.machine() == 'ppc64le', 86 ("Hits: 'LLVM Invalid PPC CTR Loop! " 87 "UNREACHABLE executed' bug")) 88 89try: 90 import scipy.linalg.cython_lapack 91 has_lapack = True 92except ImportError: 93 has_lapack = False 94 95needs_lapack = unittest.skipUnless(has_lapack, 96 "LAPACK needs SciPy 1.0+") 97 98try: 99 import scipy.linalg.cython_blas 100 has_blas = True 101except ImportError: 102 has_blas = False 103 104needs_blas = unittest.skipUnless(has_blas, "BLAS needs SciPy 1.0+") 105 106 107class CompilationCache(object): 108 """ 109 A cache of compilation results for various signatures and flags. 110 This can make tests significantly faster (or less slow). 111 """ 112 113 def __init__(self): 114 self.typingctx = typing.Context() 115 self.targetctx = cpu.CPUContext(self.typingctx) 116 self.cr_cache = {} 117 118 def compile(self, func, args, return_type=None, flags=DEFAULT_FLAGS): 119 """ 120 Compile the function or retrieve an already compiled result 121 from the cache. 122 """ 123 from numba.core.registry import cpu_target 124 125 cache_key = (func, args, return_type, flags) 126 try: 127 cr = self.cr_cache[cache_key] 128 except KeyError: 129 # Register the contexts in case for nested @jit or @overload calls 130 # (same as compile_isolated()) 131 with cpu_target.nested_context(self.typingctx, self.targetctx): 132 cr = compile_extra(self.typingctx, self.targetctx, func, 133 args, return_type, flags, locals={}) 134 self.cr_cache[cache_key] = cr 135 return cr 136 137 138class TestCase(unittest.TestCase): 139 140 longMessage = True 141 142 # A random state yielding the same random numbers for any test case. 143 # Use as `self.random.<method name>` 144 @utils.cached_property 145 def random(self): 146 return np.random.RandomState(42) 147 148 def reset_module_warnings(self, module): 149 """ 150 Reset the warnings registry of a module. This can be necessary 151 as the warnings module is buggy in that regard. 152 See http://bugs.python.org/issue4180 153 """ 154 if isinstance(module, str): 155 module = sys.modules[module] 156 try: 157 del module.__warningregistry__ 158 except AttributeError: 159 pass 160 161 @contextlib.contextmanager 162 def assertTypingError(self): 163 """ 164 A context manager that asserts the enclosed code block fails 165 compiling in nopython mode. 166 """ 167 _accepted_errors = (errors.LoweringError, errors.TypingError, 168 TypeError, NotImplementedError) 169 with self.assertRaises(_accepted_errors) as cm: 170 yield cm 171 172 @contextlib.contextmanager 173 def assertRefCount(self, *objects): 174 """ 175 A context manager that asserts the given objects have the 176 same reference counts before and after executing the 177 enclosed block. 178 """ 179 old_refcounts = [sys.getrefcount(x) for x in objects] 180 yield 181 new_refcounts = [sys.getrefcount(x) for x in objects] 182 for old, new, obj in zip(old_refcounts, new_refcounts, objects): 183 if old != new: 184 self.fail("Refcount changed from %d to %d for object: %r" 185 % (old, new, obj)) 186 187 @contextlib.contextmanager 188 def assertNoNRTLeak(self): 189 """ 190 A context manager that asserts no NRT leak was created during 191 the execution of the enclosed block. 192 """ 193 old = rtsys.get_allocation_stats() 194 yield 195 new = rtsys.get_allocation_stats() 196 total_alloc = new.alloc - old.alloc 197 total_free = new.free - old.free 198 total_mi_alloc = new.mi_alloc - old.mi_alloc 199 total_mi_free = new.mi_free - old.mi_free 200 self.assertEqual(total_alloc, total_free, 201 "number of data allocs != number of data frees") 202 self.assertEqual(total_mi_alloc, total_mi_free, 203 "number of meminfo allocs != number of meminfo frees") 204 205 206 _bool_types = (bool, np.bool_) 207 _exact_typesets = [_bool_types, utils.INT_TYPES, (str,), (np.integer,), 208 (bytes, np.bytes_)] 209 _approx_typesets = [(float,), (complex,), (np.inexact)] 210 _sequence_typesets = [(tuple, list)] 211 _float_types = (float, np.floating) 212 _complex_types = (complex, np.complexfloating) 213 214 def _detect_family(self, numeric_object): 215 """ 216 This function returns a string description of the type family 217 that the object in question belongs to. Possible return values 218 are: "exact", "complex", "approximate", "sequence", and "unknown" 219 """ 220 if isinstance(numeric_object, np.ndarray): 221 return "ndarray" 222 223 if isinstance(numeric_object, enum.Enum): 224 return "enum" 225 226 for tp in self._sequence_typesets: 227 if isinstance(numeric_object, tp): 228 return "sequence" 229 230 for tp in self._exact_typesets: 231 if isinstance(numeric_object, tp): 232 return "exact" 233 234 for tp in self._complex_types: 235 if isinstance(numeric_object, tp): 236 return "complex" 237 238 for tp in self._approx_typesets: 239 if isinstance(numeric_object, tp): 240 return "approximate" 241 242 return "unknown" 243 244 def _fix_dtype(self, dtype): 245 """ 246 Fix the given *dtype* for comparison. 247 """ 248 # Under 64-bit Windows, Numpy may return either int32 or int64 249 # arrays depending on the function. 250 if (sys.platform == 'win32' and sys.maxsize > 2**32 and 251 dtype == np.dtype('int32')): 252 return np.dtype('int64') 253 else: 254 return dtype 255 256 def _fix_strides(self, arr): 257 """ 258 Return the strides of the given array, fixed for comparison. 259 Strides for 0- or 1-sized dimensions are ignored. 260 """ 261 if arr.size == 0: 262 return [0] * arr.ndim 263 else: 264 return [stride / arr.itemsize 265 for (stride, shape) in zip(arr.strides, arr.shape) 266 if shape > 1] 267 268 def assertStridesEqual(self, first, second): 269 """ 270 Test that two arrays have the same shape and strides. 271 """ 272 self.assertEqual(first.shape, second.shape, "shapes differ") 273 self.assertEqual(first.itemsize, second.itemsize, "itemsizes differ") 274 self.assertEqual(self._fix_strides(first), self._fix_strides(second), 275 "strides differ") 276 277 def assertPreciseEqual(self, first, second, prec='exact', ulps=1, 278 msg=None, ignore_sign_on_zero=False, 279 abs_tol=None 280 ): 281 """ 282 Versatile equality testing function with more built-in checks than 283 standard assertEqual(). 284 285 For arrays, test that layout, dtype, shape are identical, and 286 recursively call assertPreciseEqual() on the contents. 287 288 For other sequences, recursively call assertPreciseEqual() on 289 the contents. 290 291 For scalars, test that two scalars or have similar types and are 292 equal up to a computed precision. 293 If the scalars are instances of exact types or if *prec* is 294 'exact', they are compared exactly. 295 If the scalars are instances of inexact types (float, complex) 296 and *prec* is not 'exact', then the number of significant bits 297 is computed according to the value of *prec*: 53 bits if *prec* 298 is 'double', 24 bits if *prec* is single. This number of bits 299 can be lowered by raising the *ulps* value. 300 ignore_sign_on_zero can be set to True if zeros are to be considered 301 equal regardless of their sign bit. 302 abs_tol if this is set to a float value its value is used in the 303 following. If, however, this is set to the string "eps" then machine 304 precision of the type(first) is used in the following instead. This 305 kwarg is used to check if the absolute difference in value between first 306 and second is less than the value set, if so the numbers being compared 307 are considered equal. (This is to handle small numbers typically of 308 magnitude less than machine precision). 309 310 Any value of *prec* other than 'exact', 'single' or 'double' 311 will raise an error. 312 """ 313 try: 314 self._assertPreciseEqual(first, second, prec, ulps, msg, 315 ignore_sign_on_zero, abs_tol) 316 except AssertionError as exc: 317 failure_msg = str(exc) 318 # Fall off of the 'except' scope to avoid Python 3 exception 319 # chaining. 320 else: 321 return 322 # Decorate the failure message with more information 323 self.fail("when comparing %s and %s: %s" % (first, second, failure_msg)) 324 325 def _assertPreciseEqual(self, first, second, prec='exact', ulps=1, 326 msg=None, ignore_sign_on_zero=False, 327 abs_tol=None): 328 """Recursive workhorse for assertPreciseEqual().""" 329 330 def _assertNumberEqual(first, second, delta=None): 331 if (delta is None or first == second == 0.0 332 or math.isinf(first) or math.isinf(second)): 333 self.assertEqual(first, second, msg=msg) 334 # For signed zeros 335 if not ignore_sign_on_zero: 336 try: 337 if math.copysign(1, first) != math.copysign(1, second): 338 self.fail( 339 self._formatMessage(msg, 340 "%s != %s" % 341 (first, second))) 342 except TypeError: 343 pass 344 else: 345 self.assertAlmostEqual(first, second, delta=delta, msg=msg) 346 347 first_family = self._detect_family(first) 348 second_family = self._detect_family(second) 349 350 assertion_message = "Type Family mismatch. (%s != %s)" % (first_family, 351 second_family) 352 if msg: 353 assertion_message += ': %s' % (msg,) 354 self.assertEqual(first_family, second_family, msg=assertion_message) 355 356 # We now know they are in the same comparison family 357 compare_family = first_family 358 359 # For recognized sequences, recurse 360 if compare_family == "ndarray": 361 dtype = self._fix_dtype(first.dtype) 362 self.assertEqual(dtype, self._fix_dtype(second.dtype)) 363 self.assertEqual(first.ndim, second.ndim, 364 "different number of dimensions") 365 self.assertEqual(first.shape, second.shape, 366 "different shapes") 367 self.assertEqual(first.flags.writeable, second.flags.writeable, 368 "different mutability") 369 # itemsize is already checked by the dtype test above 370 self.assertEqual(self._fix_strides(first), 371 self._fix_strides(second), "different strides") 372 if first.dtype != dtype: 373 first = first.astype(dtype) 374 if second.dtype != dtype: 375 second = second.astype(dtype) 376 for a, b in zip(first.flat, second.flat): 377 self._assertPreciseEqual(a, b, prec, ulps, msg, 378 ignore_sign_on_zero, abs_tol) 379 return 380 381 elif compare_family == "sequence": 382 self.assertEqual(len(first), len(second), msg=msg) 383 for a, b in zip(first, second): 384 self._assertPreciseEqual(a, b, prec, ulps, msg, 385 ignore_sign_on_zero, abs_tol) 386 return 387 388 elif compare_family == "exact": 389 exact_comparison = True 390 391 elif compare_family in ["complex", "approximate"]: 392 exact_comparison = False 393 394 elif compare_family == "enum": 395 self.assertIs(first.__class__, second.__class__) 396 self._assertPreciseEqual(first.value, second.value, 397 prec, ulps, msg, 398 ignore_sign_on_zero, abs_tol) 399 return 400 401 elif compare_family == "unknown": 402 # Assume these are non-numeric types: we will fall back 403 # on regular unittest comparison. 404 self.assertIs(first.__class__, second.__class__) 405 exact_comparison = True 406 407 else: 408 assert 0, "unexpected family" 409 410 # If a Numpy scalar, check the dtype is exactly the same too 411 # (required for datetime64 and timedelta64). 412 if hasattr(first, 'dtype') and hasattr(second, 'dtype'): 413 self.assertEqual(first.dtype, second.dtype) 414 415 # Mixing bools and non-bools should always fail 416 if (isinstance(first, self._bool_types) != 417 isinstance(second, self._bool_types)): 418 assertion_message = ("Mismatching return types (%s vs. %s)" 419 % (first.__class__, second.__class__)) 420 if msg: 421 assertion_message += ': %s' % (msg,) 422 self.fail(assertion_message) 423 424 try: 425 if cmath.isnan(first) and cmath.isnan(second): 426 # The NaNs will compare unequal, skip regular comparison 427 return 428 except TypeError: 429 # Not floats. 430 pass 431 432 # if absolute comparison is set, use it 433 if abs_tol is not None: 434 if abs_tol == "eps": 435 rtol = np.finfo(type(first)).eps 436 elif isinstance(abs_tol, float): 437 rtol = abs_tol 438 else: 439 raise ValueError("abs_tol is not \"eps\" or a float, found %s" 440 % abs_tol) 441 if abs(first - second) < rtol: 442 return 443 444 exact_comparison = exact_comparison or prec == 'exact' 445 446 if not exact_comparison and prec != 'exact': 447 if prec == 'single': 448 bits = 24 449 elif prec == 'double': 450 bits = 53 451 else: 452 raise ValueError("unsupported precision %r" % (prec,)) 453 k = 2 ** (ulps - bits - 1) 454 delta = k * (abs(first) + abs(second)) 455 else: 456 delta = None 457 if isinstance(first, self._complex_types): 458 _assertNumberEqual(first.real, second.real, delta) 459 _assertNumberEqual(first.imag, second.imag, delta) 460 elif isinstance(first, (np.timedelta64, np.datetime64)): 461 # Since Np 1.16 NaT == NaT is False, so special comparison needed 462 if numpy_support.numpy_version >= (1, 16) and np.isnat(first): 463 self.assertEqual(np.isnat(first), np.isnat(second)) 464 else: 465 _assertNumberEqual(first, second, delta) 466 else: 467 _assertNumberEqual(first, second, delta) 468 469 def run_nullary_func(self, pyfunc, flags): 470 """ 471 Compile the 0-argument *pyfunc* with the given *flags*, and check 472 it returns the same result as the pure Python function. 473 The got and expected results are returned. 474 """ 475 cr = compile_isolated(pyfunc, (), flags=flags) 476 cfunc = cr.entry_point 477 expected = pyfunc() 478 got = cfunc() 479 self.assertPreciseEqual(got, expected) 480 return got, expected 481 482 483class SerialMixin(object): 484 """Mixin to mark test for serial execution. 485 """ 486 _numba_parallel_test_ = False 487 488 489# Various helpers 490 491@contextlib.contextmanager 492def override_config(name, value): 493 """ 494 Return a context manager that temporarily sets Numba config variable 495 *name* to *value*. *name* must be the name of an existing variable 496 in numba.config. 497 """ 498 old_value = getattr(config, name) 499 setattr(config, name, value) 500 try: 501 yield 502 finally: 503 setattr(config, name, old_value) 504 505 506@contextlib.contextmanager 507def override_env_config(name, value): 508 """ 509 Return a context manager that temporarily sets an Numba config environment 510 *name* to *value*. 511 """ 512 old = os.environ.get(name) 513 os.environ[name] = value 514 config.reload_config() 515 516 try: 517 yield 518 finally: 519 if old is None: 520 # If it wasn't set originally, delete the environ var 521 del os.environ[name] 522 else: 523 # Otherwise, restore to the old value 524 os.environ[name] = old 525 # Always reload config 526 config.reload_config() 527 528 529def compile_function(name, code, globs): 530 """ 531 Given a *code* string, compile it with globals *globs* and return 532 the function named *name*. 533 """ 534 co = compile(code.rstrip(), "<string>", "single") 535 ns = {} 536 eval(co, globs, ns) 537 return ns[name] 538 539def tweak_code(func, codestring=None, consts=None): 540 """ 541 Tweak the code object of the given function by replacing its 542 *codestring* (a bytes object) and *consts* tuple, optionally. 543 """ 544 co = func.__code__ 545 tp = type(co) 546 if codestring is None: 547 codestring = co.co_code 548 if consts is None: 549 consts = co.co_consts 550 if utils.PYVERSION >= (3, 8): 551 new_code = tp(co.co_argcount, co.co_posonlyargcount, 552 co.co_kwonlyargcount, co.co_nlocals, 553 co.co_stacksize, co.co_flags, codestring, 554 consts, co.co_names, co.co_varnames, 555 co.co_filename, co.co_name, co.co_firstlineno, 556 co.co_lnotab) 557 else: 558 new_code = tp(co.co_argcount, co.co_kwonlyargcount, co.co_nlocals, 559 co.co_stacksize, co.co_flags, codestring, 560 consts, co.co_names, co.co_varnames, 561 co.co_filename, co.co_name, co.co_firstlineno, 562 co.co_lnotab) 563 func.__code__ = new_code 564 565 566_trashcan_dir = 'numba-tests' 567 568if os.name == 'nt': 569 # Under Windows, gettempdir() points to the user-local temp dir 570 _trashcan_dir = os.path.join(tempfile.gettempdir(), _trashcan_dir) 571else: 572 # Mix the UID into the directory name to allow different users to 573 # run the test suite without permission errors (issue #1586) 574 _trashcan_dir = os.path.join(tempfile.gettempdir(), 575 "%s.%s" % (_trashcan_dir, os.getuid())) 576 577# Stale temporary directories are deleted after they are older than this value. 578# The test suite probably won't ever take longer than this... 579_trashcan_timeout = 24 * 3600 # 1 day 580 581def _create_trashcan_dir(): 582 try: 583 os.mkdir(_trashcan_dir) 584 except OSError as e: 585 if e.errno != errno.EEXIST: 586 raise 587 588def _purge_trashcan_dir(): 589 freshness_threshold = time.time() - _trashcan_timeout 590 for fn in sorted(os.listdir(_trashcan_dir)): 591 fn = os.path.join(_trashcan_dir, fn) 592 try: 593 st = os.stat(fn) 594 if st.st_mtime < freshness_threshold: 595 shutil.rmtree(fn, ignore_errors=True) 596 except OSError as e: 597 # In parallel testing, several processes can attempt to 598 # remove the same entry at once, ignore. 599 pass 600 601def _create_trashcan_subdir(prefix): 602 _purge_trashcan_dir() 603 path = tempfile.mkdtemp(prefix=prefix + '-', dir=_trashcan_dir) 604 return path 605 606def temp_directory(prefix): 607 """ 608 Create a temporary directory with the given *prefix* that will survive 609 at least as long as this process invocation. The temporary directory 610 will be eventually deleted when it becomes stale enough. 611 612 This is necessary because a DLL file can't be deleted while in use 613 under Windows. 614 615 An interesting side-effect is to be able to inspect the test files 616 shortly after a test suite run. 617 """ 618 _create_trashcan_dir() 619 return _create_trashcan_subdir(prefix) 620 621 622def import_dynamic(modname): 623 """ 624 Import and return a module of the given name. Care is taken to 625 avoid issues due to Python's internal directory caching. 626 """ 627 import importlib 628 importlib.invalidate_caches() 629 __import__(modname) 630 return sys.modules[modname] 631 632 633# From CPython 634 635@contextlib.contextmanager 636def captured_output(stream_name): 637 """Return a context manager used by captured_stdout/stdin/stderr 638 that temporarily replaces the sys stream *stream_name* with a StringIO.""" 639 orig_stdout = getattr(sys, stream_name) 640 setattr(sys, stream_name, io.StringIO()) 641 try: 642 yield getattr(sys, stream_name) 643 finally: 644 setattr(sys, stream_name, orig_stdout) 645 646def captured_stdout(): 647 """Capture the output of sys.stdout: 648 649 with captured_stdout() as stdout: 650 print("hello") 651 self.assertEqual(stdout.getvalue(), "hello\n") 652 """ 653 return captured_output("stdout") 654 655def captured_stderr(): 656 """Capture the output of sys.stderr: 657 658 with captured_stderr() as stderr: 659 print("hello", file=sys.stderr) 660 self.assertEqual(stderr.getvalue(), "hello\n") 661 """ 662 return captured_output("stderr") 663 664 665@contextlib.contextmanager 666def capture_cache_log(): 667 with captured_stdout() as out: 668 with override_config('DEBUG_CACHE', True): 669 yield out 670 671 672class MemoryLeak(object): 673 674 __enable_leak_check = True 675 676 def memory_leak_setup(self): 677 # Clean up any NRT-backed objects hanging in a dead reference cycle 678 gc.collect() 679 self.__init_stats = rtsys.get_allocation_stats() 680 681 def memory_leak_teardown(self): 682 if self.__enable_leak_check: 683 self.assert_no_memory_leak() 684 685 def assert_no_memory_leak(self): 686 old = self.__init_stats 687 new = rtsys.get_allocation_stats() 688 total_alloc = new.alloc - old.alloc 689 total_free = new.free - old.free 690 total_mi_alloc = new.mi_alloc - old.mi_alloc 691 total_mi_free = new.mi_free - old.mi_free 692 self.assertEqual(total_alloc, total_free) 693 self.assertEqual(total_mi_alloc, total_mi_free) 694 695 def disable_leak_check(self): 696 # For per-test use when MemoryLeakMixin is injected into a TestCase 697 self.__enable_leak_check = False 698 699 700class MemoryLeakMixin(MemoryLeak): 701 702 def setUp(self): 703 super(MemoryLeakMixin, self).setUp() 704 self.memory_leak_setup() 705 706 def tearDown(self): 707 super(MemoryLeakMixin, self).tearDown() 708 gc.collect() 709 self.memory_leak_teardown() 710 711 712@contextlib.contextmanager 713def forbid_codegen(): 714 """ 715 Forbid LLVM code generation during the execution of the context 716 manager's enclosed block. 717 718 If code generation is invoked, a RuntimeError is raised. 719 """ 720 from numba.core import codegen 721 patchpoints = ['CodeLibrary._finalize_final_module'] 722 723 old = {} 724 def fail(*args, **kwargs): 725 raise RuntimeError("codegen forbidden by test case") 726 try: 727 # XXX use the mock library instead? 728 for name in patchpoints: 729 parts = name.split('.') 730 obj = codegen 731 for attrname in parts[:-1]: 732 obj = getattr(obj, attrname) 733 attrname = parts[-1] 734 value = getattr(obj, attrname) 735 assert callable(value), ("%r should be callable" % name) 736 old[obj, attrname] = value 737 setattr(obj, attrname, fail) 738 yield 739 finally: 740 for (obj, attrname), value in old.items(): 741 setattr(obj, attrname, value) 742 743 744# For details about redirection of file-descriptor, read 745# https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ 746 747@contextlib.contextmanager 748def redirect_fd(fd): 749 """ 750 Temporarily redirect *fd* to a pipe's write end and return a file object 751 wrapping the pipe's read end. 752 """ 753 754 from numba import _helperlib 755 libnumba = ctypes.CDLL(_helperlib.__file__) 756 757 libnumba._numba_flush_stdout() 758 save = os.dup(fd) 759 r, w = os.pipe() 760 try: 761 os.dup2(w, fd) 762 yield io.open(r, "r") 763 finally: 764 libnumba._numba_flush_stdout() 765 os.close(w) 766 os.dup2(save, fd) 767 os.close(save) 768 769 770def redirect_c_stdout(): 771 """Redirect C stdout 772 """ 773 fd = sys.__stdout__.fileno() 774 return redirect_fd(fd) 775 776 777def run_in_new_process_caching(func, cache_dir_prefix=__name__, verbose=True): 778 """Spawn a new process to run `func` with a temporary cache directory. 779 780 The childprocess's stdout and stderr will be captured and redirected to 781 the current process's stdout and stderr. 782 783 Returns 784 ------- 785 ret : dict 786 exitcode: 0 for success. 1 for exception-raised. 787 stdout: str 788 stderr: str 789 """ 790 cache_dir = temp_directory(cache_dir_prefix) 791 return run_in_new_process_in_cache_dir(func, cache_dir, verbose=verbose) 792 793 794def run_in_new_process_in_cache_dir(func, cache_dir, verbose=True): 795 """Spawn a new process to run `func` with a temporary cache directory. 796 797 The childprocess's stdout and stderr will be captured and redirected to 798 the current process's stdout and stderr. 799 800 Similar to ``run_in_new_process_caching()`` but the ``cache_dir`` is a 801 directory path instead of a name prefix for the directory path. 802 803 Returns 804 ------- 805 ret : dict 806 exitcode: 0 for success. 1 for exception-raised. 807 stdout: str 808 stderr: str 809 """ 810 ctx = mp.get_context('spawn') 811 qout = ctx.Queue() 812 with override_env_config('NUMBA_CACHE_DIR', cache_dir): 813 proc = ctx.Process(target=_remote_runner, args=[func, qout]) 814 proc.start() 815 proc.join() 816 stdout = qout.get_nowait() 817 stderr = qout.get_nowait() 818 if verbose and stdout.strip(): 819 print() 820 print('STDOUT'.center(80, '-')) 821 print(stdout) 822 if verbose and stderr.strip(): 823 print(file=sys.stderr) 824 print('STDERR'.center(80, '-'), file=sys.stderr) 825 print(stderr, file=sys.stderr) 826 return { 827 'exitcode': proc.exitcode, 828 'stdout': stdout, 829 'stderr': stderr, 830 } 831 832 833def _remote_runner(fn, qout): 834 """Used by `run_in_new_process_caching()` 835 """ 836 with captured_stderr() as stderr: 837 with captured_stdout() as stdout: 838 try: 839 fn() 840 except Exception: 841 traceback.print_exc() 842 exitcode = 1 843 else: 844 exitcode = 0 845 qout.put(stdout.getvalue()) 846 qout.put(stderr.getvalue()) 847 sys.exit(exitcode) 848 849class CheckWarningsMixin(object): 850 @contextlib.contextmanager 851 def check_warnings(self, messages, category=RuntimeWarning): 852 with warnings.catch_warnings(record=True) as catch: 853 warnings.simplefilter("always") 854 yield 855 found = 0 856 for w in catch: 857 for m in messages: 858 if m in str(w.message): 859 self.assertEqual(w.category, category) 860 found += 1 861 self.assertEqual(found, len(messages)) 862