1import itertools
2from collections.abc import Sequence
3from functools import partial, reduce
4from itertools import product
5from numbers import Integral, Number
6from operator import getitem
7
8import numpy as np
9from tlz import sliding_window
10
11from ..base import tokenize
12from ..highlevelgraph import HighLevelGraph
13from ..utils import derived_from
14from . import chunk
15from .core import (
16    Array,
17    asarray,
18    block,
19    blockwise,
20    broadcast_arrays,
21    broadcast_to,
22    cached_cumsum,
23    concatenate,
24    normalize_chunks,
25    stack,
26)
27from .numpy_compat import _numpy_120
28from .ufunc import greater_equal, rint
29from .utils import meta_from_array
30from .wrap import empty, full, ones, zeros
31
32
33def empty_like(a, dtype=None, order="C", chunks=None, name=None, shape=None):
34    """
35    Return a new array with the same shape and type as a given array.
36
37    Parameters
38    ----------
39    a : array_like
40        The shape and data-type of `a` define these same attributes of the
41        returned array.
42    dtype : data-type, optional
43        Overrides the data type of the result.
44    order : {'C', 'F'}, optional
45        Whether to store multidimensional data in C- or Fortran-contiguous
46        (row- or column-wise) order in memory.
47    chunks : sequence of ints
48        The number of samples on each block. Note that the last block will have
49        fewer samples if ``len(array) % chunks != 0``.
50    name : str, optional
51        An optional keyname for the array. Defaults to hashing the input
52        keyword arguments.
53    shape : int or sequence of ints, optional.
54        Overrides the shape of the result.
55
56    Returns
57    -------
58    out : ndarray
59        Array of uninitialized (arbitrary) data with the same
60        shape and type as `a`.
61
62    See Also
63    --------
64    ones_like : Return an array of ones with shape and type of input.
65    zeros_like : Return an array of zeros with shape and type of input.
66    empty : Return a new uninitialized array.
67    ones : Return a new array setting values to one.
68    zeros : Return a new array setting values to zero.
69
70    Notes
71    -----
72    This function does *not* initialize the returned array; to do that use
73    `zeros_like` or `ones_like` instead.  It may be marginally faster than
74    the functions that do set the array values.
75    """
76
77    a = asarray(a, name=False)
78    shape, chunks = _get_like_function_shapes_chunks(a, chunks, shape)
79    return empty(
80        shape,
81        dtype=(dtype or a.dtype),
82        order=order,
83        chunks=chunks,
84        name=name,
85        meta=a._meta,
86    )
87
88
89def ones_like(a, dtype=None, order="C", chunks=None, name=None, shape=None):
90    """
91    Return an array of ones with the same shape and type as a given array.
92
93    Parameters
94    ----------
95    a : array_like
96        The shape and data-type of `a` define these same attributes of
97        the returned array.
98    dtype : data-type, optional
99        Overrides the data type of the result.
100    order : {'C', 'F'}, optional
101        Whether to store multidimensional data in C- or Fortran-contiguous
102        (row- or column-wise) order in memory.
103    chunks : sequence of ints
104        The number of samples on each block. Note that the last block will have
105        fewer samples if ``len(array) % chunks != 0``.
106    name : str, optional
107        An optional keyname for the array. Defaults to hashing the input
108        keyword arguments.
109    shape : int or sequence of ints, optional.
110        Overrides the shape of the result.
111
112    Returns
113    -------
114    out : ndarray
115        Array of ones with the same shape and type as `a`.
116
117    See Also
118    --------
119    zeros_like : Return an array of zeros with shape and type of input.
120    empty_like : Return an empty array with shape and type of input.
121    zeros : Return a new array setting values to zero.
122    ones : Return a new array setting values to one.
123    empty : Return a new uninitialized array.
124    """
125
126    a = asarray(a, name=False)
127    shape, chunks = _get_like_function_shapes_chunks(a, chunks, shape)
128    return ones(
129        shape,
130        dtype=(dtype or a.dtype),
131        order=order,
132        chunks=chunks,
133        name=name,
134        meta=a._meta,
135    )
136
137
138def zeros_like(a, dtype=None, order="C", chunks=None, name=None, shape=None):
139    """
140    Return an array of zeros with the same shape and type as a given array.
141
142    Parameters
143    ----------
144    a : array_like
145        The shape and data-type of `a` define these same attributes of
146        the returned array.
147    dtype : data-type, optional
148        Overrides the data type of the result.
149    order : {'C', 'F'}, optional
150        Whether to store multidimensional data in C- or Fortran-contiguous
151        (row- or column-wise) order in memory.
152    chunks : sequence of ints
153        The number of samples on each block. Note that the last block will have
154        fewer samples if ``len(array) % chunks != 0``.
155    name : str, optional
156        An optional keyname for the array. Defaults to hashing the input
157        keyword arguments.
158    shape : int or sequence of ints, optional.
159        Overrides the shape of the result.
160
161    Returns
162    -------
163    out : ndarray
164        Array of zeros with the same shape and type as `a`.
165
166    See Also
167    --------
168    ones_like : Return an array of ones with shape and type of input.
169    empty_like : Return an empty array with shape and type of input.
170    zeros : Return a new array setting values to zero.
171    ones : Return a new array setting values to one.
172    empty : Return a new uninitialized array.
173    """
174
175    a = asarray(a, name=False)
176    shape, chunks = _get_like_function_shapes_chunks(a, chunks, shape)
177    return zeros(
178        shape,
179        dtype=(dtype or a.dtype),
180        order=order,
181        chunks=chunks,
182        name=name,
183        meta=a._meta,
184    )
185
186
187def full_like(a, fill_value, order="C", dtype=None, chunks=None, name=None, shape=None):
188    """
189    Return a full array with the same shape and type as a given array.
190
191    Parameters
192    ----------
193    a : array_like
194        The shape and data-type of `a` define these same attributes of
195        the returned array.
196    fill_value : scalar
197        Fill value.
198    dtype : data-type, optional
199        Overrides the data type of the result.
200    order : {'C', 'F'}, optional
201        Whether to store multidimensional data in C- or Fortran-contiguous
202        (row- or column-wise) order in memory.
203    chunks : sequence of ints
204        The number of samples on each block. Note that the last block will have
205        fewer samples if ``len(array) % chunks != 0``.
206    name : str, optional
207        An optional keyname for the array. Defaults to hashing the input
208        keyword arguments.
209    shape : int or sequence of ints, optional.
210        Overrides the shape of the result.
211
212    Returns
213    -------
214    out : ndarray
215        Array of `fill_value` with the same shape and type as `a`.
216
217    See Also
218    --------
219    zeros_like : Return an array of zeros with shape and type of input.
220    ones_like : Return an array of ones with shape and type of input.
221    empty_like : Return an empty array with shape and type of input.
222    zeros : Return a new array setting values to zero.
223    ones : Return a new array setting values to one.
224    empty : Return a new uninitialized array.
225    full : Fill a new array.
226    """
227
228    a = asarray(a, name=False)
229    shape, chunks = _get_like_function_shapes_chunks(a, chunks, shape)
230    return full(
231        shape,
232        fill_value,
233        dtype=(dtype or a.dtype),
234        order=order,
235        chunks=chunks,
236        name=name,
237        meta=a._meta,
238    )
239
240
241def _get_like_function_shapes_chunks(a, chunks, shape):
242    """
243    Helper function for finding shapes and chunks for *_like()
244    array creation functions.
245    """
246    if shape is None:
247        shape = a.shape
248        if chunks is None:
249            chunks = a.chunks
250    elif chunks is None:
251        chunks = "auto"
252    return shape, chunks
253
254
255def linspace(
256    start, stop, num=50, endpoint=True, retstep=False, chunks="auto", dtype=None
257):
258    """
259    Return `num` evenly spaced values over the closed interval [`start`,
260    `stop`].
261
262    Parameters
263    ----------
264    start : scalar
265        The starting value of the sequence.
266    stop : scalar
267        The last value of the sequence.
268    num : int, optional
269        Number of samples to include in the returned dask array, including the
270        endpoints. Default is 50.
271    endpoint : bool, optional
272        If True, ``stop`` is the last sample. Otherwise, it is not included.
273        Default is True.
274    retstep : bool, optional
275        If True, return (samples, step), where step is the spacing between
276        samples. Default is False.
277    chunks :  int
278        The number of samples on each block. Note that the last block will have
279        fewer samples if `num % blocksize != 0`
280    dtype : dtype, optional
281        The type of the output array.
282
283    Returns
284    -------
285    samples : dask array
286    step : float, optional
287        Only returned if ``retstep`` is True. Size of spacing between samples.
288
289
290    See Also
291    --------
292    dask.array.arange
293    """
294    num = int(num)
295
296    if dtype is None:
297        dtype = np.linspace(0, 1, 1).dtype
298
299    chunks = normalize_chunks(chunks, (num,), dtype=dtype)
300
301    range_ = stop - start
302
303    div = (num - 1) if endpoint else num
304    step = float(range_) / div
305
306    name = "linspace-" + tokenize((start, stop, num, endpoint, chunks, dtype))
307
308    dsk = {}
309    blockstart = start
310
311    for i, bs in enumerate(chunks[0]):
312        bs_space = bs - 1 if endpoint else bs
313        blockstop = blockstart + (bs_space * step)
314        task = (
315            partial(chunk.linspace, endpoint=endpoint, dtype=dtype),
316            blockstart,
317            blockstop,
318            bs,
319        )
320        blockstart = blockstart + (step * bs)
321        dsk[(name, i)] = task
322
323    if retstep:
324        return Array(dsk, name, chunks, dtype=dtype), step
325    else:
326        return Array(dsk, name, chunks, dtype=dtype)
327
328
329def arange(*args, chunks="auto", like=None, dtype=None, **kwargs):
330    """
331    Return evenly spaced values from `start` to `stop` with step size `step`.
332
333    The values are half-open [start, stop), so including start and excluding
334    stop. This is basically the same as python's range function but for dask
335    arrays.
336
337    When using a non-integer step, such as 0.1, the results will often not be
338    consistent. It is better to use linspace for these cases.
339
340    Parameters
341    ----------
342    start : int, optional
343        The starting value of the sequence. The default is 0.
344    stop : int
345        The end of the interval, this value is excluded from the interval.
346    step : int, optional
347        The spacing between the values. The default is 1 when not specified.
348        The last value of the sequence.
349    chunks :  int
350        The number of samples on each block. Note that the last block will have
351        fewer samples if ``len(array) % chunks != 0``.
352        Defaults to "auto" which will automatically determine chunk sizes.
353    dtype : numpy.dtype
354        Output dtype. Omit to infer it from start, stop, step
355        Defaults to ``None``.
356    like : array type or ``None``
357        Array to extract meta from. Defaults to ``None``.
358
359    Returns
360    -------
361    samples : dask array
362
363    See Also
364    --------
365    dask.array.linspace
366    """
367    if len(args) == 1:
368        start = 0
369        stop = args[0]
370        step = 1
371    elif len(args) == 2:
372        start = args[0]
373        stop = args[1]
374        step = 1
375    elif len(args) == 3:
376        start, stop, step = args
377    else:
378        raise TypeError(
379            """
380        arange takes 3 positional arguments: arange([start], stop, [step])
381        """
382        )
383
384    num = int(max(np.ceil((stop - start) / step), 0))
385
386    meta = meta_from_array(like) if like is not None else None
387
388    if dtype is None:
389        dtype = np.arange(start, stop, step * num if num else step).dtype
390
391    chunks = normalize_chunks(chunks, (num,), dtype=dtype)
392
393    if kwargs:
394        raise TypeError("Unexpected keyword argument(s): %s" % ",".join(kwargs.keys()))
395
396    name = "arange-" + tokenize((start, stop, step, chunks, dtype))
397    dsk = {}
398    elem_count = 0
399
400    for i, bs in enumerate(chunks[0]):
401        blockstart = start + (elem_count * step)
402        blockstop = start + ((elem_count + bs) * step)
403        task = (
404            partial(chunk.arange, like=like),
405            blockstart,
406            blockstop,
407            step,
408            bs,
409            dtype,
410        )
411        dsk[(name, i)] = task
412        elem_count += bs
413
414    return Array(dsk, name, chunks, dtype=dtype, meta=meta)
415
416
417@derived_from(np)
418def meshgrid(*xi, sparse=False, indexing="xy", **kwargs):
419    sparse = bool(sparse)
420
421    if "copy" in kwargs:
422        raise NotImplementedError("`copy` not supported")
423
424    if kwargs:
425        raise TypeError("unsupported keyword argument(s) provided")
426
427    if indexing not in ("ij", "xy"):
428        raise ValueError("`indexing` must be `'ij'` or `'xy'`")
429
430    xi = [asarray(e) for e in xi]
431    xi = [e.flatten() for e in xi]
432
433    if indexing == "xy" and len(xi) > 1:
434        xi[0], xi[1] = xi[1], xi[0]
435
436    grid = []
437    for i in range(len(xi)):
438        s = len(xi) * [None]
439        s[i] = slice(None)
440        s = tuple(s)
441
442        r = xi[i][s]
443
444        grid.append(r)
445
446    if not sparse:
447        grid = broadcast_arrays(*grid)
448
449    if indexing == "xy" and len(xi) > 1:
450        grid[0], grid[1] = grid[1], grid[0]
451
452    return grid
453
454
455def indices(dimensions, dtype=int, chunks="auto"):
456    """
457    Implements NumPy's ``indices`` for Dask Arrays.
458
459    Generates a grid of indices covering the dimensions provided.
460
461    The final array has the shape ``(len(dimensions), *dimensions)``. The
462    chunks are used to specify the chunking for axis 1 up to
463    ``len(dimensions)``. The 0th axis always has chunks of length 1.
464
465    Parameters
466    ----------
467    dimensions : sequence of ints
468        The shape of the index grid.
469    dtype : dtype, optional
470        Type to use for the array. Default is ``int``.
471    chunks : sequence of ints, str
472        The size of each block.  Must be one of the following forms:
473
474        - A blocksize like (500, 1000)
475        - A size in bytes, like "100 MiB" which will choose a uniform
476          block-like shape
477        - The word "auto" which acts like the above, but uses a configuration
478          value ``array.chunk-size`` for the chunk size
479
480        Note that the last block will have fewer samples if ``len(array) % chunks != 0``.
481
482    Returns
483    -------
484    grid : dask array
485    """
486    dimensions = tuple(dimensions)
487    dtype = np.dtype(dtype)
488    chunks = normalize_chunks(chunks, shape=dimensions, dtype=dtype)
489
490    if len(dimensions) != len(chunks):
491        raise ValueError("Need same number of chunks as dimensions.")
492
493    xi = []
494    for i in range(len(dimensions)):
495        xi.append(arange(dimensions[i], dtype=dtype, chunks=(chunks[i],)))
496
497    grid = []
498    if np.prod(dimensions):
499        grid = meshgrid(*xi, indexing="ij")
500
501    if grid:
502        grid = stack(grid)
503    else:
504        grid = empty((len(dimensions),) + dimensions, dtype=dtype, chunks=(1,) + chunks)
505
506    return grid
507
508
509def eye(N, chunks="auto", M=None, k=0, dtype=float):
510    """
511    Return a 2-D Array with ones on the diagonal and zeros elsewhere.
512
513    Parameters
514    ----------
515    N : int
516      Number of rows in the output.
517    chunks : int, str
518        How to chunk the array. Must be one of the following forms:
519
520        -   A blocksize like 1000.
521        -   A size in bytes, like "100 MiB" which will choose a uniform
522            block-like shape
523        -   The word "auto" which acts like the above, but uses a configuration
524            value ``array.chunk-size`` for the chunk size
525    M : int, optional
526      Number of columns in the output. If None, defaults to `N`.
527    k : int, optional
528      Index of the diagonal: 0 (the default) refers to the main diagonal,
529      a positive value refers to an upper diagonal, and a negative value
530      to a lower diagonal.
531    dtype : data-type, optional
532      Data-type of the returned array.
533
534    Returns
535    -------
536    I : Array of shape (N,M)
537      An array where all elements are equal to zero, except for the `k`-th
538      diagonal, whose values are equal to one.
539    """
540    eye = {}
541    if M is None:
542        M = N
543
544    if not isinstance(chunks, (int, str)):
545        raise ValueError("chunks must be an int or string")
546
547    vchunks, hchunks = normalize_chunks(chunks, shape=(N, M), dtype=dtype)
548    chunks = vchunks[0]
549
550    token = tokenize(N, chunks, M, k, dtype)
551    name_eye = "eye-" + token
552
553    for i, vchunk in enumerate(vchunks):
554        for j, hchunk in enumerate(hchunks):
555            if (j - i - 1) * chunks <= k <= (j - i + 1) * chunks:
556                eye[name_eye, i, j] = (
557                    np.eye,
558                    vchunk,
559                    hchunk,
560                    k - (j - i) * chunks,
561                    dtype,
562                )
563            else:
564                eye[name_eye, i, j] = (np.zeros, (vchunk, hchunk), dtype)
565    return Array(eye, name_eye, shape=(N, M), chunks=(chunks, chunks), dtype=dtype)
566
567
568@derived_from(np)
569def diag(v):
570    name = "diag-" + tokenize(v)
571
572    meta = meta_from_array(v, 2 if v.ndim == 1 else 1)
573
574    if isinstance(v, np.ndarray) or (
575        hasattr(v, "__array_function__") and not isinstance(v, Array)
576    ):
577        if v.ndim == 1:
578            chunks = ((v.shape[0],), (v.shape[0],))
579            dsk = {(name, 0, 0): (np.diag, v)}
580        elif v.ndim == 2:
581            chunks = ((min(v.shape),),)
582            dsk = {(name, 0): (np.diag, v)}
583        else:
584            raise ValueError("Array must be 1d or 2d only")
585        return Array(dsk, name, chunks, meta=meta)
586    if not isinstance(v, Array):
587        raise TypeError(f"v must be a dask array or numpy array, got {type(v)}")
588    if v.ndim != 1:
589        if v.chunks[0] == v.chunks[1]:
590            dsk = {
591                (name, i): (np.diag, row[i]) for i, row in enumerate(v.__dask_keys__())
592            }
593            graph = HighLevelGraph.from_collections(name, dsk, dependencies=[v])
594            return Array(graph, name, (v.chunks[0],), meta=meta)
595        else:
596            raise NotImplementedError(
597                "Extracting diagonals from non-square chunked arrays"
598            )
599    chunks_1d = v.chunks[0]
600    blocks = v.__dask_keys__()
601    dsk = {}
602    for i, m in enumerate(chunks_1d):
603        for j, n in enumerate(chunks_1d):
604            key = (name, i, j)
605            if i == j:
606                dsk[key] = (np.diag, blocks[i])
607            else:
608                dsk[key] = (np.zeros, (m, n))
609                dsk[key] = (partial(np.zeros_like, shape=(m, n)), meta)
610
611    graph = HighLevelGraph.from_collections(name, dsk, dependencies=[v])
612    return Array(graph, name, (chunks_1d, chunks_1d), meta=meta)
613
614
615@derived_from(np)
616def diagonal(a, offset=0, axis1=0, axis2=1):
617    name = "diagonal-" + tokenize(a, offset, axis1, axis2)
618
619    if a.ndim < 2:
620        # NumPy uses `diag` as we do here.
621        raise ValueError("diag requires an array of at least two dimensions")
622
623    def _axis_fmt(axis, name, ndim):
624        if axis < 0:
625            t = ndim + axis
626            if t < 0:
627                msg = "{}: axis {} is out of bounds for array of dimension {}"
628                raise np.AxisError(msg.format(name, axis, ndim))
629            axis = t
630        return axis
631
632    axis1 = _axis_fmt(axis1, "axis1", a.ndim)
633    axis2 = _axis_fmt(axis2, "axis2", a.ndim)
634
635    if axis1 == axis2:
636        raise ValueError("axis1 and axis2 cannot be the same")
637
638    a = asarray(a)
639
640    if axis1 > axis2:
641        axis1, axis2 = axis2, axis1
642        offset = -offset
643
644    def _diag_len(dim1, dim2, offset):
645        return max(0, min(min(dim1, dim2), dim1 + offset, dim2 - offset))
646
647    diag_chunks = []
648    chunk_offsets = []
649    cum1 = cached_cumsum(a.chunks[axis1], initial_zero=True)[:-1]
650    cum2 = cached_cumsum(a.chunks[axis2], initial_zero=True)[:-1]
651    for co1, c1 in zip(cum1, a.chunks[axis1]):
652        chunk_offsets.append([])
653        for co2, c2 in zip(cum2, a.chunks[axis2]):
654            k = offset + co1 - co2
655            diag_chunks.append(_diag_len(c1, c2, k))
656            chunk_offsets[-1].append(k)
657
658    dsk = {}
659    idx_set = set(range(a.ndim)) - {axis1, axis2}
660    n1 = len(a.chunks[axis1])
661    n2 = len(a.chunks[axis2])
662    for idx in product(*(range(len(a.chunks[i])) for i in idx_set)):
663        for i, (i1, i2) in enumerate(product(range(n1), range(n2))):
664            tsk = reduce(getitem, idx[:axis1], a.__dask_keys__())[i1]
665            tsk = reduce(getitem, idx[axis1 : axis2 - 1], tsk)[i2]
666            tsk = reduce(getitem, idx[axis2 - 1 :], tsk)
667            k = chunk_offsets[i1][i2]
668            dsk[(name,) + idx + (i,)] = (np.diagonal, tsk, k, axis1, axis2)
669
670    left_shape = tuple(a.shape[i] for i in idx_set)
671    right_shape = (_diag_len(a.shape[axis1], a.shape[axis2], offset),)
672    shape = left_shape + right_shape
673
674    left_chunks = tuple(a.chunks[i] for i in idx_set)
675    right_shape = (tuple(diag_chunks),)
676    chunks = left_chunks + right_shape
677
678    graph = HighLevelGraph.from_collections(name, dsk, dependencies=[a])
679    meta = meta_from_array(a, len(shape))
680    return Array(graph, name, shape=shape, chunks=chunks, meta=meta)
681
682
683@derived_from(np)
684def tri(N, M=None, k=0, dtype=float, chunks="auto", *, like=None):
685    if not _numpy_120 and like is not None:
686        raise RuntimeError("The use of ``like`` required NumPy >= 1.20")
687
688    _min_int = np.lib.twodim_base._min_int
689
690    if M is None:
691        M = N
692
693    chunks = normalize_chunks(chunks, shape=(N, M), dtype=dtype)
694
695    m = greater_equal(
696        arange(N, chunks=chunks[0][0], dtype=_min_int(0, N), like=like).reshape(1, N).T,
697        arange(-k, M - k, chunks=chunks[1][0], dtype=_min_int(-k, M - k), like=like),
698    )
699
700    # Avoid making a copy if the requested type is already bool
701    m = m.astype(dtype, copy=False)
702
703    return m
704
705
706@derived_from(np)
707def fromfunction(func, chunks="auto", shape=None, dtype=None, **kwargs):
708    dtype = dtype or float
709    chunks = normalize_chunks(chunks, shape, dtype=dtype)
710
711    inds = tuple(range(len(shape)))
712
713    arrs = [arange(s, dtype=dtype, chunks=c) for s, c in zip(shape, chunks)]
714    arrs = meshgrid(*arrs, indexing="ij")
715
716    args = sum(zip(arrs, itertools.repeat(inds)), ())
717
718    res = blockwise(func, inds, *args, token="fromfunction", **kwargs)
719
720    return res
721
722
723@derived_from(np)
724def repeat(a, repeats, axis=None):
725    if axis is None:
726        if a.ndim == 1:
727            axis = 0
728        else:
729            raise NotImplementedError("Must supply an integer axis value")
730
731    if not isinstance(repeats, Integral):
732        raise NotImplementedError("Only integer valued repeats supported")
733
734    if -a.ndim <= axis < 0:
735        axis += a.ndim
736    elif not 0 <= axis <= a.ndim - 1:
737        raise ValueError("axis(=%d) out of bounds" % axis)
738
739    if repeats == 0:
740        return a[tuple(slice(None) if d != axis else slice(0) for d in range(a.ndim))]
741    elif repeats == 1:
742        return a
743
744    cchunks = cached_cumsum(a.chunks[axis], initial_zero=True)
745    slices = []
746    for c_start, c_stop in sliding_window(2, cchunks):
747        ls = np.linspace(c_start, c_stop, repeats).round(0)
748        for ls_start, ls_stop in sliding_window(2, ls):
749            if ls_start != ls_stop:
750                slices.append(slice(ls_start, ls_stop))
751
752    all_slice = slice(None, None, None)
753    slices = [
754        (all_slice,) * axis + (s,) + (all_slice,) * (a.ndim - axis - 1) for s in slices
755    ]
756
757    slabs = [a[slc] for slc in slices]
758
759    out = []
760    for slab in slabs:
761        chunks = list(slab.chunks)
762        assert len(chunks[axis]) == 1
763        chunks[axis] = (chunks[axis][0] * repeats,)
764        chunks = tuple(chunks)
765        result = slab.map_blocks(
766            np.repeat, repeats, axis=axis, chunks=chunks, dtype=slab.dtype
767        )
768        out.append(result)
769
770    return concatenate(out, axis=axis)
771
772
773@derived_from(np)
774def tile(A, reps):
775    try:
776        tup = tuple(reps)
777    except TypeError:
778        tup = (reps,)
779    if any(i < 0 for i in tup):
780        raise ValueError("Negative `reps` are not allowed.")
781    c = asarray(A)
782
783    if all(tup):
784        for nrep in tup[::-1]:
785            c = nrep * [c]
786        return block(c)
787
788    d = len(tup)
789    if d < c.ndim:
790        tup = (1,) * (c.ndim - d) + tup
791    if c.ndim < d:
792        shape = (1,) * (d - c.ndim) + c.shape
793    else:
794        shape = c.shape
795    shape_out = tuple(s * t for s, t in zip(shape, tup))
796    return empty(shape=shape_out, dtype=c.dtype)
797
798
799def expand_pad_value(array, pad_value):
800    if isinstance(pad_value, Number):
801        pad_value = array.ndim * ((pad_value, pad_value),)
802    elif (
803        isinstance(pad_value, Sequence)
804        and all(isinstance(pw, Number) for pw in pad_value)
805        and len(pad_value) == 1
806    ):
807        pad_value = array.ndim * ((pad_value[0], pad_value[0]),)
808    elif (
809        isinstance(pad_value, Sequence)
810        and len(pad_value) == 2
811        and all(isinstance(pw, Number) for pw in pad_value)
812    ):
813        pad_value = array.ndim * (tuple(pad_value),)
814    elif (
815        isinstance(pad_value, Sequence)
816        and len(pad_value) == array.ndim
817        and all(isinstance(pw, Sequence) for pw in pad_value)
818        and all((len(pw) == 2) for pw in pad_value)
819        and all(all(isinstance(w, Number) for w in pw) for pw in pad_value)
820    ):
821        pad_value = tuple(tuple(pw) for pw in pad_value)
822    elif (
823        isinstance(pad_value, Sequence)
824        and len(pad_value) == 1
825        and isinstance(pad_value[0], Sequence)
826        and len(pad_value[0]) == 2
827        and all(isinstance(pw, Number) for pw in pad_value[0])
828    ):
829        pad_value = array.ndim * (tuple(pad_value[0]),)
830    else:
831        raise TypeError("`pad_value` must be composed of integral typed values.")
832
833    return pad_value
834
835
836def get_pad_shapes_chunks(array, pad_width, axes):
837    """
838    Helper function for finding shapes and chunks of end pads.
839    """
840
841    pad_shapes = [list(array.shape), list(array.shape)]
842    pad_chunks = [list(array.chunks), list(array.chunks)]
843
844    for d in axes:
845        for i in range(2):
846            pad_shapes[i][d] = pad_width[d][i]
847            pad_chunks[i][d] = (pad_width[d][i],)
848
849    pad_shapes = [tuple(s) for s in pad_shapes]
850    pad_chunks = [tuple(c) for c in pad_chunks]
851
852    return pad_shapes, pad_chunks
853
854
855def linear_ramp_chunk(start, stop, num, dim, step):
856    """
857    Helper function to find the linear ramp for a chunk.
858    """
859    num1 = num + 1
860
861    shape = list(start.shape)
862    shape[dim] = num
863    shape = tuple(shape)
864
865    dtype = np.dtype(start.dtype)
866
867    result = np.empty_like(start, shape=shape, dtype=dtype)
868    for i in np.ndindex(start.shape):
869        j = list(i)
870        j[dim] = slice(None)
871        j = tuple(j)
872
873        result[j] = np.linspace(start[i], stop, num1, dtype=dtype)[1:][::step]
874
875    return result
876
877
878def pad_edge(array, pad_width, mode, **kwargs):
879    """
880    Helper function for padding edges.
881
882    Handles the cases where the only the values on the edge are needed.
883    """
884
885    kwargs = {k: expand_pad_value(array, v) for k, v in kwargs.items()}
886
887    result = array
888    for d in range(array.ndim):
889        pad_shapes, pad_chunks = get_pad_shapes_chunks(result, pad_width, (d,))
890        pad_arrays = [result, result]
891
892        if mode == "constant":
893            from .utils import asarray_safe
894
895            constant_values = kwargs["constant_values"][d]
896            constant_values = [
897                asarray_safe(c, like=meta_from_array(array), dtype=result.dtype)
898                for c in constant_values
899            ]
900
901            pad_arrays = [
902                broadcast_to(v, s, c)
903                for v, s, c in zip(constant_values, pad_shapes, pad_chunks)
904            ]
905        elif mode in ["edge", "linear_ramp"]:
906            pad_slices = [result.ndim * [slice(None)], result.ndim * [slice(None)]]
907            pad_slices[0][d] = slice(None, 1, None)
908            pad_slices[1][d] = slice(-1, None, None)
909            pad_slices = [tuple(sl) for sl in pad_slices]
910
911            pad_arrays = [result[sl] for sl in pad_slices]
912
913            if mode == "edge":
914                pad_arrays = [
915                    broadcast_to(a, s, c)
916                    for a, s, c in zip(pad_arrays, pad_shapes, pad_chunks)
917                ]
918            elif mode == "linear_ramp":
919                end_values = kwargs["end_values"][d]
920
921                pad_arrays = [
922                    a.map_blocks(
923                        linear_ramp_chunk,
924                        ev,
925                        pw,
926                        chunks=c,
927                        dtype=result.dtype,
928                        dim=d,
929                        step=(2 * i - 1),
930                    )
931                    for i, (a, ev, pw, c) in enumerate(
932                        zip(pad_arrays, end_values, pad_width[d], pad_chunks)
933                    )
934                ]
935        elif mode == "empty":
936            pad_arrays = [
937                empty_like(array, shape=s, dtype=array.dtype, chunks=c)
938                for s, c in zip(pad_shapes, pad_chunks)
939            ]
940
941        result = concatenate([pad_arrays[0], result, pad_arrays[1]], axis=d)
942
943    return result
944
945
946def pad_reuse(array, pad_width, mode, **kwargs):
947    """
948    Helper function for padding boundaries with values in the array.
949
950    Handles the cases where the padding is constructed from values in
951    the array. Namely by reflecting them or tiling them to create periodic
952    boundary constraints.
953    """
954
955    if mode in {"reflect", "symmetric"}:
956        reflect_type = kwargs.get("reflect", "even")
957        if reflect_type == "odd":
958            raise NotImplementedError("`pad` does not support `reflect_type` of `odd`.")
959        if reflect_type != "even":
960            raise ValueError(
961                "unsupported value for reflect_type, must be one of (`even`, `odd`)"
962            )
963
964    result = np.empty(array.ndim * (3,), dtype=object)
965    for idx in np.ndindex(result.shape):
966        select = []
967        orient = []
968        for i, s, pw in zip(idx, array.shape, pad_width):
969            if mode == "wrap":
970                pw = pw[::-1]
971
972            if i < 1:
973                if mode == "reflect":
974                    select.append(slice(1, pw[0] + 1, None))
975                else:
976                    select.append(slice(None, pw[0], None))
977            elif i > 1:
978                if mode == "reflect":
979                    select.append(slice(s - pw[1] - 1, s - 1, None))
980                else:
981                    select.append(slice(s - pw[1], None, None))
982            else:
983                select.append(slice(None))
984
985            if i != 1 and mode in ["reflect", "symmetric"]:
986                orient.append(slice(None, None, -1))
987            else:
988                orient.append(slice(None))
989
990        select = tuple(select)
991        orient = tuple(orient)
992
993        if mode == "wrap":
994            idx = tuple(2 - i for i in idx)
995
996        result[idx] = array[select][orient]
997
998    result = block(result.tolist())
999
1000    return result
1001
1002
1003def pad_stats(array, pad_width, mode, stat_length):
1004    """
1005    Helper function for padding boundaries with statistics from the array.
1006
1007    In cases where the padding requires computations of statistics from part
1008    or all of the array, this function helps compute those statistics as
1009    requested and then adds those statistics onto the boundaries of the array.
1010    """
1011
1012    if mode == "median":
1013        raise NotImplementedError("`pad` does not support `mode` of `median`.")
1014
1015    stat_length = expand_pad_value(array, stat_length)
1016
1017    result = np.empty(array.ndim * (3,), dtype=object)
1018    for idx in np.ndindex(result.shape):
1019        axes = []
1020        select = []
1021        pad_shape = []
1022        pad_chunks = []
1023        for d, (i, s, c, w, l) in enumerate(
1024            zip(idx, array.shape, array.chunks, pad_width, stat_length)
1025        ):
1026            if i < 1:
1027                axes.append(d)
1028                select.append(slice(None, l[0], None))
1029                pad_shape.append(w[0])
1030                pad_chunks.append(w[0])
1031            elif i > 1:
1032                axes.append(d)
1033                select.append(slice(s - l[1], None, None))
1034                pad_shape.append(w[1])
1035                pad_chunks.append(w[1])
1036            else:
1037                select.append(slice(None))
1038                pad_shape.append(s)
1039                pad_chunks.append(c)
1040
1041        axes = tuple(axes)
1042        select = tuple(select)
1043        pad_shape = tuple(pad_shape)
1044        pad_chunks = tuple(pad_chunks)
1045
1046        result_idx = array[select]
1047        if axes:
1048            if mode == "maximum":
1049                result_idx = result_idx.max(axis=axes, keepdims=True)
1050            elif mode == "mean":
1051                result_idx = result_idx.mean(axis=axes, keepdims=True)
1052            elif mode == "minimum":
1053                result_idx = result_idx.min(axis=axes, keepdims=True)
1054
1055            result_idx = broadcast_to(result_idx, pad_shape, chunks=pad_chunks)
1056
1057            if mode == "mean":
1058                if np.issubdtype(array.dtype, np.integer):
1059                    result_idx = rint(result_idx)
1060                result_idx = result_idx.astype(array.dtype)
1061
1062        result[idx] = result_idx
1063
1064    result = block(result.tolist())
1065
1066    return result
1067
1068
1069def wrapped_pad_func(array, pad_func, iaxis_pad_width, iaxis, pad_func_kwargs):
1070    result = np.empty_like(array)
1071    for i in np.ndindex(array.shape[:iaxis] + array.shape[iaxis + 1 :]):
1072        i = i[:iaxis] + (slice(None),) + i[iaxis:]
1073        result[i] = pad_func(array[i], iaxis_pad_width, iaxis, pad_func_kwargs)
1074
1075    return result
1076
1077
1078def pad_udf(array, pad_width, mode, **kwargs):
1079    """
1080    Helper function for padding boundaries with a user defined function.
1081
1082    In cases where the padding requires a custom user defined function be
1083    applied to the array, this function assists in the prepping and
1084    application of this function to the Dask Array to construct the desired
1085    boundaries.
1086    """
1087
1088    result = pad_edge(array, pad_width, "constant", constant_values=0)
1089
1090    chunks = result.chunks
1091    for d in range(result.ndim):
1092        result = result.rechunk(
1093            chunks[:d] + (result.shape[d : d + 1],) + chunks[d + 1 :]
1094        )
1095
1096        result = result.map_blocks(
1097            wrapped_pad_func,
1098            name="pad",
1099            dtype=result.dtype,
1100            pad_func=mode,
1101            iaxis_pad_width=pad_width[d],
1102            iaxis=d,
1103            pad_func_kwargs=kwargs,
1104        )
1105
1106        result = result.rechunk(chunks)
1107
1108    return result
1109
1110
1111@derived_from(np)
1112def pad(array, pad_width, mode="constant", **kwargs):
1113    array = asarray(array)
1114
1115    pad_width = expand_pad_value(array, pad_width)
1116
1117    if callable(mode):
1118        return pad_udf(array, pad_width, mode, **kwargs)
1119
1120    # Make sure that no unsupported keywords were passed for the current mode
1121    allowed_kwargs = {
1122        "empty": [],
1123        "edge": [],
1124        "wrap": [],
1125        "constant": ["constant_values"],
1126        "linear_ramp": ["end_values"],
1127        "maximum": ["stat_length"],
1128        "mean": ["stat_length"],
1129        "median": ["stat_length"],
1130        "minimum": ["stat_length"],
1131        "reflect": ["reflect_type"],
1132        "symmetric": ["reflect_type"],
1133    }
1134    try:
1135        unsupported_kwargs = set(kwargs) - set(allowed_kwargs[mode])
1136    except KeyError as e:
1137        raise ValueError(f"mode '{mode}' is not supported") from e
1138    if unsupported_kwargs:
1139        raise ValueError(
1140            "unsupported keyword arguments for mode '{}': {}".format(
1141                mode, unsupported_kwargs
1142            )
1143        )
1144
1145    if mode in {"maximum", "mean", "median", "minimum"}:
1146        stat_length = kwargs.get("stat_length", tuple((n, n) for n in array.shape))
1147        return pad_stats(array, pad_width, mode, stat_length)
1148    elif mode == "constant":
1149        kwargs.setdefault("constant_values", 0)
1150        return pad_edge(array, pad_width, mode, **kwargs)
1151    elif mode == "linear_ramp":
1152        kwargs.setdefault("end_values", 0)
1153        return pad_edge(array, pad_width, mode, **kwargs)
1154    elif mode in {"edge", "empty"}:
1155        return pad_edge(array, pad_width, mode)
1156    elif mode in ["reflect", "symmetric", "wrap"]:
1157        return pad_reuse(array, pad_width, mode, **kwargs)
1158
1159    assert False, "unreachable"
1160