1from contextlib import contextmanager
2import functools
3import operator
4import sys
5import warnings
6import numbers
7from collections import namedtuple
8import inspect
9import math
10from typing import (
11    Optional,
12    Union,
13    TYPE_CHECKING,
14    TypeVar,
15)
16
17import numpy as np
18
19IntNumber = Union[int, np.integer]
20DecimalNumber = Union[float, np.floating, np.integer]
21
22# Since Generator was introduced in numpy 1.17, the following condition is needed for
23# backward compatibility
24if TYPE_CHECKING:
25    SeedType = Optional[Union[IntNumber, np.random.Generator,
26                              np.random.RandomState]]
27    GeneratorType = TypeVar("GeneratorType", bound=Union[np.random.Generator,
28                                                         np.random.RandomState])
29
30try:
31    from numpy.random import Generator as Generator
32except ImportError:
33    class Generator():  # type: ignore[no-redef]
34        pass
35
36
37def _lazywhere(cond, arrays, f, fillvalue=None, f2=None):
38    """
39    np.where(cond, x, fillvalue) always evaluates x even where cond is False.
40    This one only evaluates f(arr1[cond], arr2[cond], ...).
41
42    Examples
43    --------
44    >>> a, b = np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8])
45    >>> def f(a, b):
46    ...     return a*b
47    >>> _lazywhere(a > 2, (a, b), f, np.nan)
48    array([ nan,  nan,  21.,  32.])
49
50    Notice, it assumes that all `arrays` are of the same shape, or can be
51    broadcasted together.
52
53    """
54    cond = np.asarray(cond)
55    if fillvalue is None:
56        if f2 is None:
57            raise ValueError("One of (fillvalue, f2) must be given.")
58        else:
59            fillvalue = np.nan
60    else:
61        if f2 is not None:
62            raise ValueError("Only one of (fillvalue, f2) can be given.")
63
64    args = np.broadcast_arrays(cond, *arrays)
65    cond,  arrays = args[0], args[1:]
66    temp = tuple(np.extract(cond, arr) for arr in arrays)
67    tcode = np.mintypecode([a.dtype.char for a in arrays])
68    out = np.full(np.shape(arrays[0]), fill_value=fillvalue, dtype=tcode)
69    np.place(out, cond, f(*temp))
70    if f2 is not None:
71        temp = tuple(np.extract(~cond, arr) for arr in arrays)
72        np.place(out, ~cond, f2(*temp))
73
74    return out
75
76
77def _lazyselect(condlist, choicelist, arrays, default=0):
78    """
79    Mimic `np.select(condlist, choicelist)`.
80
81    Notice, it assumes that all `arrays` are of the same shape or can be
82    broadcasted together.
83
84    All functions in `choicelist` must accept array arguments in the order
85    given in `arrays` and must return an array of the same shape as broadcasted
86    `arrays`.
87
88    Examples
89    --------
90    >>> x = np.arange(6)
91    >>> np.select([x <3, x > 3], [x**2, x**3], default=0)
92    array([  0,   1,   4,   0,  64, 125])
93
94    >>> _lazyselect([x < 3, x > 3], [lambda x: x**2, lambda x: x**3], (x,))
95    array([   0.,    1.,    4.,   0.,   64.,  125.])
96
97    >>> a = -np.ones_like(x)
98    >>> _lazyselect([x < 3, x > 3],
99    ...             [lambda x, a: x**2, lambda x, a: a * x**3],
100    ...             (x, a), default=np.nan)
101    array([   0.,    1.,    4.,   nan,  -64., -125.])
102
103    """
104    arrays = np.broadcast_arrays(*arrays)
105    tcode = np.mintypecode([a.dtype.char for a in arrays])
106    out = np.full(np.shape(arrays[0]), fill_value=default, dtype=tcode)
107    for index in range(len(condlist)):
108        func, cond = choicelist[index], condlist[index]
109        if np.all(cond is False):
110            continue
111        cond, _ = np.broadcast_arrays(cond, arrays[0])
112        temp = tuple(np.extract(cond, arr) for arr in arrays)
113        np.place(out, cond, func(*temp))
114    return out
115
116
117def _aligned_zeros(shape, dtype=float, order="C", align=None):
118    """Allocate a new ndarray with aligned memory.
119
120    Primary use case for this currently is working around a f2py issue
121    in NumPy 1.9.1, where dtype.alignment is such that np.zeros() does
122    not necessarily create arrays aligned up to it.
123
124    """
125    dtype = np.dtype(dtype)
126    if align is None:
127        align = dtype.alignment
128    if not hasattr(shape, '__len__'):
129        shape = (shape,)
130    size = functools.reduce(operator.mul, shape) * dtype.itemsize
131    buf = np.empty(size + align + 1, np.uint8)
132    offset = buf.__array_interface__['data'][0] % align
133    if offset != 0:
134        offset = align - offset
135    # Note: slices producing 0-size arrays do not necessarily change
136    # data pointer --- so we use and allocate size+1
137    buf = buf[offset:offset+size+1][:-1]
138    data = np.ndarray(shape, dtype, buf, order=order)
139    data.fill(0)
140    return data
141
142
143def _prune_array(array):
144    """Return an array equivalent to the input array. If the input
145    array is a view of a much larger array, copy its contents to a
146    newly allocated array. Otherwise, return the input unchanged.
147    """
148    if array.base is not None and array.size < array.base.size // 2:
149        return array.copy()
150    return array
151
152
153def prod(iterable):
154    """
155    Product of a sequence of numbers.
156
157    Faster than np.prod for short lists like array shapes, and does
158    not overflow if using Python integers.
159    """
160    product = 1
161    for x in iterable:
162        product *= x
163    return product
164
165
166def float_factorial(n: int) -> float:
167    """Compute the factorial and return as a float
168
169    Returns infinity when result is too large for a double
170    """
171    return float(math.factorial(n)) if n < 171 else np.inf
172
173
174class DeprecatedImport:
175    """
176    Deprecated import with redirection and warning.
177
178    Examples
179    --------
180    Suppose you previously had in some module::
181
182        from foo import spam
183
184    If this has to be deprecated, do::
185
186        spam = DeprecatedImport("foo.spam", "baz")
187
188    to redirect users to use "baz" module instead.
189
190    """
191
192    def __init__(self, old_module_name, new_module_name):
193        self._old_name = old_module_name
194        self._new_name = new_module_name
195        __import__(self._new_name)
196        self._mod = sys.modules[self._new_name]
197
198    def __dir__(self):
199        return dir(self._mod)
200
201    def __getattr__(self, name):
202        warnings.warn("Module %s is deprecated, use %s instead"
203                      % (self._old_name, self._new_name),
204                      DeprecationWarning)
205        return getattr(self._mod, name)
206
207
208# copy-pasted from scikit-learn utils/validation.py
209# change this to scipy.stats._qmc.check_random_state once numpy 1.16 is dropped
210def check_random_state(seed):
211    """Turn `seed` into a `np.random.RandomState` instance.
212
213    Parameters
214    ----------
215    seed : {None, int, `numpy.random.Generator`,
216            `numpy.random.RandomState`}, optional
217
218        If `seed` is None (or `np.random`), the `numpy.random.RandomState`
219        singleton is used.
220        If `seed` is an int, a new ``RandomState`` instance is used,
221        seeded with `seed`.
222        If `seed` is already a ``Generator`` or ``RandomState`` instance then
223        that instance is used.
224
225    Returns
226    -------
227    seed : {`numpy.random.Generator`, `numpy.random.RandomState`}
228        Random number generator.
229
230    """
231    if seed is None or seed is np.random:
232        return np.random.mtrand._rand
233    if isinstance(seed, (numbers.Integral, np.integer)):
234        return np.random.RandomState(seed)
235    if isinstance(seed, np.random.RandomState):
236        return seed
237    try:
238        # Generator is only available in numpy >= 1.17
239        if isinstance(seed, np.random.Generator):
240            return seed
241    except AttributeError:
242        pass
243    raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
244                     ' instance' % seed)
245
246
247def _asarray_validated(a, check_finite=True,
248                       sparse_ok=False, objects_ok=False, mask_ok=False,
249                       as_inexact=False):
250    """
251    Helper function for SciPy argument validation.
252
253    Many SciPy linear algebra functions do support arbitrary array-like
254    input arguments. Examples of commonly unsupported inputs include
255    matrices containing inf/nan, sparse matrix representations, and
256    matrices with complicated elements.
257
258    Parameters
259    ----------
260    a : array_like
261        The array-like input.
262    check_finite : bool, optional
263        Whether to check that the input matrices contain only finite numbers.
264        Disabling may give a performance gain, but may result in problems
265        (crashes, non-termination) if the inputs do contain infinities or NaNs.
266        Default: True
267    sparse_ok : bool, optional
268        True if scipy sparse matrices are allowed.
269    objects_ok : bool, optional
270        True if arrays with dype('O') are allowed.
271    mask_ok : bool, optional
272        True if masked arrays are allowed.
273    as_inexact : bool, optional
274        True to convert the input array to a np.inexact dtype.
275
276    Returns
277    -------
278    ret : ndarray
279        The converted validated array.
280
281    """
282    if not sparse_ok:
283        import scipy.sparse
284        if scipy.sparse.issparse(a):
285            msg = ('Sparse matrices are not supported by this function. '
286                   'Perhaps one of the scipy.sparse.linalg functions '
287                   'would work instead.')
288            raise ValueError(msg)
289    if not mask_ok:
290        if np.ma.isMaskedArray(a):
291            raise ValueError('masked arrays are not supported')
292    toarray = np.asarray_chkfinite if check_finite else np.asarray
293    a = toarray(a)
294    if not objects_ok:
295        if a.dtype is np.dtype('O'):
296            raise ValueError('object arrays are not supported')
297    if as_inexact:
298        if not np.issubdtype(a.dtype, np.inexact):
299            a = toarray(a, dtype=np.float_)
300    return a
301
302
303def _validate_int(k, name, minimum=None):
304    """
305    Validate a scalar integer.
306
307    This functon can be used to validate an argument to a function
308    that expects the value to be an integer.  It uses `operator.index`
309    to validate the value (so, for example, k=2.0 results in a
310    TypeError).
311
312    Parameters
313    ----------
314    k : int
315        The value to be validated.
316    name : str
317        The name of the parameter.
318    minimum : int, optional
319        An optional lower bound.
320    """
321    try:
322        k = operator.index(k)
323    except TypeError:
324        raise TypeError(f'{name} must be an integer.') from None
325    if minimum is not None and k < minimum:
326        raise ValueError(f'{name} must be an integer not less '
327                         f'than {minimum}') from None
328    return k
329
330
331# Add a replacement for inspect.getfullargspec()/
332# The version below is borrowed from Django,
333# https://github.com/django/django/pull/4846.
334
335# Note an inconsistency between inspect.getfullargspec(func) and
336# inspect.signature(func). If `func` is a bound method, the latter does *not*
337# list `self` as a first argument, while the former *does*.
338# Hence, cook up a common ground replacement: `getfullargspec_no_self` which
339# mimics `inspect.getfullargspec` but does not list `self`.
340#
341# This way, the caller code does not need to know whether it uses a legacy
342# .getfullargspec or a bright and shiny .signature.
343
344FullArgSpec = namedtuple('FullArgSpec',
345                         ['args', 'varargs', 'varkw', 'defaults',
346                          'kwonlyargs', 'kwonlydefaults', 'annotations'])
347
348
349def getfullargspec_no_self(func):
350    """inspect.getfullargspec replacement using inspect.signature.
351
352    If func is a bound method, do not list the 'self' parameter.
353
354    Parameters
355    ----------
356    func : callable
357        A callable to inspect
358
359    Returns
360    -------
361    fullargspec : FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
362                              kwonlydefaults, annotations)
363
364        NOTE: if the first argument of `func` is self, it is *not*, I repeat
365        *not*, included in fullargspec.args.
366        This is done for consistency between inspect.getargspec() under
367        Python 2.x, and inspect.signature() under Python 3.x.
368
369    """
370    sig = inspect.signature(func)
371    args = [
372        p.name for p in sig.parameters.values()
373        if p.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD,
374                      inspect.Parameter.POSITIONAL_ONLY]
375    ]
376    varargs = [
377        p.name for p in sig.parameters.values()
378        if p.kind == inspect.Parameter.VAR_POSITIONAL
379    ]
380    varargs = varargs[0] if varargs else None
381    varkw = [
382        p.name for p in sig.parameters.values()
383        if p.kind == inspect.Parameter.VAR_KEYWORD
384    ]
385    varkw = varkw[0] if varkw else None
386    defaults = tuple(
387        p.default for p in sig.parameters.values()
388        if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and
389            p.default is not p.empty)
390    ) or None
391    kwonlyargs = [
392        p.name for p in sig.parameters.values()
393        if p.kind == inspect.Parameter.KEYWORD_ONLY
394    ]
395    kwdefaults = {p.name: p.default for p in sig.parameters.values()
396                  if p.kind == inspect.Parameter.KEYWORD_ONLY and
397                  p.default is not p.empty}
398    annotations = {p.name: p.annotation for p in sig.parameters.values()
399                   if p.annotation is not p.empty}
400    return FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
401                       kwdefaults or None, annotations)
402
403
404class MapWrapper:
405    """
406    Parallelisation wrapper for working with map-like callables, such as
407    `multiprocessing.Pool.map`.
408
409    Parameters
410    ----------
411    pool : int or map-like callable
412        If `pool` is an integer, then it specifies the number of threads to
413        use for parallelization. If ``int(pool) == 1``, then no parallel
414        processing is used and the map builtin is used.
415        If ``pool == -1``, then the pool will utilize all available CPUs.
416        If `pool` is a map-like callable that follows the same
417        calling sequence as the built-in map function, then this callable is
418        used for parallelization.
419    """
420    def __init__(self, pool=1):
421        self.pool = None
422        self._mapfunc = map
423        self._own_pool = False
424
425        if callable(pool):
426            self.pool = pool
427            self._mapfunc = self.pool
428        else:
429            from multiprocessing import Pool
430            # user supplies a number
431            if int(pool) == -1:
432                # use as many processors as possible
433                self.pool = Pool()
434                self._mapfunc = self.pool.map
435                self._own_pool = True
436            elif int(pool) == 1:
437                pass
438            elif int(pool) > 1:
439                # use the number of processors requested
440                self.pool = Pool(processes=int(pool))
441                self._mapfunc = self.pool.map
442                self._own_pool = True
443            else:
444                raise RuntimeError("Number of workers specified must be -1,"
445                                   " an int >= 1, or an object with a 'map' "
446                                   "method")
447
448    def __enter__(self):
449        return self
450
451    def terminate(self):
452        if self._own_pool:
453            self.pool.terminate()
454
455    def join(self):
456        if self._own_pool:
457            self.pool.join()
458
459    def close(self):
460        if self._own_pool:
461            self.pool.close()
462
463    def __exit__(self, exc_type, exc_value, traceback):
464        if self._own_pool:
465            self.pool.close()
466            self.pool.terminate()
467
468    def __call__(self, func, iterable):
469        # only accept one iterable because that's all Pool.map accepts
470        try:
471            return self._mapfunc(func, iterable)
472        except TypeError as e:
473            # wrong number of arguments
474            raise TypeError("The map-like callable must be of the"
475                            " form f(func, iterable)") from e
476
477
478def rng_integers(gen, low, high=None, size=None, dtype='int64',
479                 endpoint=False):
480    """
481    Return random integers from low (inclusive) to high (exclusive), or if
482    endpoint=True, low (inclusive) to high (inclusive). Replaces
483    `RandomState.randint` (with endpoint=False) and
484    `RandomState.random_integers` (with endpoint=True).
485
486    Return random integers from the "discrete uniform" distribution of the
487    specified dtype. If high is None (the default), then results are from
488    0 to low.
489
490    Parameters
491    ----------
492    gen : {None, np.random.RandomState, np.random.Generator}
493        Random number generator. If None, then the np.random.RandomState
494        singleton is used.
495    low : int or array-like of ints
496        Lowest (signed) integers to be drawn from the distribution (unless
497        high=None, in which case this parameter is 0 and this value is used
498        for high).
499    high : int or array-like of ints
500        If provided, one above the largest (signed) integer to be drawn from
501        the distribution (see above for behavior if high=None). If array-like,
502        must contain integer values.
503    size : array-like of ints, optional
504        Output shape. If the given shape is, e.g., (m, n, k), then m * n * k
505        samples are drawn. Default is None, in which case a single value is
506        returned.
507    dtype : {str, dtype}, optional
508        Desired dtype of the result. All dtypes are determined by their name,
509        i.e., 'int64', 'int', etc, so byteorder is not available and a specific
510        precision may have different C types depending on the platform.
511        The default value is np.int_.
512    endpoint : bool, optional
513        If True, sample from the interval [low, high] instead of the default
514        [low, high) Defaults to False.
515
516    Returns
517    -------
518    out: int or ndarray of ints
519        size-shaped array of random integers from the appropriate distribution,
520        or a single such random int if size not provided.
521    """
522    if isinstance(gen, Generator):
523        return gen.integers(low, high=high, size=size, dtype=dtype,
524                            endpoint=endpoint)
525    else:
526        if gen is None:
527            # default is RandomState singleton used by np.random.
528            gen = np.random.mtrand._rand
529        if endpoint:
530            # inclusive of endpoint
531            # remember that low and high can be arrays, so don't modify in
532            # place
533            if high is None:
534                return gen.randint(low + 1, size=size, dtype=dtype)
535            if high is not None:
536                return gen.randint(low, high=high + 1, size=size, dtype=dtype)
537
538        # exclusive
539        return gen.randint(low, high=high, size=size, dtype=dtype)
540
541
542@contextmanager
543def _fixed_default_rng(seed=1638083107694713882823079058616272161):
544    """Context with a fixed np.random.default_rng seed."""
545    orig_fun = np.random.default_rng
546    np.random.default_rng = lambda seed=seed: orig_fun(seed)
547    try:
548        yield
549    finally:
550        np.random.default_rng = orig_fun
551