1""" A set of NumPy functions to apply per chunk """ 2import contextlib 3from collections.abc import Container, Iterable, Sequence 4from functools import wraps 5from numbers import Integral 6 7import numpy as np 8from tlz import concat 9 10from ..core import flatten 11from . import numpy_compat as npcompat 12 13try: 14 from numpy import take_along_axis 15except ImportError: # pragma: no cover 16 take_along_axis = npcompat.take_along_axis 17 18 19def keepdims_wrapper(a_callable): 20 """ 21 A wrapper for functions that don't provide keepdims to ensure that they do. 22 """ 23 24 @wraps(a_callable) 25 def keepdims_wrapped_callable(x, axis=None, keepdims=None, *args, **kwargs): 26 r = a_callable(x, axis=axis, *args, **kwargs) 27 28 if not keepdims: 29 return r 30 31 axes = axis 32 33 if axes is None: 34 axes = range(x.ndim) 35 36 if not isinstance(axes, (Container, Iterable, Sequence)): 37 axes = [axes] 38 39 r_slice = tuple() 40 for each_axis in range(x.ndim): 41 if each_axis in axes: 42 r_slice += (None,) 43 else: 44 r_slice += (slice(None),) 45 46 r = r[r_slice] 47 48 return r 49 50 return keepdims_wrapped_callable 51 52 53# Wrap NumPy functions to ensure they provide keepdims. 54sum = np.sum 55prod = np.prod 56min = np.min 57max = np.max 58argmin = keepdims_wrapper(np.argmin) 59nanargmin = keepdims_wrapper(np.nanargmin) 60argmax = keepdims_wrapper(np.argmax) 61nanargmax = keepdims_wrapper(np.nanargmax) 62any = np.any 63all = np.all 64nansum = np.nansum 65nanprod = np.nanprod 66 67nancumprod = np.nancumprod 68nancumsum = np.nancumsum 69 70nanmin = np.nanmin 71nanmax = np.nanmax 72mean = np.mean 73 74with contextlib.suppress(AttributeError): 75 nanmean = np.nanmean 76 77var = np.var 78 79with contextlib.suppress(AttributeError): 80 nanvar = np.nanvar 81 82std = np.std 83 84with contextlib.suppress(AttributeError): 85 nanstd = np.nanstd 86 87 88def coarsen(reduction, x, axes, trim_excess=False, **kwargs): 89 """Coarsen array by applying reduction to fixed size neighborhoods 90 91 Parameters 92 ---------- 93 reduction: function 94 Function like np.sum, np.mean, etc... 95 x: np.ndarray 96 Array to be coarsened 97 axes: dict 98 Mapping of axis to coarsening factor 99 100 Examples 101 -------- 102 >>> x = np.array([1, 2, 3, 4, 5, 6]) 103 >>> coarsen(np.sum, x, {0: 2}) 104 array([ 3, 7, 11]) 105 >>> coarsen(np.max, x, {0: 3}) 106 array([3, 6]) 107 108 Provide dictionary of scale per dimension 109 110 >>> x = np.arange(24).reshape((4, 6)) 111 >>> x 112 array([[ 0, 1, 2, 3, 4, 5], 113 [ 6, 7, 8, 9, 10, 11], 114 [12, 13, 14, 15, 16, 17], 115 [18, 19, 20, 21, 22, 23]]) 116 117 >>> coarsen(np.min, x, {0: 2, 1: 3}) 118 array([[ 0, 3], 119 [12, 15]]) 120 121 You must avoid excess elements explicitly 122 123 >>> x = np.array([1, 2, 3, 4, 5, 6, 7, 8]) 124 >>> coarsen(np.min, x, {0: 3}, trim_excess=True) 125 array([1, 4]) 126 """ 127 # Insert singleton dimensions if they don't exist already 128 for i in range(x.ndim): 129 if i not in axes: 130 axes[i] = 1 131 132 if trim_excess: 133 ind = tuple( 134 slice(0, -(d % axes[i])) if d % axes[i] else slice(None, None) 135 for i, d in enumerate(x.shape) 136 ) 137 x = x[ind] 138 139 # (10, 10) -> (5, 2, 5, 2) 140 newshape = tuple(concat([(x.shape[i] // axes[i], axes[i]) for i in range(x.ndim)])) 141 142 return reduction(x.reshape(newshape), axis=tuple(range(1, x.ndim * 2, 2)), **kwargs) 143 144 145def trim(x, axes=None): 146 """Trim boundaries off of array 147 148 >>> x = np.arange(24).reshape((4, 6)) 149 >>> trim(x, axes={0: 0, 1: 1}) 150 array([[ 1, 2, 3, 4], 151 [ 7, 8, 9, 10], 152 [13, 14, 15, 16], 153 [19, 20, 21, 22]]) 154 155 >>> trim(x, axes={0: 1, 1: 1}) 156 array([[ 7, 8, 9, 10], 157 [13, 14, 15, 16]]) 158 """ 159 if isinstance(axes, Integral): 160 axes = [axes] * x.ndim 161 if isinstance(axes, dict): 162 axes = [axes.get(i, 0) for i in range(x.ndim)] 163 164 return x[tuple(slice(ax, -ax if ax else None) for ax in axes)] 165 166 167def topk(a, k, axis, keepdims): 168 """Chunk and combine function of topk 169 170 Extract the k largest elements from a on the given axis. 171 If k is negative, extract the -k smallest elements instead. 172 Note that, unlike in the parent function, the returned elements 173 are not sorted internally. 174 """ 175 assert keepdims is True 176 axis = axis[0] 177 if abs(k) >= a.shape[axis]: 178 return a 179 180 a = np.partition(a, -k, axis=axis) 181 k_slice = slice(-k, None) if k > 0 else slice(-k) 182 return a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))] 183 184 185def topk_aggregate(a, k, axis, keepdims): 186 """Final aggregation function of topk 187 188 Invoke topk one final time and then sort the results internally. 189 """ 190 assert keepdims is True 191 a = topk(a, k, axis, keepdims) 192 axis = axis[0] 193 a = np.sort(a, axis=axis) 194 if k < 0: 195 return a 196 return a[ 197 tuple( 198 slice(None, None, -1) if i == axis else slice(None) for i in range(a.ndim) 199 ) 200 ] 201 202 203def argtopk_preprocess(a, idx): 204 """Preparatory step for argtopk 205 206 Put data together with its original indices in a tuple. 207 """ 208 return a, idx 209 210 211def argtopk(a_plus_idx, k, axis, keepdims): 212 """Chunk and combine function of argtopk 213 214 Extract the indices of the k largest elements from a on the given axis. 215 If k is negative, extract the indices of the -k smallest elements instead. 216 Note that, unlike in the parent function, the returned elements 217 are not sorted internally. 218 """ 219 assert keepdims is True 220 axis = axis[0] 221 222 if isinstance(a_plus_idx, list): 223 a_plus_idx = list(flatten(a_plus_idx)) 224 a = np.concatenate([ai for ai, _ in a_plus_idx], axis) 225 idx = np.concatenate( 226 [np.broadcast_to(idxi, ai.shape) for ai, idxi in a_plus_idx], axis 227 ) 228 else: 229 a, idx = a_plus_idx 230 231 if abs(k) >= a.shape[axis]: 232 return a_plus_idx 233 234 idx2 = np.argpartition(a, -k, axis=axis) 235 k_slice = slice(-k, None) if k > 0 else slice(-k) 236 idx2 = idx2[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))] 237 return take_along_axis(a, idx2, axis), take_along_axis(idx, idx2, axis) 238 239 240def argtopk_aggregate(a_plus_idx, k, axis, keepdims): 241 """Final aggregation function of argtopk 242 243 Invoke argtopk one final time, sort the results internally, drop the data 244 and return the index only. 245 """ 246 assert keepdims is True 247 a, idx = argtopk(a_plus_idx, k, axis, keepdims) 248 axis = axis[0] 249 250 idx2 = np.argsort(a, axis=axis) 251 idx = take_along_axis(idx, idx2, axis) 252 if k < 0: 253 return idx 254 return idx[ 255 tuple( 256 slice(None, None, -1) if i == axis else slice(None) for i in range(idx.ndim) 257 ) 258 ] 259 260 261def arange(start, stop, step, length, dtype, like=None): 262 from .utils import arange_safe 263 264 res = arange_safe(start, stop, step, dtype, like=like) 265 return res[:-1] if len(res) > length else res 266 267 268def linspace(start, stop, num, endpoint=True, dtype=None): 269 from .core import Array 270 271 if isinstance(start, Array): 272 start = start.compute() 273 274 if isinstance(stop, Array): 275 stop = stop.compute() 276 277 return np.linspace(start, stop, num, endpoint=endpoint, dtype=dtype) 278 279 280def astype(x, astype_dtype=None, **kwargs): 281 return x.astype(astype_dtype, **kwargs) 282 283 284def view(x, dtype, order="C"): 285 if order == "C": 286 try: 287 x = np.ascontiguousarray(x, like=x) 288 except TypeError: 289 x = np.ascontiguousarray(x) 290 return x.view(dtype) 291 else: 292 try: 293 x = np.asfortranarray(x, like=x) 294 except TypeError: 295 x = np.asfortranarray(x) 296 return x.T.view(dtype).T 297 298 299def slice_with_int_dask_array(x, idx, offset, x_size, axis): 300 """Chunk function of `slice_with_int_dask_array_on_axis`. 301 Slice one chunk of x by one chunk of idx. 302 303 Parameters 304 ---------- 305 x: ndarray, any dtype, any shape 306 i-th chunk of x 307 idx: ndarray, ndim=1, dtype=any integer 308 j-th chunk of idx (cartesian product with the chunks of x) 309 offset: ndarray, shape=(1, ), dtype=int64 310 Index of the first element along axis of the current chunk of x 311 x_size: int 312 Total size of the x da.Array along axis 313 axis: int 314 normalized axis to take elements from (0 <= axis < x.ndim) 315 316 Returns 317 ------- 318 x sliced along axis, using only the elements of idx that fall inside the 319 current chunk. 320 """ 321 from .utils import asarray_safe, meta_from_array 322 323 idx = asarray_safe(idx, like=meta_from_array(x)) 324 325 # Needed when idx is unsigned 326 idx = idx.astype(np.int64) 327 328 # Normalize negative indices 329 idx = np.where(idx < 0, idx + x_size, idx) 330 331 # A chunk of the offset dask Array is a numpy array with shape (1, ). 332 # It indicates the index of the first element along axis of the current 333 # chunk of x. 334 idx = idx - offset 335 336 # Drop elements of idx that do not fall inside the current chunk of x 337 idx_filter = (idx >= 0) & (idx < x.shape[axis]) 338 idx = idx[idx_filter] 339 340 # np.take does not support slice indices 341 # return np.take(x, idx, axis) 342 return x[tuple(idx if i == axis else slice(None) for i in range(x.ndim))] 343 344 345def slice_with_int_dask_array_aggregate(idx, chunk_outputs, x_chunks, axis): 346 """Final aggregation function of `slice_with_int_dask_array_on_axis`. 347 Aggregate all chunks of x by one chunk of idx, reordering the output of 348 `slice_with_int_dask_array`. 349 350 Note that there is no combine function, as a recursive aggregation (e.g. 351 with split_every) would not give any benefit. 352 353 Parameters 354 ---------- 355 idx: ndarray, ndim=1, dtype=any integer 356 j-th chunk of idx 357 chunk_outputs: ndarray 358 concatenation along axis of the outputs of `slice_with_int_dask_array` 359 for all chunks of x and the j-th chunk of idx 360 x_chunks: tuple 361 dask chunks of the x da.Array along axis, e.g. ``(3, 3, 2)`` 362 axis: int 363 normalized axis to take elements from (0 <= axis < x.ndim) 364 365 Returns 366 ------- 367 Selection from all chunks of x for the j-th chunk of idx, in the correct 368 order 369 """ 370 # Needed when idx is unsigned 371 idx = idx.astype(np.int64) 372 373 # Normalize negative indices 374 idx = np.where(idx < 0, idx + sum(x_chunks), idx) 375 376 x_chunk_offset = 0 377 chunk_output_offset = 0 378 379 # Assemble the final index that picks from the output of the previous 380 # kernel by adding together one layer per chunk of x 381 # FIXME: this could probably be reimplemented with a faster search-based 382 # algorithm 383 idx_final = np.zeros_like(idx) 384 for x_chunk in x_chunks: 385 idx_filter = (idx >= x_chunk_offset) & (idx < x_chunk_offset + x_chunk) 386 idx_cum = np.cumsum(idx_filter) 387 idx_final += np.where(idx_filter, idx_cum - 1 + chunk_output_offset, 0) 388 x_chunk_offset += x_chunk 389 if idx_cum.size > 0: 390 chunk_output_offset += idx_cum[-1] 391 392 # np.take does not support slice indices 393 # return np.take(chunk_outputs, idx_final, axis) 394 return chunk_outputs[ 395 tuple( 396 idx_final if i == axis else slice(None) for i in range(chunk_outputs.ndim) 397 ) 398 ] 399 400 401def getitem(obj, index): 402 """Getitem function 403 404 This function creates a copy of the desired selection for array-like 405 inputs when the selection is smaller than half of the original array. This 406 avoids excess memory usage when extracting a small portion from a large array. 407 For more information, see 408 https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing. 409 410 Parameters 411 ---------- 412 obj: ndarray, string, tuple, list 413 Object to get item from. 414 index: int, list[int], slice() 415 Desired selection to extract from obj. 416 417 Returns 418 ------- 419 Selection obj[index] 420 421 """ 422 result = obj[index] 423 try: 424 if not result.flags.owndata and obj.size >= 2 * result.size: 425 result = result.copy() 426 except AttributeError: 427 pass 428 return result 429