1from numbers import Number 2import operator 3import os 4import threading 5import contextlib 6 7import numpy as np 8# good_size is exposed (and used) from this import 9from .pypocketfft import good_size 10 11_config = threading.local() 12_cpu_count = os.cpu_count() 13 14 15def _iterable_of_int(x, name=None): 16 """Convert ``x`` to an iterable sequence of int 17 18 Parameters 19 ---------- 20 x : value, or sequence of values, convertible to int 21 name : str, optional 22 Name of the argument being converted, only used in the error message 23 24 Returns 25 ------- 26 y : ``List[int]`` 27 """ 28 if isinstance(x, Number): 29 x = (x,) 30 31 try: 32 x = [operator.index(a) for a in x] 33 except TypeError as e: 34 name = name or "value" 35 raise ValueError("{} must be a scalar or iterable of integers" 36 .format(name)) from e 37 38 return x 39 40 41def _init_nd_shape_and_axes(x, shape, axes): 42 """Handles shape and axes arguments for nd transforms""" 43 noshape = shape is None 44 noaxes = axes is None 45 46 if not noaxes: 47 axes = _iterable_of_int(axes, 'axes') 48 axes = [a + x.ndim if a < 0 else a for a in axes] 49 50 if any(a >= x.ndim or a < 0 for a in axes): 51 raise ValueError("axes exceeds dimensionality of input") 52 if len(set(axes)) != len(axes): 53 raise ValueError("all axes must be unique") 54 55 if not noshape: 56 shape = _iterable_of_int(shape, 'shape') 57 58 if axes and len(axes) != len(shape): 59 raise ValueError("when given, axes and shape arguments" 60 " have to be of the same length") 61 if noaxes: 62 if len(shape) > x.ndim: 63 raise ValueError("shape requires more axes than are present") 64 axes = range(x.ndim - len(shape), x.ndim) 65 66 shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)] 67 elif noaxes: 68 shape = list(x.shape) 69 axes = range(x.ndim) 70 else: 71 shape = [x.shape[a] for a in axes] 72 73 if any(s < 1 for s in shape): 74 raise ValueError( 75 "invalid number of data points ({0}) specified".format(shape)) 76 77 return shape, axes 78 79 80def _asfarray(x): 81 """ 82 Convert to array with floating or complex dtype. 83 84 float16 values are also promoted to float32. 85 """ 86 if not hasattr(x, "dtype"): 87 x = np.asarray(x) 88 89 if x.dtype == np.float16: 90 return np.asarray(x, np.float32) 91 elif x.dtype.kind not in 'fc': 92 return np.asarray(x, np.float64) 93 94 # Require native byte order 95 dtype = x.dtype.newbyteorder('=') 96 # Always align input 97 copy = not x.flags['ALIGNED'] 98 return np.array(x, dtype=dtype, copy=copy) 99 100def _datacopied(arr, original): 101 """ 102 Strict check for `arr` not sharing any data with `original`, 103 under the assumption that arr = asarray(original) 104 """ 105 if arr is original: 106 return False 107 if not isinstance(original, np.ndarray) and hasattr(original, '__array__'): 108 return False 109 return arr.base is None 110 111 112def _fix_shape(x, shape, axes): 113 """Internal auxiliary function for _raw_fft, _raw_fftnd.""" 114 must_copy = False 115 116 # Build an nd slice with the dimensions to be read from x 117 index = [slice(None)]*x.ndim 118 for n, ax in zip(shape, axes): 119 if x.shape[ax] >= n: 120 index[ax] = slice(0, n) 121 else: 122 index[ax] = slice(0, x.shape[ax]) 123 must_copy = True 124 125 index = tuple(index) 126 127 if not must_copy: 128 return x[index], False 129 130 s = list(x.shape) 131 for n, axis in zip(shape, axes): 132 s[axis] = n 133 134 z = np.zeros(s, x.dtype) 135 z[index] = x[index] 136 return z, True 137 138 139def _fix_shape_1d(x, n, axis): 140 if n < 1: 141 raise ValueError( 142 "invalid number of data points ({0}) specified".format(n)) 143 144 return _fix_shape(x, (n,), (axis,)) 145 146 147_NORM_MAP = {None: 0, 'backward': 0, 'ortho': 1, 'forward': 2} 148 149 150def _normalization(norm, forward): 151 """Returns the pypocketfft normalization mode from the norm argument""" 152 try: 153 inorm = _NORM_MAP[norm] 154 return inorm if forward else (2 - inorm) 155 except KeyError: 156 raise ValueError( 157 f'Invalid norm value {norm!r}, should ' 158 'be "backward", "ortho" or "forward"') from None 159 160 161def _workers(workers): 162 if workers is None: 163 return getattr(_config, 'default_workers', 1) 164 165 if workers < 0: 166 if workers >= -_cpu_count: 167 workers += 1 + _cpu_count 168 else: 169 raise ValueError("workers value out of range; got {}, must not be" 170 " less than {}".format(workers, -_cpu_count)) 171 elif workers == 0: 172 raise ValueError("workers must not be zero") 173 174 return workers 175 176 177@contextlib.contextmanager 178def set_workers(workers): 179 """Context manager for the default number of workers used in `scipy.fft` 180 181 Parameters 182 ---------- 183 workers : int 184 The default number of workers to use 185 186 Examples 187 -------- 188 >>> from scipy import fft, signal 189 >>> rng = np.random.default_rng() 190 >>> x = rng.standard_normal((128, 64)) 191 >>> with fft.set_workers(4): 192 ... y = signal.fftconvolve(x, x) 193 194 """ 195 old_workers = get_workers() 196 _config.default_workers = _workers(operator.index(workers)) 197 try: 198 yield 199 finally: 200 _config.default_workers = old_workers 201 202 203def get_workers(): 204 """Returns the default number of workers within the current context 205 206 Examples 207 -------- 208 >>> from scipy import fft 209 >>> fft.get_workers() 210 1 211 >>> with fft.set_workers(4): 212 ... fft.get_workers() 213 4 214 """ 215 return getattr(_config, 'default_workers', 1) 216