1from __future__ import annotations
2
3import collections
4import itertools
5import operator
6from typing import (
7    TYPE_CHECKING,
8    Any,
9    Callable,
10    DefaultDict,
11    Dict,
12    Hashable,
13    Iterable,
14    List,
15    Mapping,
16    Sequence,
17    Tuple,
18    Union,
19)
20
21import numpy as np
22
23from .alignment import align
24from .dataarray import DataArray
25from .dataset import Dataset
26
27try:
28    import dask
29    import dask.array
30    from dask.array.utils import meta_from_array
31    from dask.highlevelgraph import HighLevelGraph
32
33except ImportError:
34    pass
35
36
37if TYPE_CHECKING:
38    from .types import T_Xarray
39
40
41def unzip(iterable):
42    return zip(*iterable)
43
44
45def assert_chunks_compatible(a: Dataset, b: Dataset):
46    a = a.unify_chunks()
47    b = b.unify_chunks()
48
49    for dim in set(a.chunks).intersection(set(b.chunks)):
50        if a.chunks[dim] != b.chunks[dim]:
51            raise ValueError(f"Chunk sizes along dimension {dim!r} are not equal.")
52
53
54def check_result_variables(
55    result: Union[DataArray, Dataset], expected: Mapping[str, Any], kind: str
56):
57
58    if kind == "coords":
59        nice_str = "coordinate"
60    elif kind == "data_vars":
61        nice_str = "data"
62
63    # check that coords and data variables are as expected
64    missing = expected[kind] - set(getattr(result, kind))
65    if missing:
66        raise ValueError(
67            "Result from applying user function does not contain "
68            f"{nice_str} variables {missing}."
69        )
70    extra = set(getattr(result, kind)) - expected[kind]
71    if extra:
72        raise ValueError(
73            "Result from applying user function has unexpected "
74            f"{nice_str} variables {extra}."
75        )
76
77
78def dataset_to_dataarray(obj: Dataset) -> DataArray:
79    if not isinstance(obj, Dataset):
80        raise TypeError(f"Expected Dataset, got {type(obj)}")
81
82    if len(obj.data_vars) > 1:
83        raise TypeError(
84            "Trying to convert Dataset with more than one data variable to DataArray"
85        )
86
87    return next(iter(obj.data_vars.values()))
88
89
90def dataarray_to_dataset(obj: DataArray) -> Dataset:
91    # only using _to_temp_dataset would break
92    # func = lambda x: x.to_dataset()
93    # since that relies on preserving name.
94    if obj.name is None:
95        dataset = obj._to_temp_dataset()
96    else:
97        dataset = obj.to_dataset()
98    return dataset
99
100
101def make_meta(obj):
102    """If obj is a DataArray or Dataset, return a new object of the same type and with
103    the same variables and dtypes, but where all variables have size 0 and numpy
104    backend.
105    If obj is neither a DataArray nor Dataset, return it unaltered.
106    """
107    if isinstance(obj, DataArray):
108        obj_array = obj
109        obj = dataarray_to_dataset(obj)
110    elif isinstance(obj, Dataset):
111        obj_array = None
112    else:
113        return obj
114
115    meta = Dataset()
116    for name, variable in obj.variables.items():
117        meta_obj = meta_from_array(variable.data, ndim=variable.ndim)
118        meta[name] = (variable.dims, meta_obj, variable.attrs)
119    meta.attrs = obj.attrs
120    meta = meta.set_coords(obj.coords)
121
122    if obj_array is not None:
123        return dataset_to_dataarray(meta)
124    return meta
125
126
127def infer_template(
128    func: Callable[..., T_Xarray], obj: Union[DataArray, Dataset], *args, **kwargs
129) -> T_Xarray:
130    """Infer return object by running the function on meta objects."""
131    meta_args = [make_meta(arg) for arg in (obj,) + args]
132
133    try:
134        template = func(*meta_args, **kwargs)
135    except Exception as e:
136        raise Exception(
137            "Cannot infer object returned from running user provided function. "
138            "Please supply the 'template' kwarg to map_blocks."
139        ) from e
140
141    if not isinstance(template, (Dataset, DataArray)):
142        raise TypeError(
143            "Function must return an xarray DataArray or Dataset. Instead it returned "
144            f"{type(template)}"
145        )
146
147    return template
148
149
150def make_dict(x: Union[DataArray, Dataset]) -> Dict[Hashable, Any]:
151    """Map variable name to numpy(-like) data
152    (Dataset.to_dict() is too complicated).
153    """
154    if isinstance(x, DataArray):
155        x = x._to_temp_dataset()
156
157    return {k: v.data for k, v in x.variables.items()}
158
159
160def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping):
161    if dim in chunk_index:
162        which_chunk = chunk_index[dim]
163        return slice(chunk_bounds[dim][which_chunk], chunk_bounds[dim][which_chunk + 1])
164    return slice(None)
165
166
167def map_blocks(
168    func: Callable[..., T_Xarray],
169    obj: Union[DataArray, Dataset],
170    args: Sequence[Any] = (),
171    kwargs: Mapping[str, Any] = None,
172    template: Union[DataArray, Dataset] = None,
173) -> T_Xarray:
174    """Apply a function to each block of a DataArray or Dataset.
175
176    .. warning::
177        This function is experimental and its signature may change.
178
179    Parameters
180    ----------
181    func : callable
182        User-provided function that accepts a DataArray or Dataset as its first
183        parameter ``obj``. The function will receive a subset or 'block' of ``obj`` (see below),
184        corresponding to one chunk along each chunked dimension. ``func`` will be
185        executed as ``func(subset_obj, *subset_args, **kwargs)``.
186
187        This function must return either a single DataArray or a single Dataset.
188
189        This function cannot add a new chunked dimension.
190    obj : DataArray, Dataset
191        Passed to the function as its first argument, one block at a time.
192    args : sequence
193        Passed to func after unpacking and subsetting any xarray objects by blocks.
194        xarray objects in args must be aligned with obj, otherwise an error is raised.
195    kwargs : mapping
196        Passed verbatim to func after unpacking. xarray objects, if any, will not be
197        subset to blocks. Passing dask collections in kwargs is not allowed.
198    template : DataArray or Dataset, optional
199        xarray object representing the final result after compute is called. If not provided,
200        the function will be first run on mocked-up data, that looks like ``obj`` but
201        has sizes 0, to determine properties of the returned object such as dtype,
202        variable names, attributes, new dimensions and new indexes (if any).
203        ``template`` must be provided if the function changes the size of existing dimensions.
204        When provided, ``attrs`` on variables in `template` are copied over to the result. Any
205        ``attrs`` set by ``func`` will be ignored.
206
207    Returns
208    -------
209    A single DataArray or Dataset with dask backend, reassembled from the outputs of the
210    function.
211
212    Notes
213    -----
214    This function is designed for when ``func`` needs to manipulate a whole xarray object
215    subset to each block. Each block is loaded into memory. In the more common case where
216    ``func`` can work on numpy arrays, it is recommended to use ``apply_ufunc``.
217
218    If none of the variables in ``obj`` is backed by dask arrays, calling this function is
219    equivalent to calling ``func(obj, *args, **kwargs)``.
220
221    See Also
222    --------
223    dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks
224    xarray.DataArray.map_blocks
225
226    Examples
227    --------
228    Calculate an anomaly from climatology using ``.groupby()``. Using
229    ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``,
230    its indices, and its methods like ``.groupby()``.
231
232    >>> def calculate_anomaly(da, groupby_type="time.month"):
233    ...     gb = da.groupby(groupby_type)
234    ...     clim = gb.mean(dim="time")
235    ...     return gb - clim
236    ...
237    >>> time = xr.cftime_range("1990-01", "1992-01", freq="M")
238    >>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"])
239    >>> np.random.seed(123)
240    >>> array = xr.DataArray(
241    ...     np.random.rand(len(time)),
242    ...     dims=["time"],
243    ...     coords={"time": time, "month": month},
244    ... ).chunk()
245    >>> array.map_blocks(calculate_anomaly, template=array).compute()
246    <xarray.DataArray (time: 24)>
247    array([ 0.12894847,  0.11323072, -0.0855964 , -0.09334032,  0.26848862,
248            0.12382735,  0.22460641,  0.07650108, -0.07673453, -0.22865714,
249           -0.19063865,  0.0590131 , -0.12894847, -0.11323072,  0.0855964 ,
250            0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108,
251            0.07673453,  0.22865714,  0.19063865, -0.0590131 ])
252    Coordinates:
253      * time     (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
254        month    (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12
255
256    Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments
257    to the function being applied in ``xr.map_blocks()``:
258
259    >>> array.map_blocks(
260    ...     calculate_anomaly,
261    ...     kwargs={"groupby_type": "time.year"},
262    ...     template=array,
263    ... )  # doctest: +ELLIPSIS
264    <xarray.DataArray (time: 24)>
265    dask.array<<this-array>-calculate_anomaly, shape=(24,), dtype=float64, chunksize=(24,), chunktype=numpy.ndarray>
266    Coordinates:
267      * time     (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
268        month    (time) int64 dask.array<chunksize=(24,), meta=np.ndarray>
269    """
270
271    def _wrapper(
272        func: Callable,
273        args: List,
274        kwargs: dict,
275        arg_is_array: Iterable[bool],
276        expected: dict,
277    ):
278        """
279        Wrapper function that receives datasets in args; converts to dataarrays when necessary;
280        passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc.
281        """
282
283        converted_args = [
284            dataset_to_dataarray(arg) if is_array else arg
285            for is_array, arg in zip(arg_is_array, args)
286        ]
287
288        result = func(*converted_args, **kwargs)
289
290        # check all dims are present
291        missing_dimensions = set(expected["shapes"]) - set(result.sizes)
292        if missing_dimensions:
293            raise ValueError(
294                f"Dimensions {missing_dimensions} missing on returned object."
295            )
296
297        # check that index lengths and values are as expected
298        for name, index in result.xindexes.items():
299            if name in expected["shapes"]:
300                if result.sizes[name] != expected["shapes"][name]:
301                    raise ValueError(
302                        f"Received dimension {name!r} of length {result.sizes[name]}. "
303                        f"Expected length {expected['shapes'][name]}."
304                    )
305            if name in expected["indexes"]:
306                expected_index = expected["indexes"][name]
307                if not index.equals(expected_index):
308                    raise ValueError(
309                        f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead."
310                    )
311
312        # check that all expected variables were returned
313        check_result_variables(result, expected, "coords")
314        if isinstance(result, Dataset):
315            check_result_variables(result, expected, "data_vars")
316
317        return make_dict(result)
318
319    if template is not None and not isinstance(template, (DataArray, Dataset)):
320        raise TypeError(
321            f"template must be a DataArray or Dataset. Received {type(template).__name__} instead."
322        )
323    if not isinstance(args, Sequence):
324        raise TypeError("args must be a sequence (for example, a list or tuple).")
325    if kwargs is None:
326        kwargs = {}
327    elif not isinstance(kwargs, Mapping):
328        raise TypeError("kwargs must be a mapping (for example, a dict)")
329
330    for value in kwargs.values():
331        if dask.is_dask_collection(value):
332            raise TypeError(
333                "Cannot pass dask collections in kwargs yet. Please compute or "
334                "load values before passing to map_blocks."
335            )
336
337    if not dask.is_dask_collection(obj):
338        return func(obj, *args, **kwargs)
339
340    all_args = [obj] + list(args)
341    is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in all_args]
342    is_array = [isinstance(arg, DataArray) for arg in all_args]
343
344    # there should be a better way to group this. partition?
345    xarray_indices, xarray_objs = unzip(
346        (index, arg) for index, arg in enumerate(all_args) if is_xarray[index]
347    )
348    others = [
349        (index, arg) for index, arg in enumerate(all_args) if not is_xarray[index]
350    ]
351
352    # all xarray objects must be aligned. This is consistent with apply_ufunc.
353    aligned = align(*xarray_objs, join="exact")
354    xarray_objs = tuple(
355        dataarray_to_dataset(arg) if is_da else arg
356        for is_da, arg in zip(is_array, aligned)
357    )
358
359    _, npargs = unzip(
360        sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0])
361    )
362
363    # check that chunk sizes are compatible
364    input_chunks = dict(npargs[0].chunks)
365    input_indexes = dict(npargs[0].xindexes)
366    for arg in xarray_objs[1:]:
367        assert_chunks_compatible(npargs[0], arg)
368        input_chunks.update(arg.chunks)
369        input_indexes.update(arg.xindexes)
370
371    if template is None:
372        # infer template by providing zero-shaped arrays
373        template = infer_template(func, aligned[0], *args, **kwargs)
374        template_indexes = set(template.xindexes)
375        preserved_indexes = template_indexes & set(input_indexes)
376        new_indexes = template_indexes - set(input_indexes)
377        indexes = {dim: input_indexes[dim] for dim in preserved_indexes}
378        indexes.update({k: template.xindexes[k] for k in new_indexes})
379        output_chunks = {
380            dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
381        }
382
383    else:
384        # template xarray object has been provided with proper sizes and chunk shapes
385        indexes = dict(template.xindexes)
386        if isinstance(template, DataArray):
387            output_chunks = dict(
388                zip(template.dims, template.chunks)  # type: ignore[arg-type]
389            )
390        else:
391            output_chunks = dict(template.chunks)
392
393    for dim in output_chunks:
394        if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]):
395            raise ValueError(
396                "map_blocks requires that one block of the input maps to one block of output. "
397                f"Expected number of output chunks along dimension {dim!r} to be {len(input_chunks[dim])}. "
398                f"Received {len(output_chunks[dim])} instead. Please provide template if not provided, or "
399                "fix the provided template."
400            )
401
402    if isinstance(template, DataArray):
403        result_is_array = True
404        template_name = template.name
405        template = template._to_temp_dataset()
406    elif isinstance(template, Dataset):
407        result_is_array = False
408    else:
409        raise TypeError(
410            f"func output must be DataArray or Dataset; got {type(template)}"
411        )
412
413    # We're building a new HighLevelGraph hlg. We'll have one new layer
414    # for each variable in the dataset, which is the result of the
415    # func applied to the values.
416
417    graph: Dict[Any, Any] = {}
418    new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict)
419    gname = "{}-{}".format(
420        dask.utils.funcname(func), dask.base.tokenize(npargs[0], args, kwargs)
421    )
422
423    # map dims to list of chunk indexes
424    ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()}
425    # mapping from chunk index to slice bounds
426    input_chunk_bounds = {
427        dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items()
428    }
429    output_chunk_bounds = {
430        dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items()
431    }
432
433    def subset_dataset_to_block(
434        graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index
435    ):
436        """
437        Creates a task that subsets an xarray dataset to a block determined by chunk_index.
438        Block extents are determined by input_chunk_bounds.
439        Also subtasks that subset the constituent variables of a dataset.
440        """
441
442        # this will become [[name1, variable1],
443        #                   [name2, variable2],
444        #                   ...]
445        # which is passed to dict and then to Dataset
446        data_vars = []
447        coords = []
448
449        chunk_tuple = tuple(chunk_index.values())
450        for name, variable in dataset.variables.items():
451            # make a task that creates tuple of (dims, chunk)
452            if dask.is_dask_collection(variable.data):
453                # recursively index into dask_keys nested list to get chunk
454                chunk = variable.__dask_keys__()
455                for dim in variable.dims:
456                    chunk = chunk[chunk_index[dim]]
457
458                chunk_variable_task = (f"{name}-{gname}-{chunk[0]}",) + chunk_tuple
459                graph[chunk_variable_task] = (
460                    tuple,
461                    [variable.dims, chunk, variable.attrs],
462                )
463            else:
464                # non-dask array possibly with dimensions chunked on other variables
465                # index into variable appropriately
466                subsetter = {
467                    dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds)
468                    for dim in variable.dims
469                }
470                subset = variable.isel(subsetter)
471                chunk_variable_task = (
472                    f"{name}-{gname}-{dask.base.tokenize(subset)}",
473                ) + chunk_tuple
474                graph[chunk_variable_task] = (
475                    tuple,
476                    [subset.dims, subset, subset.attrs],
477                )
478
479            # this task creates dict mapping variable name to above tuple
480            if name in dataset._coord_names:
481                coords.append([name, chunk_variable_task])
482            else:
483                data_vars.append([name, chunk_variable_task])
484
485        return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs)
486
487    # iterate over all possible chunk combinations
488    for chunk_tuple in itertools.product(*ichunk.values()):
489        # mapping from dimension name to chunk index
490        chunk_index = dict(zip(ichunk.keys(), chunk_tuple))
491
492        blocked_args = [
493            subset_dataset_to_block(graph, gname, arg, input_chunk_bounds, chunk_index)
494            if isxr
495            else arg
496            for isxr, arg in zip(is_xarray, npargs)
497        ]
498
499        # expected["shapes", "coords", "data_vars", "indexes"] are used to
500        # raise nice error messages in _wrapper
501        expected = {}
502        # input chunk 0 along a dimension maps to output chunk 0 along the same dimension
503        # even if length of dimension is changed by the applied function
504        expected["shapes"] = {
505            k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks
506        }
507        expected["data_vars"] = set(template.data_vars.keys())  # type: ignore[assignment]
508        expected["coords"] = set(template.coords.keys())  # type: ignore[assignment]
509        expected["indexes"] = {
510            dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)]
511            for dim in indexes
512        }
513
514        from_wrapper = (gname,) + chunk_tuple
515        graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)
516
517        # mapping from variable name to dask graph key
518        var_key_map: Dict[Hashable, str] = {}
519        for name, variable in template.variables.items():
520            if name in indexes:
521                continue
522            gname_l = f"{name}-{gname}"
523            var_key_map[name] = gname_l
524
525            key: Tuple[Any, ...] = (gname_l,)
526            for dim in variable.dims:
527                if dim in chunk_index:
528                    key += (chunk_index[dim],)
529                else:
530                    # unchunked dimensions in the input have one chunk in the result
531                    # output can have new dimensions with exactly one chunk
532                    key += (0,)
533
534            # We're adding multiple new layers to the graph:
535            # The first new layer is the result of the computation on
536            # the array.
537            # Then we add one layer per variable, which extracts the
538            # result for that variable, and depends on just the first new
539            # layer.
540            new_layers[gname_l][key] = (operator.getitem, from_wrapper, name)
541
542    hlg = HighLevelGraph.from_collections(
543        gname,
544        graph,
545        dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)],
546    )
547
548    # This adds in the getitems for each variable in the dataset.
549    hlg = HighLevelGraph(
550        {**hlg.layers, **new_layers},
551        dependencies={
552            **hlg.dependencies,
553            **{name: {gname} for name in new_layers.keys()},
554        },
555    )
556
557    # TODO: benbovy - flexible indexes: make it work with custom indexes
558    # this will need to pass both indexes and coords to the Dataset constructor
559    result = Dataset(
560        coords={k: idx.to_pandas_index() for k, idx in indexes.items()},
561        attrs=template.attrs,
562    )
563
564    for index in result.xindexes:
565        result[index].attrs = template[index].attrs
566        result[index].encoding = template[index].encoding
567
568    for name, gname_l in var_key_map.items():
569        dims = template[name].dims
570        var_chunks = []
571        for dim in dims:
572            if dim in output_chunks:
573                var_chunks.append(output_chunks[dim])
574            elif dim in result.xindexes:
575                var_chunks.append((result.sizes[dim],))
576            elif dim in template.dims:
577                # new unindexed dimension
578                var_chunks.append((template.sizes[dim],))
579
580        data = dask.array.Array(
581            hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
582        )
583        result[name] = (dims, data, template[name].attrs)
584        result[name].encoding = template[name].encoding
585
586    result = result.set_coords(template._coord_names)
587
588    if result_is_array:
589        da = dataset_to_dataarray(result)
590        da.name = template_name
591        return da  # type: ignore[return-value]
592    return result  # type: ignore[return-value]
593