1import datetime
2import warnings
3
4import numpy as np
5import pandas as pd
6
7from . import dtypes, duck_array_ops, nputils, ops
8from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic
9from .concat import concat
10from .formatting import format_array_flat
11from .indexes import propagate_indexes
12from .options import _get_keep_attrs
13from .pycompat import integer_types
14from .utils import (
15    either_dict_or_kwargs,
16    hashable,
17    is_scalar,
18    maybe_wrap_array,
19    peek_at,
20    safe_cast_to_index,
21)
22from .variable import IndexVariable, Variable, as_variable
23
24
25def check_reduce_dims(reduce_dims, dimensions):
26
27    if reduce_dims is not ...:
28        if is_scalar(reduce_dims):
29            reduce_dims = [reduce_dims]
30        if any(dim not in dimensions for dim in reduce_dims):
31            raise ValueError(
32                f"cannot reduce over dimensions {reduce_dims!r}. expected either '...' "
33                f"to reduce over all dimensions or one or more of {dimensions!r}."
34            )
35
36
37def unique_value_groups(ar, sort=True):
38    """Group an array by its unique values.
39
40    Parameters
41    ----------
42    ar : array-like
43        Input array. This will be flattened if it is not already 1-D.
44    sort : bool, optional
45        Whether or not to sort unique values.
46
47    Returns
48    -------
49    values : np.ndarray
50        Sorted, unique values as returned by `np.unique`.
51    indices : list of lists of int
52        Each element provides the integer indices in `ar` with values given by
53        the corresponding value in `unique_values`.
54    """
55    inverse, values = pd.factorize(ar, sort=sort)
56    groups = [[] for _ in range(len(values))]
57    for n, g in enumerate(inverse):
58        if g >= 0:
59            # pandas uses -1 to mark NaN, but doesn't include them in values
60            groups[g].append(n)
61    return values, groups
62
63
64def _dummy_copy(xarray_obj):
65    from .dataarray import DataArray
66    from .dataset import Dataset
67
68    if isinstance(xarray_obj, Dataset):
69        res = Dataset(
70            {
71                k: dtypes.get_fill_value(v.dtype)
72                for k, v in xarray_obj.data_vars.items()
73            },
74            {
75                k: dtypes.get_fill_value(v.dtype)
76                for k, v in xarray_obj.coords.items()
77                if k not in xarray_obj.dims
78            },
79            xarray_obj.attrs,
80        )
81    elif isinstance(xarray_obj, DataArray):
82        res = DataArray(
83            dtypes.get_fill_value(xarray_obj.dtype),
84            {
85                k: dtypes.get_fill_value(v.dtype)
86                for k, v in xarray_obj.coords.items()
87                if k not in xarray_obj.dims
88            },
89            dims=[],
90            name=xarray_obj.name,
91            attrs=xarray_obj.attrs,
92        )
93    else:  # pragma: no cover
94        raise AssertionError
95    return res
96
97
98def _is_one_or_none(obj):
99    return obj == 1 or obj is None
100
101
102def _consolidate_slices(slices):
103    """Consolidate adjacent slices in a list of slices."""
104    result = []
105    last_slice = slice(None)
106    for slice_ in slices:
107        if not isinstance(slice_, slice):
108            raise ValueError(f"list element is not a slice: {slice_!r}")
109        if (
110            result
111            and last_slice.stop == slice_.start
112            and _is_one_or_none(last_slice.step)
113            and _is_one_or_none(slice_.step)
114        ):
115            last_slice = slice(last_slice.start, slice_.stop, slice_.step)
116            result[-1] = last_slice
117        else:
118            result.append(slice_)
119            last_slice = slice_
120    return result
121
122
123def _inverse_permutation_indices(positions):
124    """Like inverse_permutation, but also handles slices.
125
126    Parameters
127    ----------
128    positions : list of ndarray or slice
129        If slice objects, all are assumed to be slices.
130
131    Returns
132    -------
133    np.ndarray of indices or None, if no permutation is necessary.
134    """
135    if not positions:
136        return None
137
138    if isinstance(positions[0], slice):
139        positions = _consolidate_slices(positions)
140        if positions == slice(None):
141            return None
142        positions = [np.arange(sl.start, sl.stop, sl.step) for sl in positions]
143
144    return nputils.inverse_permutation(np.concatenate(positions))
145
146
147class _DummyGroup:
148    """Class for keeping track of grouped dimensions without coordinates.
149
150    Should not be user visible.
151    """
152
153    __slots__ = ("name", "coords", "size")
154
155    def __init__(self, obj, name, coords):
156        self.name = name
157        self.coords = coords
158        self.size = obj.sizes[name]
159
160    @property
161    def dims(self):
162        return (self.name,)
163
164    @property
165    def ndim(self):
166        return 1
167
168    @property
169    def values(self):
170        return range(self.size)
171
172    @property
173    def shape(self):
174        return (self.size,)
175
176    def __getitem__(self, key):
177        if isinstance(key, tuple):
178            key = key[0]
179        return self.values[key]
180
181
182def _ensure_1d(group, obj):
183    if group.ndim != 1:
184        # try to stack the dims of the group into a single dim
185        orig_dims = group.dims
186        stacked_dim = "stacked_" + "_".join(orig_dims)
187        # these dimensions get created by the stack operation
188        inserted_dims = [dim for dim in group.dims if dim not in group.coords]
189        # the copy is necessary here, otherwise read only array raises error
190        # in pandas: https://github.com/pydata/pandas/issues/12813
191        group = group.stack(**{stacked_dim: orig_dims}).copy()
192        obj = obj.stack(**{stacked_dim: orig_dims})
193    else:
194        stacked_dim = None
195        inserted_dims = []
196    return group, obj, stacked_dim, inserted_dims
197
198
199def _unique_and_monotonic(group):
200    if isinstance(group, _DummyGroup):
201        return True
202    index = safe_cast_to_index(group)
203    return index.is_unique and index.is_monotonic
204
205
206def _apply_loffset(grouper, result):
207    """
208    (copied from pandas)
209    if loffset is set, offset the result index
210
211    This is NOT an idempotent routine, it will be applied
212    exactly once to the result.
213
214    Parameters
215    ----------
216    result : Series or DataFrame
217        the result of resample
218    """
219
220    needs_offset = (
221        isinstance(grouper.loffset, (pd.DateOffset, datetime.timedelta))
222        and isinstance(result.index, pd.DatetimeIndex)
223        and len(result.index) > 0
224    )
225
226    if needs_offset:
227        result.index = result.index + grouper.loffset
228
229    grouper.loffset = None
230
231
232class GroupBy:
233    """A object that implements the split-apply-combine pattern.
234
235    Modeled after `pandas.GroupBy`. The `GroupBy` object can be iterated over
236    (unique_value, grouped_array) pairs, but the main way to interact with a
237    groupby object are with the `apply` or `reduce` methods. You can also
238    directly call numpy methods like `mean` or `std`.
239
240    You should create a GroupBy object by using the `DataArray.groupby` or
241    `Dataset.groupby` methods.
242
243    See Also
244    --------
245    Dataset.groupby
246    DataArray.groupby
247    """
248
249    __slots__ = (
250        "_full_index",
251        "_inserted_dims",
252        "_group",
253        "_group_dim",
254        "_group_indices",
255        "_groups",
256        "_obj",
257        "_restore_coord_dims",
258        "_stacked_dim",
259        "_unique_coord",
260        "_dims",
261    )
262
263    def __init__(
264        self,
265        obj,
266        group,
267        squeeze=False,
268        grouper=None,
269        bins=None,
270        restore_coord_dims=True,
271        cut_kwargs=None,
272    ):
273        """Create a GroupBy object
274
275        Parameters
276        ----------
277        obj : Dataset or DataArray
278            Object to group.
279        group : DataArray
280            Array with the group values.
281        squeeze : bool, optional
282            If "group" is a coordinate of object, `squeeze` controls whether
283            the subarrays have a dimension of length 1 along that coordinate or
284            if the dimension is squeezed out.
285        grouper : pandas.Grouper, optional
286            Used for grouping values along the `group` array.
287        bins : array-like, optional
288            If `bins` is specified, the groups will be discretized into the
289            specified bins by `pandas.cut`.
290        restore_coord_dims : bool, default: True
291            If True, also restore the dimension order of multi-dimensional
292            coordinates.
293        cut_kwargs : dict, optional
294            Extra keyword arguments to pass to `pandas.cut`
295
296        """
297        if cut_kwargs is None:
298            cut_kwargs = {}
299        from .dataarray import DataArray
300
301        if grouper is not None and bins is not None:
302            raise TypeError("can't specify both `grouper` and `bins`")
303
304        if not isinstance(group, (DataArray, IndexVariable)):
305            if not hashable(group):
306                raise TypeError(
307                    "`group` must be an xarray.DataArray or the "
308                    "name of an xarray variable or dimension."
309                    f"Received {group!r} instead."
310                )
311            group = obj[group]
312            if len(group) == 0:
313                raise ValueError(f"{group.name} must not be empty")
314
315            if group.name not in obj.coords and group.name in obj.dims:
316                # DummyGroups should not appear on groupby results
317                group = _DummyGroup(obj, group.name, group.coords)
318
319        if getattr(group, "name", None) is None:
320            group.name = "group"
321
322        group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj)
323        (group_dim,) = group.dims
324
325        expected_size = obj.sizes[group_dim]
326        if group.size != expected_size:
327            raise ValueError(
328                "the group variable's length does not "
329                "match the length of this variable along its "
330                "dimension"
331            )
332
333        full_index = None
334
335        if bins is not None:
336            if duck_array_ops.isnull(bins).all():
337                raise ValueError("All bin edges are NaN.")
338            binned = pd.cut(group.values, bins, **cut_kwargs)
339            new_dim_name = group.name + "_bins"
340            group = DataArray(binned, group.coords, name=new_dim_name)
341            full_index = binned.categories
342
343        if grouper is not None:
344            index = safe_cast_to_index(group)
345            if not index.is_monotonic:
346                # TODO: sort instead of raising an error
347                raise ValueError("index must be monotonic for resampling")
348            full_index, first_items = self._get_index_and_items(index, grouper)
349            sbins = first_items.values.astype(np.int64)
350            group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + [
351                slice(sbins[-1], None)
352            ]
353            unique_coord = IndexVariable(group.name, first_items.index)
354        elif group.dims == (group.name,) and _unique_and_monotonic(group):
355            # no need to factorize
356            group_indices = np.arange(group.size)
357            if not squeeze:
358                # use slices to do views instead of fancy indexing
359                # equivalent to: group_indices = group_indices.reshape(-1, 1)
360                group_indices = [slice(i, i + 1) for i in group_indices]
361            unique_coord = group
362        else:
363            if group.isnull().any():
364                # drop any NaN valued groups.
365                # also drop obj values where group was NaN
366                # Use where instead of reindex to account for duplicate coordinate labels.
367                obj = obj.where(group.notnull(), drop=True)
368                group = group.dropna(group_dim)
369
370            # look through group to find the unique values
371            group_as_index = safe_cast_to_index(group)
372            sort = bins is None and (not isinstance(group_as_index, pd.MultiIndex))
373            unique_values, group_indices = unique_value_groups(
374                group_as_index, sort=sort
375            )
376            unique_coord = IndexVariable(group.name, unique_values)
377
378        if len(group_indices) == 0:
379            if bins is not None:
380                raise ValueError(
381                    f"None of the data falls within bins with edges {bins!r}"
382                )
383            else:
384                raise ValueError(
385                    "Failed to group data. Are you grouping by a variable that is all NaN?"
386                )
387
388        # specification for the groupby operation
389        self._obj = obj
390        self._group = group
391        self._group_dim = group_dim
392        self._group_indices = group_indices
393        self._unique_coord = unique_coord
394        self._stacked_dim = stacked_dim
395        self._inserted_dims = inserted_dims
396        self._full_index = full_index
397        self._restore_coord_dims = restore_coord_dims
398
399        # cached attributes
400        self._groups = None
401        self._dims = None
402
403    @property
404    def dims(self):
405        if self._dims is None:
406            self._dims = self._obj.isel(
407                **{self._group_dim: self._group_indices[0]}
408            ).dims
409
410        return self._dims
411
412    @property
413    def groups(self):
414        """
415        Mapping from group labels to indices. The indices can be used to index the underlying object.
416        """
417        # provided to mimic pandas.groupby
418        if self._groups is None:
419            self._groups = dict(zip(self._unique_coord.values, self._group_indices))
420        return self._groups
421
422    def __getitem__(self, key):
423        """
424        Get DataArray or Dataset corresponding to a particular group label.
425        """
426        return self._obj.isel({self._group_dim: self.groups[key]})
427
428    def __len__(self):
429        return self._unique_coord.size
430
431    def __iter__(self):
432        return zip(self._unique_coord.values, self._iter_grouped())
433
434    def __repr__(self):
435        return "{}, grouped over {!r}\n{!r} groups with labels {}.".format(
436            self.__class__.__name__,
437            self._unique_coord.name,
438            self._unique_coord.size,
439            ", ".join(format_array_flat(self._unique_coord, 30).split()),
440        )
441
442    def _get_index_and_items(self, index, grouper):
443        from .resample_cftime import CFTimeGrouper
444
445        s = pd.Series(np.arange(index.size), index)
446        if isinstance(grouper, CFTimeGrouper):
447            first_items = grouper.first_items(index)
448        else:
449            first_items = s.groupby(grouper).first()
450            _apply_loffset(grouper, first_items)
451        full_index = first_items.index
452        if first_items.isnull().any():
453            first_items = first_items.dropna()
454        return full_index, first_items
455
456    def _iter_grouped(self):
457        """Iterate over each element in this group"""
458        for indices in self._group_indices:
459            yield self._obj.isel(**{self._group_dim: indices})
460
461    def _infer_concat_args(self, applied_example):
462        if self._group_dim in applied_example.dims:
463            coord = self._group
464            positions = self._group_indices
465        else:
466            coord = self._unique_coord
467            positions = None
468        (dim,) = coord.dims
469        if isinstance(coord, _DummyGroup):
470            coord = None
471        return coord, dim, positions
472
473    def _binary_op(self, other, f, reflexive=False):
474        g = f if not reflexive else lambda x, y: f(y, x)
475        applied = self._yield_binary_applied(g, other)
476        return self._combine(applied)
477
478    def _yield_binary_applied(self, func, other):
479        dummy = None
480
481        for group_value, obj in self:
482            try:
483                other_sel = other.sel(**{self._group.name: group_value})
484            except AttributeError:
485                raise TypeError(
486                    "GroupBy objects only support binary ops "
487                    "when the other argument is a Dataset or "
488                    "DataArray"
489                )
490            except (KeyError, ValueError):
491                if self._group.name not in other.dims:
492                    raise ValueError(
493                        "incompatible dimensions for a grouped "
494                        f"binary operation: the group variable {self._group.name!r} "
495                        "is not a dimension on the other argument"
496                    )
497                if dummy is None:
498                    dummy = _dummy_copy(other)
499                other_sel = dummy
500
501            result = func(obj, other_sel)
502            yield result
503
504    def _maybe_restore_empty_groups(self, combined):
505        """Our index contained empty groups (e.g., from a resampling). If we
506        reduced on that dimension, we want to restore the full index.
507        """
508        if self._full_index is not None and self._group.name in combined.dims:
509            indexers = {self._group.name: self._full_index}
510            combined = combined.reindex(**indexers)
511        return combined
512
513    def _maybe_unstack(self, obj):
514        """This gets called if we are applying on an array with a
515        multidimensional group."""
516        if self._stacked_dim is not None and self._stacked_dim in obj.dims:
517            obj = obj.unstack(self._stacked_dim)
518            for dim in self._inserted_dims:
519                if dim in obj.coords:
520                    del obj.coords[dim]
521            obj._indexes = propagate_indexes(obj._indexes, exclude=self._inserted_dims)
522        return obj
523
524    def fillna(self, value):
525        """Fill missing values in this object by group.
526
527        This operation follows the normal broadcasting and alignment rules that
528        xarray uses for binary arithmetic, except the result is aligned to this
529        object (``join='left'``) instead of aligned to the intersection of
530        index coordinates (``join='inner'``).
531
532        Parameters
533        ----------
534        value
535            Used to fill all matching missing values by group. Needs
536            to be of a valid type for the wrapped object's fillna
537            method.
538
539        Returns
540        -------
541        same type as the grouped object
542
543        See Also
544        --------
545        Dataset.fillna
546        DataArray.fillna
547        """
548        return ops.fillna(self, value)
549
550    def quantile(
551        self, q, dim=None, interpolation="linear", keep_attrs=None, skipna=True
552    ):
553        """Compute the qth quantile over each array in the groups and
554        concatenate them together into a new array.
555
556        Parameters
557        ----------
558        q : float or sequence of float
559            Quantile to compute, which must be between 0 and 1
560            inclusive.
561        dim : ..., str or sequence of str, optional
562            Dimension(s) over which to apply quantile.
563            Defaults to the grouped dimension.
564        interpolation : {"linear", "lower", "higher", "midpoint", "nearest"}, default: "linear"
565            This optional parameter specifies the interpolation method to
566            use when the desired quantile lies between two data points
567            ``i < j``:
568
569                * linear: ``i + (j - i) * fraction``, where ``fraction`` is
570                  the fractional part of the index surrounded by ``i`` and
571                  ``j``.
572                * lower: ``i``.
573                * higher: ``j``.
574                * nearest: ``i`` or ``j``, whichever is nearest.
575                * midpoint: ``(i + j) / 2``.
576        skipna : bool, optional
577            Whether to skip missing values when aggregating.
578
579        Returns
580        -------
581        quantiles : Variable
582            If `q` is a single quantile, then the result is a
583            scalar. If multiple percentiles are given, first axis of
584            the result corresponds to the quantile. In either case a
585            quantile dimension is added to the return array. The other
586            dimensions are the dimensions that remain after the
587            reduction of the array.
588
589        See Also
590        --------
591        numpy.nanquantile, numpy.quantile, pandas.Series.quantile, Dataset.quantile
592        DataArray.quantile
593
594        Examples
595        --------
596        >>> da = xr.DataArray(
597        ...     [[1.3, 8.4, 0.7, 6.9], [0.7, 4.2, 9.4, 1.5], [6.5, 7.3, 2.6, 1.9]],
598        ...     coords={"x": [0, 0, 1], "y": [1, 1, 2, 2]},
599        ...     dims=("x", "y"),
600        ... )
601        >>> ds = xr.Dataset({"a": da})
602        >>> da.groupby("x").quantile(0)
603        <xarray.DataArray (x: 2, y: 4)>
604        array([[0.7, 4.2, 0.7, 1.5],
605               [6.5, 7.3, 2.6, 1.9]])
606        Coordinates:
607          * y         (y) int64 1 1 2 2
608            quantile  float64 0.0
609          * x         (x) int64 0 1
610        >>> ds.groupby("y").quantile(0, dim=...)
611        <xarray.Dataset>
612        Dimensions:   (y: 2)
613        Coordinates:
614            quantile  float64 0.0
615          * y         (y) int64 1 2
616        Data variables:
617            a         (y) float64 0.7 0.7
618        >>> da.groupby("x").quantile([0, 0.5, 1])
619        <xarray.DataArray (x: 2, y: 4, quantile: 3)>
620        array([[[0.7 , 1.  , 1.3 ],
621                [4.2 , 6.3 , 8.4 ],
622                [0.7 , 5.05, 9.4 ],
623                [1.5 , 4.2 , 6.9 ]],
624        <BLANKLINE>
625               [[6.5 , 6.5 , 6.5 ],
626                [7.3 , 7.3 , 7.3 ],
627                [2.6 , 2.6 , 2.6 ],
628                [1.9 , 1.9 , 1.9 ]]])
629        Coordinates:
630          * y         (y) int64 1 1 2 2
631          * quantile  (quantile) float64 0.0 0.5 1.0
632          * x         (x) int64 0 1
633        >>> ds.groupby("y").quantile([0, 0.5, 1], dim=...)
634        <xarray.Dataset>
635        Dimensions:   (y: 2, quantile: 3)
636        Coordinates:
637          * quantile  (quantile) float64 0.0 0.5 1.0
638          * y         (y) int64 1 2
639        Data variables:
640            a         (y, quantile) float64 0.7 5.35 8.4 0.7 2.25 9.4
641        """
642        if dim is None:
643            dim = self._group_dim
644
645        out = self.map(
646            self._obj.__class__.quantile,
647            shortcut=False,
648            q=q,
649            dim=dim,
650            interpolation=interpolation,
651            keep_attrs=keep_attrs,
652            skipna=skipna,
653        )
654        return out
655
656    def where(self, cond, other=dtypes.NA):
657        """Return elements from `self` or `other` depending on `cond`.
658
659        Parameters
660        ----------
661        cond : DataArray or Dataset
662            Locations at which to preserve this objects values. dtypes have to be `bool`
663        other : scalar, DataArray or Dataset, optional
664            Value to use for locations in this object where ``cond`` is False.
665            By default, inserts missing values.
666
667        Returns
668        -------
669        same type as the grouped object
670
671        See Also
672        --------
673        Dataset.where
674        """
675        return ops.where_method(self, cond, other)
676
677    def _first_or_last(self, op, skipna, keep_attrs):
678        if isinstance(self._group_indices[0], integer_types):
679            # NB. this is currently only used for reductions along an existing
680            # dimension
681            return self._obj
682        if keep_attrs is None:
683            keep_attrs = _get_keep_attrs(default=True)
684        return self.reduce(op, self._group_dim, skipna=skipna, keep_attrs=keep_attrs)
685
686    def first(self, skipna=None, keep_attrs=None):
687        """Return the first element of each group along the group dimension"""
688        return self._first_or_last(duck_array_ops.first, skipna, keep_attrs)
689
690    def last(self, skipna=None, keep_attrs=None):
691        """Return the last element of each group along the group dimension"""
692        return self._first_or_last(duck_array_ops.last, skipna, keep_attrs)
693
694    def assign_coords(self, coords=None, **coords_kwargs):
695        """Assign coordinates by group.
696
697        See Also
698        --------
699        Dataset.assign_coords
700        Dataset.swap_dims
701        """
702        coords_kwargs = either_dict_or_kwargs(coords, coords_kwargs, "assign_coords")
703        return self.map(lambda ds: ds.assign_coords(**coords_kwargs))
704
705
706def _maybe_reorder(xarray_obj, dim, positions):
707    order = _inverse_permutation_indices(positions)
708
709    if order is None or len(order) != xarray_obj.sizes[dim]:
710        return xarray_obj
711    else:
712        return xarray_obj[{dim: order}]
713
714
715class DataArrayGroupBy(GroupBy, DataArrayGroupbyArithmetic):
716    """GroupBy object specialized to grouping DataArray objects"""
717
718    __slots__ = ()
719
720    def _iter_grouped_shortcut(self):
721        """Fast version of `_iter_grouped` that yields Variables without
722        metadata
723        """
724        var = self._obj.variable
725        for indices in self._group_indices:
726            yield var[{self._group_dim: indices}]
727
728    def _concat_shortcut(self, applied, dim, positions=None):
729        # nb. don't worry too much about maintaining this method -- it does
730        # speed things up, but it's not very interpretable and there are much
731        # faster alternatives (e.g., doing the grouped aggregation in a
732        # compiled language)
733        stacked = Variable.concat(applied, dim, shortcut=True)
734        reordered = _maybe_reorder(stacked, dim, positions)
735        return self._obj._replace_maybe_drop_dims(reordered)
736
737    def _restore_dim_order(self, stacked):
738        def lookup_order(dimension):
739            if dimension == self._group.name:
740                (dimension,) = self._group.dims
741            if dimension in self._obj.dims:
742                axis = self._obj.get_axis_num(dimension)
743            else:
744                axis = 1e6  # some arbitrarily high value
745            return axis
746
747        new_order = sorted(stacked.dims, key=lookup_order)
748        return stacked.transpose(*new_order, transpose_coords=self._restore_coord_dims)
749
750    def map(self, func, shortcut=False, args=(), **kwargs):
751        """Apply a function to each array in the group and concatenate them
752        together into a new array.
753
754        `func` is called like `func(ar, *args, **kwargs)` for each array `ar`
755        in this group.
756
757        Apply uses heuristics (like `pandas.GroupBy.apply`) to figure out how
758        to stack together the array. The rule is:
759
760        1. If the dimension along which the group coordinate is defined is
761           still in the first grouped array after applying `func`, then stack
762           over this dimension.
763        2. Otherwise, stack over the new dimension given by name of this
764           grouping (the argument to the `groupby` function).
765
766        Parameters
767        ----------
768        func : callable
769            Callable to apply to each array.
770        shortcut : bool, optional
771            Whether or not to shortcut evaluation under the assumptions that:
772
773            (1) The action of `func` does not depend on any of the array
774                metadata (attributes or coordinates) but only on the data and
775                dimensions.
776            (2) The action of `func` creates arrays with homogeneous metadata,
777                that is, with the same dimensions and attributes.
778
779            If these conditions are satisfied `shortcut` provides significant
780            speedup. This should be the case for many common groupby operations
781            (e.g., applying numpy ufuncs).
782        *args : tuple, optional
783            Positional arguments passed to `func`.
784        **kwargs
785            Used to call `func(ar, **kwargs)` for each array `ar`.
786
787        Returns
788        -------
789        applied : DataArray or DataArray
790            The result of splitting, applying and combining this array.
791        """
792        grouped = self._iter_grouped_shortcut() if shortcut else self._iter_grouped()
793        applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped)
794        return self._combine(applied, shortcut=shortcut)
795
796    def apply(self, func, shortcut=False, args=(), **kwargs):
797        """
798        Backward compatible implementation of ``map``
799
800        See Also
801        --------
802        DataArrayGroupBy.map
803        """
804        warnings.warn(
805            "GroupBy.apply may be deprecated in the future. Using GroupBy.map is encouraged",
806            PendingDeprecationWarning,
807            stacklevel=2,
808        )
809        return self.map(func, shortcut=shortcut, args=args, **kwargs)
810
811    def _combine(self, applied, shortcut=False):
812        """Recombine the applied objects like the original."""
813        applied_example, applied = peek_at(applied)
814        coord, dim, positions = self._infer_concat_args(applied_example)
815        if shortcut:
816            combined = self._concat_shortcut(applied, dim, positions)
817        else:
818            combined = concat(applied, dim)
819            combined = _maybe_reorder(combined, dim, positions)
820
821        if isinstance(combined, type(self._obj)):
822            # only restore dimension order for arrays
823            combined = self._restore_dim_order(combined)
824        # assign coord when the applied function does not return that coord
825        if coord is not None and dim not in applied_example.dims:
826            if shortcut:
827                coord_var = as_variable(coord)
828                combined._coords[coord.name] = coord_var
829            else:
830                combined.coords[coord.name] = coord
831        combined = self._maybe_restore_empty_groups(combined)
832        combined = self._maybe_unstack(combined)
833        return combined
834
835    def reduce(
836        self, func, dim=None, axis=None, keep_attrs=None, shortcut=True, **kwargs
837    ):
838        """Reduce the items in this group by applying `func` along some
839        dimension(s).
840
841        Parameters
842        ----------
843        func : callable
844            Function which can be called in the form
845            `func(x, axis=axis, **kwargs)` to return the result of collapsing
846            an np.ndarray over an integer valued axis.
847        dim : ..., str or sequence of str, optional
848            Dimension(s) over which to apply `func`.
849        axis : int or sequence of int, optional
850            Axis(es) over which to apply `func`. Only one of the 'dimension'
851            and 'axis' arguments can be supplied. If neither are supplied, then
852            `func` is calculated over all dimension for each group item.
853        keep_attrs : bool, optional
854            If True, the datasets's attributes (`attrs`) will be copied from
855            the original object to the new one.  If False (default), the new
856            object will be returned without attributes.
857        **kwargs : dict
858            Additional keyword arguments passed on to `func`.
859
860        Returns
861        -------
862        reduced : Array
863            Array with summarized data and the indicated dimension(s)
864            removed.
865        """
866        if dim is None:
867            dim = self._group_dim
868
869        if keep_attrs is None:
870            keep_attrs = _get_keep_attrs(default=False)
871
872        def reduce_array(ar):
873            return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs)
874
875        check_reduce_dims(dim, self.dims)
876
877        return self.map(reduce_array, shortcut=shortcut)
878
879
880class DatasetGroupBy(GroupBy, DatasetGroupbyArithmetic):
881
882    __slots__ = ()
883
884    def map(self, func, args=(), shortcut=None, **kwargs):
885        """Apply a function to each Dataset in the group and concatenate them
886        together into a new Dataset.
887
888        `func` is called like `func(ds, *args, **kwargs)` for each dataset `ds`
889        in this group.
890
891        Apply uses heuristics (like `pandas.GroupBy.apply`) to figure out how
892        to stack together the datasets. The rule is:
893
894        1. If the dimension along which the group coordinate is defined is
895           still in the first grouped item after applying `func`, then stack
896           over this dimension.
897        2. Otherwise, stack over the new dimension given by name of this
898           grouping (the argument to the `groupby` function).
899
900        Parameters
901        ----------
902        func : callable
903            Callable to apply to each sub-dataset.
904        args : tuple, optional
905            Positional arguments to pass to `func`.
906        **kwargs
907            Used to call `func(ds, **kwargs)` for each sub-dataset `ar`.
908
909        Returns
910        -------
911        applied : Dataset or DataArray
912            The result of splitting, applying and combining this dataset.
913        """
914        # ignore shortcut if set (for now)
915        applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped())
916        return self._combine(applied)
917
918    def apply(self, func, args=(), shortcut=None, **kwargs):
919        """
920        Backward compatible implementation of ``map``
921
922        See Also
923        --------
924        DatasetGroupBy.map
925        """
926
927        warnings.warn(
928            "GroupBy.apply may be deprecated in the future. Using GroupBy.map is encouraged",
929            PendingDeprecationWarning,
930            stacklevel=2,
931        )
932        return self.map(func, shortcut=shortcut, args=args, **kwargs)
933
934    def _combine(self, applied):
935        """Recombine the applied objects like the original."""
936        applied_example, applied = peek_at(applied)
937        coord, dim, positions = self._infer_concat_args(applied_example)
938        combined = concat(applied, dim)
939        combined = _maybe_reorder(combined, dim, positions)
940        # assign coord when the applied function does not return that coord
941        if coord is not None and dim not in applied_example.dims:
942            combined[coord.name] = coord
943        combined = self._maybe_restore_empty_groups(combined)
944        combined = self._maybe_unstack(combined)
945        return combined
946
947    def reduce(self, func, dim=None, keep_attrs=None, **kwargs):
948        """Reduce the items in this group by applying `func` along some
949        dimension(s).
950
951        Parameters
952        ----------
953        func : callable
954            Function which can be called in the form
955            `func(x, axis=axis, **kwargs)` to return the result of collapsing
956            an np.ndarray over an integer valued axis.
957        dim : ..., str or sequence of str, optional
958            Dimension(s) over which to apply `func`.
959        axis : int or sequence of int, optional
960            Axis(es) over which to apply `func`. Only one of the 'dimension'
961            and 'axis' arguments can be supplied. If neither are supplied, then
962            `func` is calculated over all dimension for each group item.
963        keep_attrs : bool, optional
964            If True, the datasets's attributes (`attrs`) will be copied from
965            the original object to the new one.  If False (default), the new
966            object will be returned without attributes.
967        **kwargs : dict
968            Additional keyword arguments passed on to `func`.
969
970        Returns
971        -------
972        reduced : Array
973            Array with summarized data and the indicated dimension(s)
974            removed.
975        """
976        if dim is None:
977            dim = self._group_dim
978
979        if keep_attrs is None:
980            keep_attrs = _get_keep_attrs(default=False)
981
982        def reduce_dataset(ds):
983            return ds.reduce(func, dim, keep_attrs, **kwargs)
984
985        check_reduce_dims(dim, self.dims)
986
987        return self.map(reduce_dataset)
988
989    def assign(self, **kwargs):
990        """Assign data variables by group.
991
992        See Also
993        --------
994        Dataset.assign
995        """
996        return self.map(lambda ds: ds.assign(**kwargs))
997