1import datetime 2import warnings 3 4import numpy as np 5import pandas as pd 6 7from . import dtypes, duck_array_ops, nputils, ops 8from .arithmetic import DataArrayGroupbyArithmetic, DatasetGroupbyArithmetic 9from .concat import concat 10from .formatting import format_array_flat 11from .indexes import propagate_indexes 12from .options import _get_keep_attrs 13from .pycompat import integer_types 14from .utils import ( 15 either_dict_or_kwargs, 16 hashable, 17 is_scalar, 18 maybe_wrap_array, 19 peek_at, 20 safe_cast_to_index, 21) 22from .variable import IndexVariable, Variable, as_variable 23 24 25def check_reduce_dims(reduce_dims, dimensions): 26 27 if reduce_dims is not ...: 28 if is_scalar(reduce_dims): 29 reduce_dims = [reduce_dims] 30 if any(dim not in dimensions for dim in reduce_dims): 31 raise ValueError( 32 f"cannot reduce over dimensions {reduce_dims!r}. expected either '...' " 33 f"to reduce over all dimensions or one or more of {dimensions!r}." 34 ) 35 36 37def unique_value_groups(ar, sort=True): 38 """Group an array by its unique values. 39 40 Parameters 41 ---------- 42 ar : array-like 43 Input array. This will be flattened if it is not already 1-D. 44 sort : bool, optional 45 Whether or not to sort unique values. 46 47 Returns 48 ------- 49 values : np.ndarray 50 Sorted, unique values as returned by `np.unique`. 51 indices : list of lists of int 52 Each element provides the integer indices in `ar` with values given by 53 the corresponding value in `unique_values`. 54 """ 55 inverse, values = pd.factorize(ar, sort=sort) 56 groups = [[] for _ in range(len(values))] 57 for n, g in enumerate(inverse): 58 if g >= 0: 59 # pandas uses -1 to mark NaN, but doesn't include them in values 60 groups[g].append(n) 61 return values, groups 62 63 64def _dummy_copy(xarray_obj): 65 from .dataarray import DataArray 66 from .dataset import Dataset 67 68 if isinstance(xarray_obj, Dataset): 69 res = Dataset( 70 { 71 k: dtypes.get_fill_value(v.dtype) 72 for k, v in xarray_obj.data_vars.items() 73 }, 74 { 75 k: dtypes.get_fill_value(v.dtype) 76 for k, v in xarray_obj.coords.items() 77 if k not in xarray_obj.dims 78 }, 79 xarray_obj.attrs, 80 ) 81 elif isinstance(xarray_obj, DataArray): 82 res = DataArray( 83 dtypes.get_fill_value(xarray_obj.dtype), 84 { 85 k: dtypes.get_fill_value(v.dtype) 86 for k, v in xarray_obj.coords.items() 87 if k not in xarray_obj.dims 88 }, 89 dims=[], 90 name=xarray_obj.name, 91 attrs=xarray_obj.attrs, 92 ) 93 else: # pragma: no cover 94 raise AssertionError 95 return res 96 97 98def _is_one_or_none(obj): 99 return obj == 1 or obj is None 100 101 102def _consolidate_slices(slices): 103 """Consolidate adjacent slices in a list of slices.""" 104 result = [] 105 last_slice = slice(None) 106 for slice_ in slices: 107 if not isinstance(slice_, slice): 108 raise ValueError(f"list element is not a slice: {slice_!r}") 109 if ( 110 result 111 and last_slice.stop == slice_.start 112 and _is_one_or_none(last_slice.step) 113 and _is_one_or_none(slice_.step) 114 ): 115 last_slice = slice(last_slice.start, slice_.stop, slice_.step) 116 result[-1] = last_slice 117 else: 118 result.append(slice_) 119 last_slice = slice_ 120 return result 121 122 123def _inverse_permutation_indices(positions): 124 """Like inverse_permutation, but also handles slices. 125 126 Parameters 127 ---------- 128 positions : list of ndarray or slice 129 If slice objects, all are assumed to be slices. 130 131 Returns 132 ------- 133 np.ndarray of indices or None, if no permutation is necessary. 134 """ 135 if not positions: 136 return None 137 138 if isinstance(positions[0], slice): 139 positions = _consolidate_slices(positions) 140 if positions == slice(None): 141 return None 142 positions = [np.arange(sl.start, sl.stop, sl.step) for sl in positions] 143 144 return nputils.inverse_permutation(np.concatenate(positions)) 145 146 147class _DummyGroup: 148 """Class for keeping track of grouped dimensions without coordinates. 149 150 Should not be user visible. 151 """ 152 153 __slots__ = ("name", "coords", "size") 154 155 def __init__(self, obj, name, coords): 156 self.name = name 157 self.coords = coords 158 self.size = obj.sizes[name] 159 160 @property 161 def dims(self): 162 return (self.name,) 163 164 @property 165 def ndim(self): 166 return 1 167 168 @property 169 def values(self): 170 return range(self.size) 171 172 @property 173 def shape(self): 174 return (self.size,) 175 176 def __getitem__(self, key): 177 if isinstance(key, tuple): 178 key = key[0] 179 return self.values[key] 180 181 182def _ensure_1d(group, obj): 183 if group.ndim != 1: 184 # try to stack the dims of the group into a single dim 185 orig_dims = group.dims 186 stacked_dim = "stacked_" + "_".join(orig_dims) 187 # these dimensions get created by the stack operation 188 inserted_dims = [dim for dim in group.dims if dim not in group.coords] 189 # the copy is necessary here, otherwise read only array raises error 190 # in pandas: https://github.com/pydata/pandas/issues/12813 191 group = group.stack(**{stacked_dim: orig_dims}).copy() 192 obj = obj.stack(**{stacked_dim: orig_dims}) 193 else: 194 stacked_dim = None 195 inserted_dims = [] 196 return group, obj, stacked_dim, inserted_dims 197 198 199def _unique_and_monotonic(group): 200 if isinstance(group, _DummyGroup): 201 return True 202 index = safe_cast_to_index(group) 203 return index.is_unique and index.is_monotonic 204 205 206def _apply_loffset(grouper, result): 207 """ 208 (copied from pandas) 209 if loffset is set, offset the result index 210 211 This is NOT an idempotent routine, it will be applied 212 exactly once to the result. 213 214 Parameters 215 ---------- 216 result : Series or DataFrame 217 the result of resample 218 """ 219 220 needs_offset = ( 221 isinstance(grouper.loffset, (pd.DateOffset, datetime.timedelta)) 222 and isinstance(result.index, pd.DatetimeIndex) 223 and len(result.index) > 0 224 ) 225 226 if needs_offset: 227 result.index = result.index + grouper.loffset 228 229 grouper.loffset = None 230 231 232class GroupBy: 233 """A object that implements the split-apply-combine pattern. 234 235 Modeled after `pandas.GroupBy`. The `GroupBy` object can be iterated over 236 (unique_value, grouped_array) pairs, but the main way to interact with a 237 groupby object are with the `apply` or `reduce` methods. You can also 238 directly call numpy methods like `mean` or `std`. 239 240 You should create a GroupBy object by using the `DataArray.groupby` or 241 `Dataset.groupby` methods. 242 243 See Also 244 -------- 245 Dataset.groupby 246 DataArray.groupby 247 """ 248 249 __slots__ = ( 250 "_full_index", 251 "_inserted_dims", 252 "_group", 253 "_group_dim", 254 "_group_indices", 255 "_groups", 256 "_obj", 257 "_restore_coord_dims", 258 "_stacked_dim", 259 "_unique_coord", 260 "_dims", 261 ) 262 263 def __init__( 264 self, 265 obj, 266 group, 267 squeeze=False, 268 grouper=None, 269 bins=None, 270 restore_coord_dims=True, 271 cut_kwargs=None, 272 ): 273 """Create a GroupBy object 274 275 Parameters 276 ---------- 277 obj : Dataset or DataArray 278 Object to group. 279 group : DataArray 280 Array with the group values. 281 squeeze : bool, optional 282 If "group" is a coordinate of object, `squeeze` controls whether 283 the subarrays have a dimension of length 1 along that coordinate or 284 if the dimension is squeezed out. 285 grouper : pandas.Grouper, optional 286 Used for grouping values along the `group` array. 287 bins : array-like, optional 288 If `bins` is specified, the groups will be discretized into the 289 specified bins by `pandas.cut`. 290 restore_coord_dims : bool, default: True 291 If True, also restore the dimension order of multi-dimensional 292 coordinates. 293 cut_kwargs : dict, optional 294 Extra keyword arguments to pass to `pandas.cut` 295 296 """ 297 if cut_kwargs is None: 298 cut_kwargs = {} 299 from .dataarray import DataArray 300 301 if grouper is not None and bins is not None: 302 raise TypeError("can't specify both `grouper` and `bins`") 303 304 if not isinstance(group, (DataArray, IndexVariable)): 305 if not hashable(group): 306 raise TypeError( 307 "`group` must be an xarray.DataArray or the " 308 "name of an xarray variable or dimension." 309 f"Received {group!r} instead." 310 ) 311 group = obj[group] 312 if len(group) == 0: 313 raise ValueError(f"{group.name} must not be empty") 314 315 if group.name not in obj.coords and group.name in obj.dims: 316 # DummyGroups should not appear on groupby results 317 group = _DummyGroup(obj, group.name, group.coords) 318 319 if getattr(group, "name", None) is None: 320 group.name = "group" 321 322 group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj) 323 (group_dim,) = group.dims 324 325 expected_size = obj.sizes[group_dim] 326 if group.size != expected_size: 327 raise ValueError( 328 "the group variable's length does not " 329 "match the length of this variable along its " 330 "dimension" 331 ) 332 333 full_index = None 334 335 if bins is not None: 336 if duck_array_ops.isnull(bins).all(): 337 raise ValueError("All bin edges are NaN.") 338 binned = pd.cut(group.values, bins, **cut_kwargs) 339 new_dim_name = group.name + "_bins" 340 group = DataArray(binned, group.coords, name=new_dim_name) 341 full_index = binned.categories 342 343 if grouper is not None: 344 index = safe_cast_to_index(group) 345 if not index.is_monotonic: 346 # TODO: sort instead of raising an error 347 raise ValueError("index must be monotonic for resampling") 348 full_index, first_items = self._get_index_and_items(index, grouper) 349 sbins = first_items.values.astype(np.int64) 350 group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + [ 351 slice(sbins[-1], None) 352 ] 353 unique_coord = IndexVariable(group.name, first_items.index) 354 elif group.dims == (group.name,) and _unique_and_monotonic(group): 355 # no need to factorize 356 group_indices = np.arange(group.size) 357 if not squeeze: 358 # use slices to do views instead of fancy indexing 359 # equivalent to: group_indices = group_indices.reshape(-1, 1) 360 group_indices = [slice(i, i + 1) for i in group_indices] 361 unique_coord = group 362 else: 363 if group.isnull().any(): 364 # drop any NaN valued groups. 365 # also drop obj values where group was NaN 366 # Use where instead of reindex to account for duplicate coordinate labels. 367 obj = obj.where(group.notnull(), drop=True) 368 group = group.dropna(group_dim) 369 370 # look through group to find the unique values 371 group_as_index = safe_cast_to_index(group) 372 sort = bins is None and (not isinstance(group_as_index, pd.MultiIndex)) 373 unique_values, group_indices = unique_value_groups( 374 group_as_index, sort=sort 375 ) 376 unique_coord = IndexVariable(group.name, unique_values) 377 378 if len(group_indices) == 0: 379 if bins is not None: 380 raise ValueError( 381 f"None of the data falls within bins with edges {bins!r}" 382 ) 383 else: 384 raise ValueError( 385 "Failed to group data. Are you grouping by a variable that is all NaN?" 386 ) 387 388 # specification for the groupby operation 389 self._obj = obj 390 self._group = group 391 self._group_dim = group_dim 392 self._group_indices = group_indices 393 self._unique_coord = unique_coord 394 self._stacked_dim = stacked_dim 395 self._inserted_dims = inserted_dims 396 self._full_index = full_index 397 self._restore_coord_dims = restore_coord_dims 398 399 # cached attributes 400 self._groups = None 401 self._dims = None 402 403 @property 404 def dims(self): 405 if self._dims is None: 406 self._dims = self._obj.isel( 407 **{self._group_dim: self._group_indices[0]} 408 ).dims 409 410 return self._dims 411 412 @property 413 def groups(self): 414 """ 415 Mapping from group labels to indices. The indices can be used to index the underlying object. 416 """ 417 # provided to mimic pandas.groupby 418 if self._groups is None: 419 self._groups = dict(zip(self._unique_coord.values, self._group_indices)) 420 return self._groups 421 422 def __getitem__(self, key): 423 """ 424 Get DataArray or Dataset corresponding to a particular group label. 425 """ 426 return self._obj.isel({self._group_dim: self.groups[key]}) 427 428 def __len__(self): 429 return self._unique_coord.size 430 431 def __iter__(self): 432 return zip(self._unique_coord.values, self._iter_grouped()) 433 434 def __repr__(self): 435 return "{}, grouped over {!r}\n{!r} groups with labels {}.".format( 436 self.__class__.__name__, 437 self._unique_coord.name, 438 self._unique_coord.size, 439 ", ".join(format_array_flat(self._unique_coord, 30).split()), 440 ) 441 442 def _get_index_and_items(self, index, grouper): 443 from .resample_cftime import CFTimeGrouper 444 445 s = pd.Series(np.arange(index.size), index) 446 if isinstance(grouper, CFTimeGrouper): 447 first_items = grouper.first_items(index) 448 else: 449 first_items = s.groupby(grouper).first() 450 _apply_loffset(grouper, first_items) 451 full_index = first_items.index 452 if first_items.isnull().any(): 453 first_items = first_items.dropna() 454 return full_index, first_items 455 456 def _iter_grouped(self): 457 """Iterate over each element in this group""" 458 for indices in self._group_indices: 459 yield self._obj.isel(**{self._group_dim: indices}) 460 461 def _infer_concat_args(self, applied_example): 462 if self._group_dim in applied_example.dims: 463 coord = self._group 464 positions = self._group_indices 465 else: 466 coord = self._unique_coord 467 positions = None 468 (dim,) = coord.dims 469 if isinstance(coord, _DummyGroup): 470 coord = None 471 return coord, dim, positions 472 473 def _binary_op(self, other, f, reflexive=False): 474 g = f if not reflexive else lambda x, y: f(y, x) 475 applied = self._yield_binary_applied(g, other) 476 return self._combine(applied) 477 478 def _yield_binary_applied(self, func, other): 479 dummy = None 480 481 for group_value, obj in self: 482 try: 483 other_sel = other.sel(**{self._group.name: group_value}) 484 except AttributeError: 485 raise TypeError( 486 "GroupBy objects only support binary ops " 487 "when the other argument is a Dataset or " 488 "DataArray" 489 ) 490 except (KeyError, ValueError): 491 if self._group.name not in other.dims: 492 raise ValueError( 493 "incompatible dimensions for a grouped " 494 f"binary operation: the group variable {self._group.name!r} " 495 "is not a dimension on the other argument" 496 ) 497 if dummy is None: 498 dummy = _dummy_copy(other) 499 other_sel = dummy 500 501 result = func(obj, other_sel) 502 yield result 503 504 def _maybe_restore_empty_groups(self, combined): 505 """Our index contained empty groups (e.g., from a resampling). If we 506 reduced on that dimension, we want to restore the full index. 507 """ 508 if self._full_index is not None and self._group.name in combined.dims: 509 indexers = {self._group.name: self._full_index} 510 combined = combined.reindex(**indexers) 511 return combined 512 513 def _maybe_unstack(self, obj): 514 """This gets called if we are applying on an array with a 515 multidimensional group.""" 516 if self._stacked_dim is not None and self._stacked_dim in obj.dims: 517 obj = obj.unstack(self._stacked_dim) 518 for dim in self._inserted_dims: 519 if dim in obj.coords: 520 del obj.coords[dim] 521 obj._indexes = propagate_indexes(obj._indexes, exclude=self._inserted_dims) 522 return obj 523 524 def fillna(self, value): 525 """Fill missing values in this object by group. 526 527 This operation follows the normal broadcasting and alignment rules that 528 xarray uses for binary arithmetic, except the result is aligned to this 529 object (``join='left'``) instead of aligned to the intersection of 530 index coordinates (``join='inner'``). 531 532 Parameters 533 ---------- 534 value 535 Used to fill all matching missing values by group. Needs 536 to be of a valid type for the wrapped object's fillna 537 method. 538 539 Returns 540 ------- 541 same type as the grouped object 542 543 See Also 544 -------- 545 Dataset.fillna 546 DataArray.fillna 547 """ 548 return ops.fillna(self, value) 549 550 def quantile( 551 self, q, dim=None, interpolation="linear", keep_attrs=None, skipna=True 552 ): 553 """Compute the qth quantile over each array in the groups and 554 concatenate them together into a new array. 555 556 Parameters 557 ---------- 558 q : float or sequence of float 559 Quantile to compute, which must be between 0 and 1 560 inclusive. 561 dim : ..., str or sequence of str, optional 562 Dimension(s) over which to apply quantile. 563 Defaults to the grouped dimension. 564 interpolation : {"linear", "lower", "higher", "midpoint", "nearest"}, default: "linear" 565 This optional parameter specifies the interpolation method to 566 use when the desired quantile lies between two data points 567 ``i < j``: 568 569 * linear: ``i + (j - i) * fraction``, where ``fraction`` is 570 the fractional part of the index surrounded by ``i`` and 571 ``j``. 572 * lower: ``i``. 573 * higher: ``j``. 574 * nearest: ``i`` or ``j``, whichever is nearest. 575 * midpoint: ``(i + j) / 2``. 576 skipna : bool, optional 577 Whether to skip missing values when aggregating. 578 579 Returns 580 ------- 581 quantiles : Variable 582 If `q` is a single quantile, then the result is a 583 scalar. If multiple percentiles are given, first axis of 584 the result corresponds to the quantile. In either case a 585 quantile dimension is added to the return array. The other 586 dimensions are the dimensions that remain after the 587 reduction of the array. 588 589 See Also 590 -------- 591 numpy.nanquantile, numpy.quantile, pandas.Series.quantile, Dataset.quantile 592 DataArray.quantile 593 594 Examples 595 -------- 596 >>> da = xr.DataArray( 597 ... [[1.3, 8.4, 0.7, 6.9], [0.7, 4.2, 9.4, 1.5], [6.5, 7.3, 2.6, 1.9]], 598 ... coords={"x": [0, 0, 1], "y": [1, 1, 2, 2]}, 599 ... dims=("x", "y"), 600 ... ) 601 >>> ds = xr.Dataset({"a": da}) 602 >>> da.groupby("x").quantile(0) 603 <xarray.DataArray (x: 2, y: 4)> 604 array([[0.7, 4.2, 0.7, 1.5], 605 [6.5, 7.3, 2.6, 1.9]]) 606 Coordinates: 607 * y (y) int64 1 1 2 2 608 quantile float64 0.0 609 * x (x) int64 0 1 610 >>> ds.groupby("y").quantile(0, dim=...) 611 <xarray.Dataset> 612 Dimensions: (y: 2) 613 Coordinates: 614 quantile float64 0.0 615 * y (y) int64 1 2 616 Data variables: 617 a (y) float64 0.7 0.7 618 >>> da.groupby("x").quantile([0, 0.5, 1]) 619 <xarray.DataArray (x: 2, y: 4, quantile: 3)> 620 array([[[0.7 , 1. , 1.3 ], 621 [4.2 , 6.3 , 8.4 ], 622 [0.7 , 5.05, 9.4 ], 623 [1.5 , 4.2 , 6.9 ]], 624 <BLANKLINE> 625 [[6.5 , 6.5 , 6.5 ], 626 [7.3 , 7.3 , 7.3 ], 627 [2.6 , 2.6 , 2.6 ], 628 [1.9 , 1.9 , 1.9 ]]]) 629 Coordinates: 630 * y (y) int64 1 1 2 2 631 * quantile (quantile) float64 0.0 0.5 1.0 632 * x (x) int64 0 1 633 >>> ds.groupby("y").quantile([0, 0.5, 1], dim=...) 634 <xarray.Dataset> 635 Dimensions: (y: 2, quantile: 3) 636 Coordinates: 637 * quantile (quantile) float64 0.0 0.5 1.0 638 * y (y) int64 1 2 639 Data variables: 640 a (y, quantile) float64 0.7 5.35 8.4 0.7 2.25 9.4 641 """ 642 if dim is None: 643 dim = self._group_dim 644 645 out = self.map( 646 self._obj.__class__.quantile, 647 shortcut=False, 648 q=q, 649 dim=dim, 650 interpolation=interpolation, 651 keep_attrs=keep_attrs, 652 skipna=skipna, 653 ) 654 return out 655 656 def where(self, cond, other=dtypes.NA): 657 """Return elements from `self` or `other` depending on `cond`. 658 659 Parameters 660 ---------- 661 cond : DataArray or Dataset 662 Locations at which to preserve this objects values. dtypes have to be `bool` 663 other : scalar, DataArray or Dataset, optional 664 Value to use for locations in this object where ``cond`` is False. 665 By default, inserts missing values. 666 667 Returns 668 ------- 669 same type as the grouped object 670 671 See Also 672 -------- 673 Dataset.where 674 """ 675 return ops.where_method(self, cond, other) 676 677 def _first_or_last(self, op, skipna, keep_attrs): 678 if isinstance(self._group_indices[0], integer_types): 679 # NB. this is currently only used for reductions along an existing 680 # dimension 681 return self._obj 682 if keep_attrs is None: 683 keep_attrs = _get_keep_attrs(default=True) 684 return self.reduce(op, self._group_dim, skipna=skipna, keep_attrs=keep_attrs) 685 686 def first(self, skipna=None, keep_attrs=None): 687 """Return the first element of each group along the group dimension""" 688 return self._first_or_last(duck_array_ops.first, skipna, keep_attrs) 689 690 def last(self, skipna=None, keep_attrs=None): 691 """Return the last element of each group along the group dimension""" 692 return self._first_or_last(duck_array_ops.last, skipna, keep_attrs) 693 694 def assign_coords(self, coords=None, **coords_kwargs): 695 """Assign coordinates by group. 696 697 See Also 698 -------- 699 Dataset.assign_coords 700 Dataset.swap_dims 701 """ 702 coords_kwargs = either_dict_or_kwargs(coords, coords_kwargs, "assign_coords") 703 return self.map(lambda ds: ds.assign_coords(**coords_kwargs)) 704 705 706def _maybe_reorder(xarray_obj, dim, positions): 707 order = _inverse_permutation_indices(positions) 708 709 if order is None or len(order) != xarray_obj.sizes[dim]: 710 return xarray_obj 711 else: 712 return xarray_obj[{dim: order}] 713 714 715class DataArrayGroupBy(GroupBy, DataArrayGroupbyArithmetic): 716 """GroupBy object specialized to grouping DataArray objects""" 717 718 __slots__ = () 719 720 def _iter_grouped_shortcut(self): 721 """Fast version of `_iter_grouped` that yields Variables without 722 metadata 723 """ 724 var = self._obj.variable 725 for indices in self._group_indices: 726 yield var[{self._group_dim: indices}] 727 728 def _concat_shortcut(self, applied, dim, positions=None): 729 # nb. don't worry too much about maintaining this method -- it does 730 # speed things up, but it's not very interpretable and there are much 731 # faster alternatives (e.g., doing the grouped aggregation in a 732 # compiled language) 733 stacked = Variable.concat(applied, dim, shortcut=True) 734 reordered = _maybe_reorder(stacked, dim, positions) 735 return self._obj._replace_maybe_drop_dims(reordered) 736 737 def _restore_dim_order(self, stacked): 738 def lookup_order(dimension): 739 if dimension == self._group.name: 740 (dimension,) = self._group.dims 741 if dimension in self._obj.dims: 742 axis = self._obj.get_axis_num(dimension) 743 else: 744 axis = 1e6 # some arbitrarily high value 745 return axis 746 747 new_order = sorted(stacked.dims, key=lookup_order) 748 return stacked.transpose(*new_order, transpose_coords=self._restore_coord_dims) 749 750 def map(self, func, shortcut=False, args=(), **kwargs): 751 """Apply a function to each array in the group and concatenate them 752 together into a new array. 753 754 `func` is called like `func(ar, *args, **kwargs)` for each array `ar` 755 in this group. 756 757 Apply uses heuristics (like `pandas.GroupBy.apply`) to figure out how 758 to stack together the array. The rule is: 759 760 1. If the dimension along which the group coordinate is defined is 761 still in the first grouped array after applying `func`, then stack 762 over this dimension. 763 2. Otherwise, stack over the new dimension given by name of this 764 grouping (the argument to the `groupby` function). 765 766 Parameters 767 ---------- 768 func : callable 769 Callable to apply to each array. 770 shortcut : bool, optional 771 Whether or not to shortcut evaluation under the assumptions that: 772 773 (1) The action of `func` does not depend on any of the array 774 metadata (attributes or coordinates) but only on the data and 775 dimensions. 776 (2) The action of `func` creates arrays with homogeneous metadata, 777 that is, with the same dimensions and attributes. 778 779 If these conditions are satisfied `shortcut` provides significant 780 speedup. This should be the case for many common groupby operations 781 (e.g., applying numpy ufuncs). 782 *args : tuple, optional 783 Positional arguments passed to `func`. 784 **kwargs 785 Used to call `func(ar, **kwargs)` for each array `ar`. 786 787 Returns 788 ------- 789 applied : DataArray or DataArray 790 The result of splitting, applying and combining this array. 791 """ 792 grouped = self._iter_grouped_shortcut() if shortcut else self._iter_grouped() 793 applied = (maybe_wrap_array(arr, func(arr, *args, **kwargs)) for arr in grouped) 794 return self._combine(applied, shortcut=shortcut) 795 796 def apply(self, func, shortcut=False, args=(), **kwargs): 797 """ 798 Backward compatible implementation of ``map`` 799 800 See Also 801 -------- 802 DataArrayGroupBy.map 803 """ 804 warnings.warn( 805 "GroupBy.apply may be deprecated in the future. Using GroupBy.map is encouraged", 806 PendingDeprecationWarning, 807 stacklevel=2, 808 ) 809 return self.map(func, shortcut=shortcut, args=args, **kwargs) 810 811 def _combine(self, applied, shortcut=False): 812 """Recombine the applied objects like the original.""" 813 applied_example, applied = peek_at(applied) 814 coord, dim, positions = self._infer_concat_args(applied_example) 815 if shortcut: 816 combined = self._concat_shortcut(applied, dim, positions) 817 else: 818 combined = concat(applied, dim) 819 combined = _maybe_reorder(combined, dim, positions) 820 821 if isinstance(combined, type(self._obj)): 822 # only restore dimension order for arrays 823 combined = self._restore_dim_order(combined) 824 # assign coord when the applied function does not return that coord 825 if coord is not None and dim not in applied_example.dims: 826 if shortcut: 827 coord_var = as_variable(coord) 828 combined._coords[coord.name] = coord_var 829 else: 830 combined.coords[coord.name] = coord 831 combined = self._maybe_restore_empty_groups(combined) 832 combined = self._maybe_unstack(combined) 833 return combined 834 835 def reduce( 836 self, func, dim=None, axis=None, keep_attrs=None, shortcut=True, **kwargs 837 ): 838 """Reduce the items in this group by applying `func` along some 839 dimension(s). 840 841 Parameters 842 ---------- 843 func : callable 844 Function which can be called in the form 845 `func(x, axis=axis, **kwargs)` to return the result of collapsing 846 an np.ndarray over an integer valued axis. 847 dim : ..., str or sequence of str, optional 848 Dimension(s) over which to apply `func`. 849 axis : int or sequence of int, optional 850 Axis(es) over which to apply `func`. Only one of the 'dimension' 851 and 'axis' arguments can be supplied. If neither are supplied, then 852 `func` is calculated over all dimension for each group item. 853 keep_attrs : bool, optional 854 If True, the datasets's attributes (`attrs`) will be copied from 855 the original object to the new one. If False (default), the new 856 object will be returned without attributes. 857 **kwargs : dict 858 Additional keyword arguments passed on to `func`. 859 860 Returns 861 ------- 862 reduced : Array 863 Array with summarized data and the indicated dimension(s) 864 removed. 865 """ 866 if dim is None: 867 dim = self._group_dim 868 869 if keep_attrs is None: 870 keep_attrs = _get_keep_attrs(default=False) 871 872 def reduce_array(ar): 873 return ar.reduce(func, dim, axis, keep_attrs=keep_attrs, **kwargs) 874 875 check_reduce_dims(dim, self.dims) 876 877 return self.map(reduce_array, shortcut=shortcut) 878 879 880class DatasetGroupBy(GroupBy, DatasetGroupbyArithmetic): 881 882 __slots__ = () 883 884 def map(self, func, args=(), shortcut=None, **kwargs): 885 """Apply a function to each Dataset in the group and concatenate them 886 together into a new Dataset. 887 888 `func` is called like `func(ds, *args, **kwargs)` for each dataset `ds` 889 in this group. 890 891 Apply uses heuristics (like `pandas.GroupBy.apply`) to figure out how 892 to stack together the datasets. The rule is: 893 894 1. If the dimension along which the group coordinate is defined is 895 still in the first grouped item after applying `func`, then stack 896 over this dimension. 897 2. Otherwise, stack over the new dimension given by name of this 898 grouping (the argument to the `groupby` function). 899 900 Parameters 901 ---------- 902 func : callable 903 Callable to apply to each sub-dataset. 904 args : tuple, optional 905 Positional arguments to pass to `func`. 906 **kwargs 907 Used to call `func(ds, **kwargs)` for each sub-dataset `ar`. 908 909 Returns 910 ------- 911 applied : Dataset or DataArray 912 The result of splitting, applying and combining this dataset. 913 """ 914 # ignore shortcut if set (for now) 915 applied = (func(ds, *args, **kwargs) for ds in self._iter_grouped()) 916 return self._combine(applied) 917 918 def apply(self, func, args=(), shortcut=None, **kwargs): 919 """ 920 Backward compatible implementation of ``map`` 921 922 See Also 923 -------- 924 DatasetGroupBy.map 925 """ 926 927 warnings.warn( 928 "GroupBy.apply may be deprecated in the future. Using GroupBy.map is encouraged", 929 PendingDeprecationWarning, 930 stacklevel=2, 931 ) 932 return self.map(func, shortcut=shortcut, args=args, **kwargs) 933 934 def _combine(self, applied): 935 """Recombine the applied objects like the original.""" 936 applied_example, applied = peek_at(applied) 937 coord, dim, positions = self._infer_concat_args(applied_example) 938 combined = concat(applied, dim) 939 combined = _maybe_reorder(combined, dim, positions) 940 # assign coord when the applied function does not return that coord 941 if coord is not None and dim not in applied_example.dims: 942 combined[coord.name] = coord 943 combined = self._maybe_restore_empty_groups(combined) 944 combined = self._maybe_unstack(combined) 945 return combined 946 947 def reduce(self, func, dim=None, keep_attrs=None, **kwargs): 948 """Reduce the items in this group by applying `func` along some 949 dimension(s). 950 951 Parameters 952 ---------- 953 func : callable 954 Function which can be called in the form 955 `func(x, axis=axis, **kwargs)` to return the result of collapsing 956 an np.ndarray over an integer valued axis. 957 dim : ..., str or sequence of str, optional 958 Dimension(s) over which to apply `func`. 959 axis : int or sequence of int, optional 960 Axis(es) over which to apply `func`. Only one of the 'dimension' 961 and 'axis' arguments can be supplied. If neither are supplied, then 962 `func` is calculated over all dimension for each group item. 963 keep_attrs : bool, optional 964 If True, the datasets's attributes (`attrs`) will be copied from 965 the original object to the new one. If False (default), the new 966 object will be returned without attributes. 967 **kwargs : dict 968 Additional keyword arguments passed on to `func`. 969 970 Returns 971 ------- 972 reduced : Array 973 Array with summarized data and the indicated dimension(s) 974 removed. 975 """ 976 if dim is None: 977 dim = self._group_dim 978 979 if keep_attrs is None: 980 keep_attrs = _get_keep_attrs(default=False) 981 982 def reduce_dataset(ds): 983 return ds.reduce(func, dim, keep_attrs, **kwargs) 984 985 check_reduce_dims(dim, self.dims) 986 987 return self.map(reduce_dataset) 988 989 def assign(self, **kwargs): 990 """Assign data variables by group. 991 992 See Also 993 -------- 994 Dataset.assign 995 """ 996 return self.map(lambda ds: ds.assign(**kwargs)) 997