1from . import dtypes, nputils 2 3 4def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): 5 """Wrapper to apply bottleneck moving window funcs on dask arrays""" 6 import dask.array as da 7 8 dtype, fill_value = dtypes.maybe_promote(a.dtype) 9 a = a.astype(dtype) 10 # inputs for overlap 11 if axis < 0: 12 axis = a.ndim + axis 13 depth = {d: 0 for d in range(a.ndim)} 14 depth[axis] = (window + 1) // 2 15 boundary = {d: fill_value for d in range(a.ndim)} 16 # Create overlap array. 17 ag = da.overlap.overlap(a, depth=depth, boundary=boundary) 18 # apply rolling func 19 out = da.map_blocks( 20 moving_func, ag, window, min_count=min_count, axis=axis, dtype=a.dtype 21 ) 22 # trim array 23 result = da.overlap.trim_internal(out, depth) 24 return result 25 26 27def least_squares(lhs, rhs, rcond=None, skipna=False): 28 import dask.array as da 29 30 lhs_da = da.from_array(lhs, chunks=(rhs.chunks[0], lhs.shape[1])) 31 if skipna: 32 added_dim = rhs.ndim == 1 33 if added_dim: 34 rhs = rhs.reshape(rhs.shape[0], 1) 35 results = da.apply_along_axis( 36 nputils._nanpolyfit_1d, 37 0, 38 rhs, 39 lhs_da, 40 dtype=float, 41 shape=(lhs.shape[1] + 1,), 42 rcond=rcond, 43 ) 44 coeffs = results[:-1, ...] 45 residuals = results[-1, ...] 46 if added_dim: 47 coeffs = coeffs.reshape(coeffs.shape[0]) 48 residuals = residuals.reshape(residuals.shape[0]) 49 else: 50 # Residuals here are (1, 1) but should be (K,) as rhs is (N, K) 51 # See issue dask/dask#6516 52 coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs) 53 return coeffs, residuals 54 55 56def push(array, n, axis): 57 """ 58 Dask-aware bottleneck.push 59 """ 60 from bottleneck import push 61 62 if len(array.chunks[axis]) > 1 and n is not None and n < array.shape[axis]: 63 raise NotImplementedError( 64 "Cannot fill along a chunked axis when limit is not None." 65 "Either rechunk to a single chunk along this axis or call .compute() or .load() first." 66 ) 67 if all(c == 1 for c in array.chunks[axis]): 68 array = array.rechunk({axis: 2}) 69 pushed = array.map_blocks(push, axis=axis, n=n, dtype=array.dtype, meta=array._meta) 70 if len(array.chunks[axis]) > 1: 71 pushed = pushed.map_overlap( 72 push, 73 axis=axis, 74 n=n, 75 depth={axis: (1, 0)}, 76 boundary="none", 77 dtype=array.dtype, 78 meta=array._meta, 79 ) 80 return pushed 81