1from contextlib import contextmanager 2import functools 3import operator 4import sys 5import warnings 6import numbers 7from collections import namedtuple 8import inspect 9import math 10from typing import ( 11 Optional, 12 Union, 13 TYPE_CHECKING, 14 TypeVar, 15) 16 17import numpy as np 18 19IntNumber = Union[int, np.integer] 20DecimalNumber = Union[float, np.floating, np.integer] 21 22# Since Generator was introduced in numpy 1.17, the following condition is needed for 23# backward compatibility 24if TYPE_CHECKING: 25 SeedType = Optional[Union[IntNumber, np.random.Generator, 26 np.random.RandomState]] 27 GeneratorType = TypeVar("GeneratorType", bound=Union[np.random.Generator, 28 np.random.RandomState]) 29 30try: 31 from numpy.random import Generator as Generator 32except ImportError: 33 class Generator(): # type: ignore[no-redef] 34 pass 35 36 37def _lazywhere(cond, arrays, f, fillvalue=None, f2=None): 38 """ 39 np.where(cond, x, fillvalue) always evaluates x even where cond is False. 40 This one only evaluates f(arr1[cond], arr2[cond], ...). 41 42 Examples 43 -------- 44 >>> a, b = np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]) 45 >>> def f(a, b): 46 ... return a*b 47 >>> _lazywhere(a > 2, (a, b), f, np.nan) 48 array([ nan, nan, 21., 32.]) 49 50 Notice, it assumes that all `arrays` are of the same shape, or can be 51 broadcasted together. 52 53 """ 54 cond = np.asarray(cond) 55 if fillvalue is None: 56 if f2 is None: 57 raise ValueError("One of (fillvalue, f2) must be given.") 58 else: 59 fillvalue = np.nan 60 else: 61 if f2 is not None: 62 raise ValueError("Only one of (fillvalue, f2) can be given.") 63 64 args = np.broadcast_arrays(cond, *arrays) 65 cond, arrays = args[0], args[1:] 66 temp = tuple(np.extract(cond, arr) for arr in arrays) 67 tcode = np.mintypecode([a.dtype.char for a in arrays]) 68 out = np.full(np.shape(arrays[0]), fill_value=fillvalue, dtype=tcode) 69 np.place(out, cond, f(*temp)) 70 if f2 is not None: 71 temp = tuple(np.extract(~cond, arr) for arr in arrays) 72 np.place(out, ~cond, f2(*temp)) 73 74 return out 75 76 77def _lazyselect(condlist, choicelist, arrays, default=0): 78 """ 79 Mimic `np.select(condlist, choicelist)`. 80 81 Notice, it assumes that all `arrays` are of the same shape or can be 82 broadcasted together. 83 84 All functions in `choicelist` must accept array arguments in the order 85 given in `arrays` and must return an array of the same shape as broadcasted 86 `arrays`. 87 88 Examples 89 -------- 90 >>> x = np.arange(6) 91 >>> np.select([x <3, x > 3], [x**2, x**3], default=0) 92 array([ 0, 1, 4, 0, 64, 125]) 93 94 >>> _lazyselect([x < 3, x > 3], [lambda x: x**2, lambda x: x**3], (x,)) 95 array([ 0., 1., 4., 0., 64., 125.]) 96 97 >>> a = -np.ones_like(x) 98 >>> _lazyselect([x < 3, x > 3], 99 ... [lambda x, a: x**2, lambda x, a: a * x**3], 100 ... (x, a), default=np.nan) 101 array([ 0., 1., 4., nan, -64., -125.]) 102 103 """ 104 arrays = np.broadcast_arrays(*arrays) 105 tcode = np.mintypecode([a.dtype.char for a in arrays]) 106 out = np.full(np.shape(arrays[0]), fill_value=default, dtype=tcode) 107 for index in range(len(condlist)): 108 func, cond = choicelist[index], condlist[index] 109 if np.all(cond is False): 110 continue 111 cond, _ = np.broadcast_arrays(cond, arrays[0]) 112 temp = tuple(np.extract(cond, arr) for arr in arrays) 113 np.place(out, cond, func(*temp)) 114 return out 115 116 117def _aligned_zeros(shape, dtype=float, order="C", align=None): 118 """Allocate a new ndarray with aligned memory. 119 120 Primary use case for this currently is working around a f2py issue 121 in NumPy 1.9.1, where dtype.alignment is such that np.zeros() does 122 not necessarily create arrays aligned up to it. 123 124 """ 125 dtype = np.dtype(dtype) 126 if align is None: 127 align = dtype.alignment 128 if not hasattr(shape, '__len__'): 129 shape = (shape,) 130 size = functools.reduce(operator.mul, shape) * dtype.itemsize 131 buf = np.empty(size + align + 1, np.uint8) 132 offset = buf.__array_interface__['data'][0] % align 133 if offset != 0: 134 offset = align - offset 135 # Note: slices producing 0-size arrays do not necessarily change 136 # data pointer --- so we use and allocate size+1 137 buf = buf[offset:offset+size+1][:-1] 138 data = np.ndarray(shape, dtype, buf, order=order) 139 data.fill(0) 140 return data 141 142 143def _prune_array(array): 144 """Return an array equivalent to the input array. If the input 145 array is a view of a much larger array, copy its contents to a 146 newly allocated array. Otherwise, return the input unchanged. 147 """ 148 if array.base is not None and array.size < array.base.size // 2: 149 return array.copy() 150 return array 151 152 153def prod(iterable): 154 """ 155 Product of a sequence of numbers. 156 157 Faster than np.prod for short lists like array shapes, and does 158 not overflow if using Python integers. 159 """ 160 product = 1 161 for x in iterable: 162 product *= x 163 return product 164 165 166def float_factorial(n: int) -> float: 167 """Compute the factorial and return as a float 168 169 Returns infinity when result is too large for a double 170 """ 171 return float(math.factorial(n)) if n < 171 else np.inf 172 173 174class DeprecatedImport: 175 """ 176 Deprecated import with redirection and warning. 177 178 Examples 179 -------- 180 Suppose you previously had in some module:: 181 182 from foo import spam 183 184 If this has to be deprecated, do:: 185 186 spam = DeprecatedImport("foo.spam", "baz") 187 188 to redirect users to use "baz" module instead. 189 190 """ 191 192 def __init__(self, old_module_name, new_module_name): 193 self._old_name = old_module_name 194 self._new_name = new_module_name 195 __import__(self._new_name) 196 self._mod = sys.modules[self._new_name] 197 198 def __dir__(self): 199 return dir(self._mod) 200 201 def __getattr__(self, name): 202 warnings.warn("Module %s is deprecated, use %s instead" 203 % (self._old_name, self._new_name), 204 DeprecationWarning) 205 return getattr(self._mod, name) 206 207 208# copy-pasted from scikit-learn utils/validation.py 209# change this to scipy.stats._qmc.check_random_state once numpy 1.16 is dropped 210def check_random_state(seed): 211 """Turn `seed` into a `np.random.RandomState` instance. 212 213 Parameters 214 ---------- 215 seed : {None, int, `numpy.random.Generator`, 216 `numpy.random.RandomState`}, optional 217 218 If `seed` is None (or `np.random`), the `numpy.random.RandomState` 219 singleton is used. 220 If `seed` is an int, a new ``RandomState`` instance is used, 221 seeded with `seed`. 222 If `seed` is already a ``Generator`` or ``RandomState`` instance then 223 that instance is used. 224 225 Returns 226 ------- 227 seed : {`numpy.random.Generator`, `numpy.random.RandomState`} 228 Random number generator. 229 230 """ 231 if seed is None or seed is np.random: 232 return np.random.mtrand._rand 233 if isinstance(seed, (numbers.Integral, np.integer)): 234 return np.random.RandomState(seed) 235 if isinstance(seed, np.random.RandomState): 236 return seed 237 try: 238 # Generator is only available in numpy >= 1.17 239 if isinstance(seed, np.random.Generator): 240 return seed 241 except AttributeError: 242 pass 243 raise ValueError('%r cannot be used to seed a numpy.random.RandomState' 244 ' instance' % seed) 245 246 247def _asarray_validated(a, check_finite=True, 248 sparse_ok=False, objects_ok=False, mask_ok=False, 249 as_inexact=False): 250 """ 251 Helper function for SciPy argument validation. 252 253 Many SciPy linear algebra functions do support arbitrary array-like 254 input arguments. Examples of commonly unsupported inputs include 255 matrices containing inf/nan, sparse matrix representations, and 256 matrices with complicated elements. 257 258 Parameters 259 ---------- 260 a : array_like 261 The array-like input. 262 check_finite : bool, optional 263 Whether to check that the input matrices contain only finite numbers. 264 Disabling may give a performance gain, but may result in problems 265 (crashes, non-termination) if the inputs do contain infinities or NaNs. 266 Default: True 267 sparse_ok : bool, optional 268 True if scipy sparse matrices are allowed. 269 objects_ok : bool, optional 270 True if arrays with dype('O') are allowed. 271 mask_ok : bool, optional 272 True if masked arrays are allowed. 273 as_inexact : bool, optional 274 True to convert the input array to a np.inexact dtype. 275 276 Returns 277 ------- 278 ret : ndarray 279 The converted validated array. 280 281 """ 282 if not sparse_ok: 283 import scipy.sparse 284 if scipy.sparse.issparse(a): 285 msg = ('Sparse matrices are not supported by this function. ' 286 'Perhaps one of the scipy.sparse.linalg functions ' 287 'would work instead.') 288 raise ValueError(msg) 289 if not mask_ok: 290 if np.ma.isMaskedArray(a): 291 raise ValueError('masked arrays are not supported') 292 toarray = np.asarray_chkfinite if check_finite else np.asarray 293 a = toarray(a) 294 if not objects_ok: 295 if a.dtype is np.dtype('O'): 296 raise ValueError('object arrays are not supported') 297 if as_inexact: 298 if not np.issubdtype(a.dtype, np.inexact): 299 a = toarray(a, dtype=np.float_) 300 return a 301 302 303def _validate_int(k, name, minimum=None): 304 """ 305 Validate a scalar integer. 306 307 This functon can be used to validate an argument to a function 308 that expects the value to be an integer. It uses `operator.index` 309 to validate the value (so, for example, k=2.0 results in a 310 TypeError). 311 312 Parameters 313 ---------- 314 k : int 315 The value to be validated. 316 name : str 317 The name of the parameter. 318 minimum : int, optional 319 An optional lower bound. 320 """ 321 try: 322 k = operator.index(k) 323 except TypeError: 324 raise TypeError(f'{name} must be an integer.') from None 325 if minimum is not None and k < minimum: 326 raise ValueError(f'{name} must be an integer not less ' 327 f'than {minimum}') from None 328 return k 329 330 331# Add a replacement for inspect.getfullargspec()/ 332# The version below is borrowed from Django, 333# https://github.com/django/django/pull/4846. 334 335# Note an inconsistency between inspect.getfullargspec(func) and 336# inspect.signature(func). If `func` is a bound method, the latter does *not* 337# list `self` as a first argument, while the former *does*. 338# Hence, cook up a common ground replacement: `getfullargspec_no_self` which 339# mimics `inspect.getfullargspec` but does not list `self`. 340# 341# This way, the caller code does not need to know whether it uses a legacy 342# .getfullargspec or a bright and shiny .signature. 343 344FullArgSpec = namedtuple('FullArgSpec', 345 ['args', 'varargs', 'varkw', 'defaults', 346 'kwonlyargs', 'kwonlydefaults', 'annotations']) 347 348 349def getfullargspec_no_self(func): 350 """inspect.getfullargspec replacement using inspect.signature. 351 352 If func is a bound method, do not list the 'self' parameter. 353 354 Parameters 355 ---------- 356 func : callable 357 A callable to inspect 358 359 Returns 360 ------- 361 fullargspec : FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, 362 kwonlydefaults, annotations) 363 364 NOTE: if the first argument of `func` is self, it is *not*, I repeat 365 *not*, included in fullargspec.args. 366 This is done for consistency between inspect.getargspec() under 367 Python 2.x, and inspect.signature() under Python 3.x. 368 369 """ 370 sig = inspect.signature(func) 371 args = [ 372 p.name for p in sig.parameters.values() 373 if p.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD, 374 inspect.Parameter.POSITIONAL_ONLY] 375 ] 376 varargs = [ 377 p.name for p in sig.parameters.values() 378 if p.kind == inspect.Parameter.VAR_POSITIONAL 379 ] 380 varargs = varargs[0] if varargs else None 381 varkw = [ 382 p.name for p in sig.parameters.values() 383 if p.kind == inspect.Parameter.VAR_KEYWORD 384 ] 385 varkw = varkw[0] if varkw else None 386 defaults = tuple( 387 p.default for p in sig.parameters.values() 388 if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and 389 p.default is not p.empty) 390 ) or None 391 kwonlyargs = [ 392 p.name for p in sig.parameters.values() 393 if p.kind == inspect.Parameter.KEYWORD_ONLY 394 ] 395 kwdefaults = {p.name: p.default for p in sig.parameters.values() 396 if p.kind == inspect.Parameter.KEYWORD_ONLY and 397 p.default is not p.empty} 398 annotations = {p.name: p.annotation for p in sig.parameters.values() 399 if p.annotation is not p.empty} 400 return FullArgSpec(args, varargs, varkw, defaults, kwonlyargs, 401 kwdefaults or None, annotations) 402 403 404class MapWrapper: 405 """ 406 Parallelisation wrapper for working with map-like callables, such as 407 `multiprocessing.Pool.map`. 408 409 Parameters 410 ---------- 411 pool : int or map-like callable 412 If `pool` is an integer, then it specifies the number of threads to 413 use for parallelization. If ``int(pool) == 1``, then no parallel 414 processing is used and the map builtin is used. 415 If ``pool == -1``, then the pool will utilize all available CPUs. 416 If `pool` is a map-like callable that follows the same 417 calling sequence as the built-in map function, then this callable is 418 used for parallelization. 419 """ 420 def __init__(self, pool=1): 421 self.pool = None 422 self._mapfunc = map 423 self._own_pool = False 424 425 if callable(pool): 426 self.pool = pool 427 self._mapfunc = self.pool 428 else: 429 from multiprocessing import Pool 430 # user supplies a number 431 if int(pool) == -1: 432 # use as many processors as possible 433 self.pool = Pool() 434 self._mapfunc = self.pool.map 435 self._own_pool = True 436 elif int(pool) == 1: 437 pass 438 elif int(pool) > 1: 439 # use the number of processors requested 440 self.pool = Pool(processes=int(pool)) 441 self._mapfunc = self.pool.map 442 self._own_pool = True 443 else: 444 raise RuntimeError("Number of workers specified must be -1," 445 " an int >= 1, or an object with a 'map' " 446 "method") 447 448 def __enter__(self): 449 return self 450 451 def terminate(self): 452 if self._own_pool: 453 self.pool.terminate() 454 455 def join(self): 456 if self._own_pool: 457 self.pool.join() 458 459 def close(self): 460 if self._own_pool: 461 self.pool.close() 462 463 def __exit__(self, exc_type, exc_value, traceback): 464 if self._own_pool: 465 self.pool.close() 466 self.pool.terminate() 467 468 def __call__(self, func, iterable): 469 # only accept one iterable because that's all Pool.map accepts 470 try: 471 return self._mapfunc(func, iterable) 472 except TypeError as e: 473 # wrong number of arguments 474 raise TypeError("The map-like callable must be of the" 475 " form f(func, iterable)") from e 476 477 478def rng_integers(gen, low, high=None, size=None, dtype='int64', 479 endpoint=False): 480 """ 481 Return random integers from low (inclusive) to high (exclusive), or if 482 endpoint=True, low (inclusive) to high (inclusive). Replaces 483 `RandomState.randint` (with endpoint=False) and 484 `RandomState.random_integers` (with endpoint=True). 485 486 Return random integers from the "discrete uniform" distribution of the 487 specified dtype. If high is None (the default), then results are from 488 0 to low. 489 490 Parameters 491 ---------- 492 gen : {None, np.random.RandomState, np.random.Generator} 493 Random number generator. If None, then the np.random.RandomState 494 singleton is used. 495 low : int or array-like of ints 496 Lowest (signed) integers to be drawn from the distribution (unless 497 high=None, in which case this parameter is 0 and this value is used 498 for high). 499 high : int or array-like of ints 500 If provided, one above the largest (signed) integer to be drawn from 501 the distribution (see above for behavior if high=None). If array-like, 502 must contain integer values. 503 size : array-like of ints, optional 504 Output shape. If the given shape is, e.g., (m, n, k), then m * n * k 505 samples are drawn. Default is None, in which case a single value is 506 returned. 507 dtype : {str, dtype}, optional 508 Desired dtype of the result. All dtypes are determined by their name, 509 i.e., 'int64', 'int', etc, so byteorder is not available and a specific 510 precision may have different C types depending on the platform. 511 The default value is np.int_. 512 endpoint : bool, optional 513 If True, sample from the interval [low, high] instead of the default 514 [low, high) Defaults to False. 515 516 Returns 517 ------- 518 out: int or ndarray of ints 519 size-shaped array of random integers from the appropriate distribution, 520 or a single such random int if size not provided. 521 """ 522 if isinstance(gen, Generator): 523 return gen.integers(low, high=high, size=size, dtype=dtype, 524 endpoint=endpoint) 525 else: 526 if gen is None: 527 # default is RandomState singleton used by np.random. 528 gen = np.random.mtrand._rand 529 if endpoint: 530 # inclusive of endpoint 531 # remember that low and high can be arrays, so don't modify in 532 # place 533 if high is None: 534 return gen.randint(low + 1, size=size, dtype=dtype) 535 if high is not None: 536 return gen.randint(low, high=high + 1, size=size, dtype=dtype) 537 538 # exclusive 539 return gen.randint(low, high=high, size=size, dtype=dtype) 540 541 542@contextmanager 543def _fixed_default_rng(seed=1638083107694713882823079058616272161): 544 """Context with a fixed np.random.default_rng seed.""" 545 orig_fun = np.random.default_rng 546 np.random.default_rng = lambda seed=seed: orig_fun(seed) 547 try: 548 yield 549 finally: 550 np.random.default_rng = orig_fun 551