1from numbers import Number
2import operator
3import os
4import threading
5import contextlib
6
7import numpy as np
8# good_size is exposed (and used) from this import
9from .pypocketfft import good_size
10
11_config = threading.local()
12_cpu_count = os.cpu_count()
13
14
15def _iterable_of_int(x, name=None):
16    """Convert ``x`` to an iterable sequence of int
17
18    Parameters
19    ----------
20    x : value, or sequence of values, convertible to int
21    name : str, optional
22        Name of the argument being converted, only used in the error message
23
24    Returns
25    -------
26    y : ``List[int]``
27    """
28    if isinstance(x, Number):
29        x = (x,)
30
31    try:
32        x = [operator.index(a) for a in x]
33    except TypeError as e:
34        name = name or "value"
35        raise ValueError("{} must be a scalar or iterable of integers"
36                         .format(name)) from e
37
38    return x
39
40
41def _init_nd_shape_and_axes(x, shape, axes):
42    """Handles shape and axes arguments for nd transforms"""
43    noshape = shape is None
44    noaxes = axes is None
45
46    if not noaxes:
47        axes = _iterable_of_int(axes, 'axes')
48        axes = [a + x.ndim if a < 0 else a for a in axes]
49
50        if any(a >= x.ndim or a < 0 for a in axes):
51            raise ValueError("axes exceeds dimensionality of input")
52        if len(set(axes)) != len(axes):
53            raise ValueError("all axes must be unique")
54
55    if not noshape:
56        shape = _iterable_of_int(shape, 'shape')
57
58        if axes and len(axes) != len(shape):
59            raise ValueError("when given, axes and shape arguments"
60                             " have to be of the same length")
61        if noaxes:
62            if len(shape) > x.ndim:
63                raise ValueError("shape requires more axes than are present")
64            axes = range(x.ndim - len(shape), x.ndim)
65
66        shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)]
67    elif noaxes:
68        shape = list(x.shape)
69        axes = range(x.ndim)
70    else:
71        shape = [x.shape[a] for a in axes]
72
73    if any(s < 1 for s in shape):
74        raise ValueError(
75            "invalid number of data points ({0}) specified".format(shape))
76
77    return shape, axes
78
79
80def _asfarray(x):
81    """
82    Convert to array with floating or complex dtype.
83
84    float16 values are also promoted to float32.
85    """
86    if not hasattr(x, "dtype"):
87        x = np.asarray(x)
88
89    if x.dtype == np.float16:
90        return np.asarray(x, np.float32)
91    elif x.dtype.kind not in 'fc':
92        return np.asarray(x, np.float64)
93
94    # Require native byte order
95    dtype = x.dtype.newbyteorder('=')
96    # Always align input
97    copy = not x.flags['ALIGNED']
98    return np.array(x, dtype=dtype, copy=copy)
99
100def _datacopied(arr, original):
101    """
102    Strict check for `arr` not sharing any data with `original`,
103    under the assumption that arr = asarray(original)
104    """
105    if arr is original:
106        return False
107    if not isinstance(original, np.ndarray) and hasattr(original, '__array__'):
108        return False
109    return arr.base is None
110
111
112def _fix_shape(x, shape, axes):
113    """Internal auxiliary function for _raw_fft, _raw_fftnd."""
114    must_copy = False
115
116    # Build an nd slice with the dimensions to be read from x
117    index = [slice(None)]*x.ndim
118    for n, ax in zip(shape, axes):
119        if x.shape[ax] >= n:
120            index[ax] = slice(0, n)
121        else:
122            index[ax] = slice(0, x.shape[ax])
123            must_copy = True
124
125    index = tuple(index)
126
127    if not must_copy:
128        return x[index], False
129
130    s = list(x.shape)
131    for n, axis in zip(shape, axes):
132        s[axis] = n
133
134    z = np.zeros(s, x.dtype)
135    z[index] = x[index]
136    return z, True
137
138
139def _fix_shape_1d(x, n, axis):
140    if n < 1:
141        raise ValueError(
142            "invalid number of data points ({0}) specified".format(n))
143
144    return _fix_shape(x, (n,), (axis,))
145
146
147_NORM_MAP = {None: 0, 'backward': 0, 'ortho': 1, 'forward': 2}
148
149
150def _normalization(norm, forward):
151    """Returns the pypocketfft normalization mode from the norm argument"""
152    try:
153        inorm = _NORM_MAP[norm]
154        return inorm if forward else (2 - inorm)
155    except KeyError:
156        raise ValueError(
157            f'Invalid norm value {norm!r}, should '
158            'be "backward", "ortho" or "forward"') from None
159
160
161def _workers(workers):
162    if workers is None:
163        return getattr(_config, 'default_workers', 1)
164
165    if workers < 0:
166        if workers >= -_cpu_count:
167            workers += 1 + _cpu_count
168        else:
169            raise ValueError("workers value out of range; got {}, must not be"
170                             " less than {}".format(workers, -_cpu_count))
171    elif workers == 0:
172        raise ValueError("workers must not be zero")
173
174    return workers
175
176
177@contextlib.contextmanager
178def set_workers(workers):
179    """Context manager for the default number of workers used in `scipy.fft`
180
181    Parameters
182    ----------
183    workers : int
184        The default number of workers to use
185
186    Examples
187    --------
188    >>> from scipy import fft, signal
189    >>> rng = np.random.default_rng()
190    >>> x = rng.standard_normal((128, 64))
191    >>> with fft.set_workers(4):
192    ...     y = signal.fftconvolve(x, x)
193
194    """
195    old_workers = get_workers()
196    _config.default_workers = _workers(operator.index(workers))
197    try:
198        yield
199    finally:
200        _config.default_workers = old_workers
201
202
203def get_workers():
204    """Returns the default number of workers within the current context
205
206    Examples
207    --------
208    >>> from scipy import fft
209    >>> fft.get_workers()
210    1
211    >>> with fft.set_workers(4):
212    ...     fft.get_workers()
213    4
214    """
215    return getattr(_config, 'default_workers', 1)
216