1""" A set of NumPy functions to apply per chunk """
2import contextlib
3from collections.abc import Container, Iterable, Sequence
4from functools import wraps
5from numbers import Integral
6
7import numpy as np
8from tlz import concat
9
10from ..core import flatten
11from . import numpy_compat as npcompat
12
13try:
14    from numpy import take_along_axis
15except ImportError:  # pragma: no cover
16    take_along_axis = npcompat.take_along_axis
17
18
19def keepdims_wrapper(a_callable):
20    """
21    A wrapper for functions that don't provide keepdims to ensure that they do.
22    """
23
24    @wraps(a_callable)
25    def keepdims_wrapped_callable(x, axis=None, keepdims=None, *args, **kwargs):
26        r = a_callable(x, axis=axis, *args, **kwargs)
27
28        if not keepdims:
29            return r
30
31        axes = axis
32
33        if axes is None:
34            axes = range(x.ndim)
35
36        if not isinstance(axes, (Container, Iterable, Sequence)):
37            axes = [axes]
38
39        r_slice = tuple()
40        for each_axis in range(x.ndim):
41            if each_axis in axes:
42                r_slice += (None,)
43            else:
44                r_slice += (slice(None),)
45
46        r = r[r_slice]
47
48        return r
49
50    return keepdims_wrapped_callable
51
52
53# Wrap NumPy functions to ensure they provide keepdims.
54sum = np.sum
55prod = np.prod
56min = np.min
57max = np.max
58argmin = keepdims_wrapper(np.argmin)
59nanargmin = keepdims_wrapper(np.nanargmin)
60argmax = keepdims_wrapper(np.argmax)
61nanargmax = keepdims_wrapper(np.nanargmax)
62any = np.any
63all = np.all
64nansum = np.nansum
65nanprod = np.nanprod
66
67nancumprod = np.nancumprod
68nancumsum = np.nancumsum
69
70nanmin = np.nanmin
71nanmax = np.nanmax
72mean = np.mean
73
74with contextlib.suppress(AttributeError):
75    nanmean = np.nanmean
76
77var = np.var
78
79with contextlib.suppress(AttributeError):
80    nanvar = np.nanvar
81
82std = np.std
83
84with contextlib.suppress(AttributeError):
85    nanstd = np.nanstd
86
87
88def coarsen(reduction, x, axes, trim_excess=False, **kwargs):
89    """Coarsen array by applying reduction to fixed size neighborhoods
90
91    Parameters
92    ----------
93    reduction: function
94        Function like np.sum, np.mean, etc...
95    x: np.ndarray
96        Array to be coarsened
97    axes: dict
98        Mapping of axis to coarsening factor
99
100    Examples
101    --------
102    >>> x = np.array([1, 2, 3, 4, 5, 6])
103    >>> coarsen(np.sum, x, {0: 2})
104    array([ 3,  7, 11])
105    >>> coarsen(np.max, x, {0: 3})
106    array([3, 6])
107
108    Provide dictionary of scale per dimension
109
110    >>> x = np.arange(24).reshape((4, 6))
111    >>> x
112    array([[ 0,  1,  2,  3,  4,  5],
113           [ 6,  7,  8,  9, 10, 11],
114           [12, 13, 14, 15, 16, 17],
115           [18, 19, 20, 21, 22, 23]])
116
117    >>> coarsen(np.min, x, {0: 2, 1: 3})
118    array([[ 0,  3],
119           [12, 15]])
120
121    You must avoid excess elements explicitly
122
123    >>> x = np.array([1, 2, 3, 4, 5, 6, 7, 8])
124    >>> coarsen(np.min, x, {0: 3}, trim_excess=True)
125    array([1, 4])
126    """
127    # Insert singleton dimensions if they don't exist already
128    for i in range(x.ndim):
129        if i not in axes:
130            axes[i] = 1
131
132    if trim_excess:
133        ind = tuple(
134            slice(0, -(d % axes[i])) if d % axes[i] else slice(None, None)
135            for i, d in enumerate(x.shape)
136        )
137        x = x[ind]
138
139    # (10, 10) -> (5, 2, 5, 2)
140    newshape = tuple(concat([(x.shape[i] // axes[i], axes[i]) for i in range(x.ndim)]))
141
142    return reduction(x.reshape(newshape), axis=tuple(range(1, x.ndim * 2, 2)), **kwargs)
143
144
145def trim(x, axes=None):
146    """Trim boundaries off of array
147
148    >>> x = np.arange(24).reshape((4, 6))
149    >>> trim(x, axes={0: 0, 1: 1})
150    array([[ 1,  2,  3,  4],
151           [ 7,  8,  9, 10],
152           [13, 14, 15, 16],
153           [19, 20, 21, 22]])
154
155    >>> trim(x, axes={0: 1, 1: 1})
156    array([[ 7,  8,  9, 10],
157           [13, 14, 15, 16]])
158    """
159    if isinstance(axes, Integral):
160        axes = [axes] * x.ndim
161    if isinstance(axes, dict):
162        axes = [axes.get(i, 0) for i in range(x.ndim)]
163
164    return x[tuple(slice(ax, -ax if ax else None) for ax in axes)]
165
166
167def topk(a, k, axis, keepdims):
168    """Chunk and combine function of topk
169
170    Extract the k largest elements from a on the given axis.
171    If k is negative, extract the -k smallest elements instead.
172    Note that, unlike in the parent function, the returned elements
173    are not sorted internally.
174    """
175    assert keepdims is True
176    axis = axis[0]
177    if abs(k) >= a.shape[axis]:
178        return a
179
180    a = np.partition(a, -k, axis=axis)
181    k_slice = slice(-k, None) if k > 0 else slice(-k)
182    return a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]
183
184
185def topk_aggregate(a, k, axis, keepdims):
186    """Final aggregation function of topk
187
188    Invoke topk one final time and then sort the results internally.
189    """
190    assert keepdims is True
191    a = topk(a, k, axis, keepdims)
192    axis = axis[0]
193    a = np.sort(a, axis=axis)
194    if k < 0:
195        return a
196    return a[
197        tuple(
198            slice(None, None, -1) if i == axis else slice(None) for i in range(a.ndim)
199        )
200    ]
201
202
203def argtopk_preprocess(a, idx):
204    """Preparatory step for argtopk
205
206    Put data together with its original indices in a tuple.
207    """
208    return a, idx
209
210
211def argtopk(a_plus_idx, k, axis, keepdims):
212    """Chunk and combine function of argtopk
213
214    Extract the indices of the k largest elements from a on the given axis.
215    If k is negative, extract the indices of the -k smallest elements instead.
216    Note that, unlike in the parent function, the returned elements
217    are not sorted internally.
218    """
219    assert keepdims is True
220    axis = axis[0]
221
222    if isinstance(a_plus_idx, list):
223        a_plus_idx = list(flatten(a_plus_idx))
224        a = np.concatenate([ai for ai, _ in a_plus_idx], axis)
225        idx = np.concatenate(
226            [np.broadcast_to(idxi, ai.shape) for ai, idxi in a_plus_idx], axis
227        )
228    else:
229        a, idx = a_plus_idx
230
231    if abs(k) >= a.shape[axis]:
232        return a_plus_idx
233
234    idx2 = np.argpartition(a, -k, axis=axis)
235    k_slice = slice(-k, None) if k > 0 else slice(-k)
236    idx2 = idx2[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]
237    return take_along_axis(a, idx2, axis), take_along_axis(idx, idx2, axis)
238
239
240def argtopk_aggregate(a_plus_idx, k, axis, keepdims):
241    """Final aggregation function of argtopk
242
243    Invoke argtopk one final time, sort the results internally, drop the data
244    and return the index only.
245    """
246    assert keepdims is True
247    a, idx = argtopk(a_plus_idx, k, axis, keepdims)
248    axis = axis[0]
249
250    idx2 = np.argsort(a, axis=axis)
251    idx = take_along_axis(idx, idx2, axis)
252    if k < 0:
253        return idx
254    return idx[
255        tuple(
256            slice(None, None, -1) if i == axis else slice(None) for i in range(idx.ndim)
257        )
258    ]
259
260
261def arange(start, stop, step, length, dtype, like=None):
262    from .utils import arange_safe
263
264    res = arange_safe(start, stop, step, dtype, like=like)
265    return res[:-1] if len(res) > length else res
266
267
268def linspace(start, stop, num, endpoint=True, dtype=None):
269    from .core import Array
270
271    if isinstance(start, Array):
272        start = start.compute()
273
274    if isinstance(stop, Array):
275        stop = stop.compute()
276
277    return np.linspace(start, stop, num, endpoint=endpoint, dtype=dtype)
278
279
280def astype(x, astype_dtype=None, **kwargs):
281    return x.astype(astype_dtype, **kwargs)
282
283
284def view(x, dtype, order="C"):
285    if order == "C":
286        try:
287            x = np.ascontiguousarray(x, like=x)
288        except TypeError:
289            x = np.ascontiguousarray(x)
290        return x.view(dtype)
291    else:
292        try:
293            x = np.asfortranarray(x, like=x)
294        except TypeError:
295            x = np.asfortranarray(x)
296        return x.T.view(dtype).T
297
298
299def slice_with_int_dask_array(x, idx, offset, x_size, axis):
300    """Chunk function of `slice_with_int_dask_array_on_axis`.
301    Slice one chunk of x by one chunk of idx.
302
303    Parameters
304    ----------
305    x: ndarray, any dtype, any shape
306        i-th chunk of x
307    idx: ndarray, ndim=1, dtype=any integer
308        j-th chunk of idx (cartesian product with the chunks of x)
309    offset: ndarray, shape=(1, ), dtype=int64
310        Index of the first element along axis of the current chunk of x
311    x_size: int
312        Total size of the x da.Array along axis
313    axis: int
314        normalized axis to take elements from (0 <= axis < x.ndim)
315
316    Returns
317    -------
318    x sliced along axis, using only the elements of idx that fall inside the
319    current chunk.
320    """
321    from .utils import asarray_safe, meta_from_array
322
323    idx = asarray_safe(idx, like=meta_from_array(x))
324
325    # Needed when idx is unsigned
326    idx = idx.astype(np.int64)
327
328    # Normalize negative indices
329    idx = np.where(idx < 0, idx + x_size, idx)
330
331    # A chunk of the offset dask Array is a numpy array with shape (1, ).
332    # It indicates the index of the first element along axis of the current
333    # chunk of x.
334    idx = idx - offset
335
336    # Drop elements of idx that do not fall inside the current chunk of x
337    idx_filter = (idx >= 0) & (idx < x.shape[axis])
338    idx = idx[idx_filter]
339
340    # np.take does not support slice indices
341    # return np.take(x, idx, axis)
342    return x[tuple(idx if i == axis else slice(None) for i in range(x.ndim))]
343
344
345def slice_with_int_dask_array_aggregate(idx, chunk_outputs, x_chunks, axis):
346    """Final aggregation function of `slice_with_int_dask_array_on_axis`.
347    Aggregate all chunks of x by one chunk of idx, reordering the output of
348    `slice_with_int_dask_array`.
349
350    Note that there is no combine function, as a recursive aggregation (e.g.
351    with split_every) would not give any benefit.
352
353    Parameters
354    ----------
355    idx: ndarray, ndim=1, dtype=any integer
356        j-th chunk of idx
357    chunk_outputs: ndarray
358        concatenation along axis of the outputs of `slice_with_int_dask_array`
359        for all chunks of x and the j-th chunk of idx
360    x_chunks: tuple
361        dask chunks of the x da.Array along axis, e.g. ``(3, 3, 2)``
362    axis: int
363        normalized axis to take elements from (0 <= axis < x.ndim)
364
365    Returns
366    -------
367    Selection from all chunks of x for the j-th chunk of idx, in the correct
368    order
369    """
370    # Needed when idx is unsigned
371    idx = idx.astype(np.int64)
372
373    # Normalize negative indices
374    idx = np.where(idx < 0, idx + sum(x_chunks), idx)
375
376    x_chunk_offset = 0
377    chunk_output_offset = 0
378
379    # Assemble the final index that picks from the output of the previous
380    # kernel by adding together one layer per chunk of x
381    # FIXME: this could probably be reimplemented with a faster search-based
382    # algorithm
383    idx_final = np.zeros_like(idx)
384    for x_chunk in x_chunks:
385        idx_filter = (idx >= x_chunk_offset) & (idx < x_chunk_offset + x_chunk)
386        idx_cum = np.cumsum(idx_filter)
387        idx_final += np.where(idx_filter, idx_cum - 1 + chunk_output_offset, 0)
388        x_chunk_offset += x_chunk
389        if idx_cum.size > 0:
390            chunk_output_offset += idx_cum[-1]
391
392    # np.take does not support slice indices
393    # return np.take(chunk_outputs, idx_final, axis)
394    return chunk_outputs[
395        tuple(
396            idx_final if i == axis else slice(None) for i in range(chunk_outputs.ndim)
397        )
398    ]
399
400
401def getitem(obj, index):
402    """Getitem function
403
404    This function creates a copy of the desired selection for array-like
405    inputs when the selection is smaller than half of the original array. This
406    avoids excess memory usage when extracting a small portion from a large array.
407    For more information, see
408    https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing.
409
410    Parameters
411    ----------
412    obj: ndarray, string, tuple, list
413        Object to get item from.
414    index: int, list[int], slice()
415        Desired selection to extract from obj.
416
417    Returns
418    -------
419    Selection obj[index]
420
421    """
422    result = obj[index]
423    try:
424        if not result.flags.owndata and obj.size >= 2 * result.size:
425            result = result.copy()
426    except AttributeError:
427        pass
428    return result
429