1from . import dtypes, nputils
2
3
4def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1):
5    """Wrapper to apply bottleneck moving window funcs on dask arrays"""
6    import dask.array as da
7
8    dtype, fill_value = dtypes.maybe_promote(a.dtype)
9    a = a.astype(dtype)
10    # inputs for overlap
11    if axis < 0:
12        axis = a.ndim + axis
13    depth = {d: 0 for d in range(a.ndim)}
14    depth[axis] = (window + 1) // 2
15    boundary = {d: fill_value for d in range(a.ndim)}
16    # Create overlap array.
17    ag = da.overlap.overlap(a, depth=depth, boundary=boundary)
18    # apply rolling func
19    out = da.map_blocks(
20        moving_func, ag, window, min_count=min_count, axis=axis, dtype=a.dtype
21    )
22    # trim array
23    result = da.overlap.trim_internal(out, depth)
24    return result
25
26
27def least_squares(lhs, rhs, rcond=None, skipna=False):
28    import dask.array as da
29
30    lhs_da = da.from_array(lhs, chunks=(rhs.chunks[0], lhs.shape[1]))
31    if skipna:
32        added_dim = rhs.ndim == 1
33        if added_dim:
34            rhs = rhs.reshape(rhs.shape[0], 1)
35        results = da.apply_along_axis(
36            nputils._nanpolyfit_1d,
37            0,
38            rhs,
39            lhs_da,
40            dtype=float,
41            shape=(lhs.shape[1] + 1,),
42            rcond=rcond,
43        )
44        coeffs = results[:-1, ...]
45        residuals = results[-1, ...]
46        if added_dim:
47            coeffs = coeffs.reshape(coeffs.shape[0])
48            residuals = residuals.reshape(residuals.shape[0])
49    else:
50        # Residuals here are (1, 1) but should be (K,) as rhs is (N, K)
51        # See issue dask/dask#6516
52        coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs)
53    return coeffs, residuals
54
55
56def push(array, n, axis):
57    """
58    Dask-aware bottleneck.push
59    """
60    from bottleneck import push
61
62    if len(array.chunks[axis]) > 1 and n is not None and n < array.shape[axis]:
63        raise NotImplementedError(
64            "Cannot fill along a chunked axis when limit is not None."
65            "Either rechunk to a single chunk along this axis or call .compute() or .load() first."
66        )
67    if all(c == 1 for c in array.chunks[axis]):
68        array = array.rechunk({axis: 2})
69    pushed = array.map_blocks(push, axis=axis, n=n, dtype=array.dtype, meta=array._meta)
70    if len(array.chunks[axis]) > 1:
71        pushed = pushed.map_overlap(
72            push,
73            axis=axis,
74            n=n,
75            depth={axis: (1, 0)},
76            boundary="none",
77            dtype=array.dtype,
78            meta=array._meta,
79        )
80    return pushed
81