1import collections
2import itertools as it
3import operator
4import warnings
5from numbers import Integral
6
7import numpy as np
8import pandas as pd
9
10from ..base import tokenize
11from ..highlevelgraph import HighLevelGraph
12from ..utils import M, derived_from, funcname, itemgetter
13from .core import (
14    DataFrame,
15    Series,
16    _extract_meta,
17    aca,
18    map_partitions,
19    new_dd_object,
20    no_default,
21    split_out_on_index,
22)
23from .methods import concat, drop_columns
24from .shuffle import shuffle
25from .utils import (
26    PANDAS_GT_110,
27    insert_meta_param_description,
28    is_dataframe_like,
29    is_series_like,
30    make_meta,
31    raise_on_meta_error,
32)
33
34# #############################################
35#
36# GroupBy implementation notes
37#
38# Dask groupby supports reductions, i.e., mean, sum and alike, and apply. The
39# former do not shuffle the data and are efficiently implemented as tree
40# reductions. The latter is implemented by shuffling the underlying partiitons
41# such that all items of a group can be found in the same parititon.
42#
43# The argument to ``.groupby``, the index, can be a ``str``, ``dd.DataFrame``,
44# ``dd.Series``, or a list thereof. In operations on the grouped object, the
45# divisions of the the grouped object and the items of index have to align.
46# Currently, there is no support to shuffle the index values as part of the
47# groupby operation. Therefore, the alignment has to be guaranteed by the
48# caller.
49#
50# To operate on matching partitions, most groupby operations exploit the
51# corresponding support in ``apply_concat_apply``. Specifically, this function
52# operates on matching partitions of frame-like objects passed as varargs.
53#
54# After the initial chunk step, the passed index is implicitly passed along to
55# subsequent operations as the index of the partitions. Groupby operations on
56# the individual partitions can then access the index via the ``levels``
57# parameter of the ``groupby`` function. The correct argument is determined by
58# the ``_determine_levels`` function.
59#
60# To minimize overhead, series in an index that were obtained by getitem on the
61# object to group are not passed as series to the various operations, but as
62# columnn keys. This transformation is implemented as ``_normalize_index``.
63#
64# #############################################
65
66
67def _determine_levels(index):
68    """Determine the correct levels argument to groupby."""
69    if isinstance(index, (tuple, list)) and len(index) > 1:
70        return list(range(len(index)))
71    else:
72        return 0
73
74
75def _normalize_index(df, index):
76    """Replace series with column names in an index wherever possible."""
77    if not isinstance(df, DataFrame):
78        return index
79
80    elif isinstance(index, list):
81        return [_normalize_index(df, col) for col in index]
82
83    elif (
84        is_series_like(index)
85        and index.name in df.columns
86        and index._name == df[index.name]._name
87    ):
88        return index.name
89
90    elif (
91        isinstance(index, DataFrame)
92        and set(index.columns).issubset(df.columns)
93        and index._name == df[index.columns]._name
94    ):
95        return list(index.columns)
96
97    else:
98        return index
99
100
101def _maybe_slice(grouped, columns):
102    """
103    Slice columns if grouped is pd.DataFrameGroupBy
104    """
105    # FIXME: update with better groupby object detection (i.e.: ngroups, get_group)
106    if "groupby" in type(grouped).__name__.lower():
107        if columns is not None:
108            if isinstance(columns, (tuple, list, set, pd.Index)):
109                columns = list(columns)
110            return grouped[columns]
111    return grouped
112
113
114def _is_aligned(df, by):
115    """Check if `df` and `by` have aligned indices"""
116    if is_series_like(by) or is_dataframe_like(by):
117        return df.index.equals(by.index)
118    elif isinstance(by, (list, tuple)):
119        return all(_is_aligned(df, i) for i in by)
120    else:
121        return True
122
123
124def _groupby_raise_unaligned(df, **kwargs):
125    """Groupby, but raise if df and `by` key are unaligned.
126
127    Pandas supports grouping by a column that doesn't align with the input
128    frame/series/index. However, the reindexing does not seem to be
129    threadsafe, and can result in incorrect results. Since grouping by an
130    unaligned key is generally a bad idea, we just error loudly in dask.
131
132    For more information see pandas GH issue #15244 and Dask GH issue #1876."""
133    by = kwargs.get("by", None)
134    if by is not None and not _is_aligned(df, by):
135        msg = (
136            "Grouping by an unaligned index is unsafe and unsupported.\n"
137            "This can be caused by filtering only one of the object or\n"
138            "grouping key. For example, the following works in pandas,\n"
139            "but not in dask:\n"
140            "\n"
141            "df[df.foo < 0].groupby(df.bar)\n"
142            "\n"
143            "This can be avoided by either filtering beforehand, or\n"
144            "passing in the name of the column instead:\n"
145            "\n"
146            "df2 = df[df.foo < 0]\n"
147            "df2.groupby(df2.bar)\n"
148            "# or\n"
149            "df[df.foo < 0].groupby('bar')\n"
150            "\n"
151            "For more information see dask GH issue #1876."
152        )
153        raise ValueError(msg)
154    elif by is not None and len(by):
155        # since we're coming through apply, `by` will be a tuple.
156        # Pandas treats tuples as a single key, and lists as multiple keys
157        # We want multiple keys
158        if isinstance(by, str):
159            by = [by]
160        kwargs.update(by=list(by))
161    return df.groupby(**kwargs)
162
163
164def _groupby_slice_apply(
165    df, grouper, key, func, *args, group_keys=True, dropna=None, observed=None, **kwargs
166):
167    # No need to use raise if unaligned here - this is only called after
168    # shuffling, which makes everything aligned already
169    dropna = {"dropna": dropna} if dropna is not None else {}
170    observed = {"observed": observed} if observed is not None else {}
171    g = df.groupby(grouper, group_keys=group_keys, **observed, **dropna)
172    if key:
173        g = g[key]
174    return g.apply(func, *args, **kwargs)
175
176
177def _groupby_slice_transform(
178    df, grouper, key, func, *args, group_keys=True, dropna=None, observed=None, **kwargs
179):
180    # No need to use raise if unaligned here - this is only called after
181    # shuffling, which makes everything aligned already
182    dropna = {"dropna": dropna} if dropna is not None else {}
183    observed = {"observed": observed} if observed is not None else {}
184    g = df.groupby(grouper, group_keys=group_keys, **observed, **dropna)
185    if key:
186        g = g[key]
187
188    # Cannot call transform on an empty dataframe
189    if len(df) == 0:
190        return g.apply(func, *args, **kwargs)
191
192    return g.transform(func, *args, **kwargs)
193
194
195def _groupby_get_group(df, by_key, get_key, columns):
196    # SeriesGroupBy may pass df which includes group key
197    grouped = _groupby_raise_unaligned(df, by=by_key)
198
199    if get_key in grouped.groups:
200        if is_dataframe_like(df):
201            grouped = grouped[columns]
202        return grouped.get_group(get_key)
203
204    else:
205        # to create empty DataFrame/Series, which has the same
206        # dtype as the original
207        if is_dataframe_like(df):
208            # may be SeriesGroupBy
209            df = df[columns]
210        return df.iloc[0:0]
211
212
213###############################################################
214# Aggregation
215###############################################################
216
217
218class Aggregation:
219    """User defined groupby-aggregation.
220
221    This class allows users to define their own custom aggregation in terms of
222    operations on Pandas dataframes in a map-reduce style. You need to specify
223    what operation to do on each chunk of data, how to combine those chunks of
224    data together, and then how to finalize the result.
225
226    See :ref:`dataframe.groupby.aggregate` for more.
227
228    Parameters
229    ----------
230    name : str
231        the name of the aggregation. It should be unique, since intermediate
232        result will be identified by this name.
233    chunk : callable
234        a function that will be called with the grouped column of each
235        partition. It can either return a single series or a tuple of series.
236        The index has to be equal to the groups.
237    agg : callable
238        a function that will be called to aggregate the results of each chunk.
239        Again the argument(s) will be grouped series. If ``chunk`` returned a
240        tuple, ``agg`` will be called with all of them as individual positional
241        arguments.
242    finalize : callable
243        an optional finalizer that will be called with the results from the
244        aggregation.
245
246    Examples
247    --------
248    We could implement ``sum`` as follows:
249
250    >>> custom_sum = dd.Aggregation(
251    ...     name='custom_sum',
252    ...     chunk=lambda s: s.sum(),
253    ...     agg=lambda s0: s0.sum()
254    ... )  # doctest: +SKIP
255    >>> df.groupby('g').agg(custom_sum)  # doctest: +SKIP
256
257    We can implement ``mean`` as follows:
258
259    >>> custom_mean = dd.Aggregation(
260    ...     name='custom_mean',
261    ...     chunk=lambda s: (s.count(), s.sum()),
262    ...     agg=lambda count, sum: (count.sum(), sum.sum()),
263    ...     finalize=lambda count, sum: sum / count,
264    ... )  # doctest: +SKIP
265    >>> df.groupby('g').agg(custom_mean)  # doctest: +SKIP
266
267    Though of course, both of these are built-in and so you don't need to
268    implement them yourself.
269    """
270
271    def __init__(self, name, chunk, agg, finalize=None):
272        self.chunk = chunk
273        self.agg = agg
274        self.finalize = finalize
275        self.__name__ = name
276
277
278def _groupby_aggregate(
279    df, aggfunc=None, levels=None, dropna=None, sort=False, observed=None, **kwargs
280):
281    dropna = {"dropna": dropna} if dropna is not None else {}
282    observed = {"observed": observed} if observed is not None else {}
283
284    grouped = df.groupby(level=levels, sort=sort, **observed, **dropna)
285    return aggfunc(grouped, **kwargs)
286
287
288def _apply_chunk(df, *index, dropna=None, observed=None, **kwargs):
289    func = kwargs.pop("chunk")
290    columns = kwargs.pop("columns")
291    dropna = {"dropna": dropna} if dropna is not None else {}
292    observed = {"observed": observed} if observed is not None else {}
293
294    g = _groupby_raise_unaligned(df, by=index, **observed, **dropna)
295    if is_series_like(df) or columns is None:
296        return func(g, **kwargs)
297    else:
298        if isinstance(columns, (tuple, list, set, pd.Index)):
299            columns = list(columns)
300        return func(g[columns], **kwargs)
301
302
303def _var_chunk(df, *index):
304    if is_series_like(df):
305        df = df.to_frame()
306
307    df = df.copy()
308
309    g = _groupby_raise_unaligned(df, by=index)
310    x = g.sum()
311
312    n = g[x.columns].count().rename(columns=lambda c: (c, "-count"))
313
314    cols = x.columns
315    df[cols] = df[cols] ** 2
316
317    g2 = _groupby_raise_unaligned(df, by=index)
318    x2 = g2.sum().rename(columns=lambda c: (c, "-x2"))
319
320    return concat([x, x2, n], axis=1)
321
322
323def _var_combine(g, levels, sort=False):
324    return g.groupby(level=levels, sort=sort).sum()
325
326
327def _var_agg(g, levels, ddof, sort=False):
328    g = g.groupby(level=levels, sort=sort).sum()
329    nc = len(g.columns)
330    x = g[g.columns[: nc // 3]]
331    # chunks columns are tuples (value, name), so we just keep the value part
332    x2 = g[g.columns[nc // 3 : 2 * nc // 3]].rename(columns=lambda c: c[0])
333    n = g[g.columns[-nc // 3 :]].rename(columns=lambda c: c[0])
334
335    # TODO: replace with _finalize_var?
336    result = x2 - x ** 2 / n
337    div = n - ddof
338    div[div < 0] = 0
339    result /= div
340    result[(n - ddof) == 0] = np.nan
341    assert is_dataframe_like(result)
342    result[result < 0] = 0  # avoid rounding errors that take us to zero
343    return result
344
345
346def _cov_combine(g, levels):
347    return g
348
349
350def _cov_finalizer(df, cols, std=False):
351    vals = []
352    num_elements = len(list(it.product(cols, repeat=2)))
353    num_cols = len(cols)
354    vals = list(range(num_elements))
355    col_idx_mapping = dict(zip(cols, range(num_cols)))
356    for i, j in it.combinations_with_replacement(df[cols].columns, 2):
357        x = col_idx_mapping[i]
358        y = col_idx_mapping[j]
359        idx = x + num_cols * y
360        mul_col = f"{i}{j}"
361        ni = df["%s-count" % i]
362        nj = df["%s-count" % j]
363
364        n = np.sqrt(ni * nj)
365        div = n - 1
366        div[div < 0] = 0
367        val = (df[mul_col] - df[i] * df[j] / n).values[0] / div.values[0]
368        if std:
369            ii = f"{i}{i}"
370            jj = f"{j}{j}"
371            std_val_i = (df[ii] - (df[i] ** 2) / ni).values[0] / div.values[0]
372            std_val_j = (df[jj] - (df[j] ** 2) / nj).values[0] / div.values[0]
373            val = val / np.sqrt(std_val_i * std_val_j)
374
375        vals[idx] = val
376        if i != j:
377            idx = num_cols * x + y
378            vals[idx] = val
379
380    level_1 = cols
381    index = pd.MultiIndex.from_product([level_1, level_1])
382    return pd.Series(vals, index=index)
383
384
385def _mul_cols(df, cols):
386    """Internal function to be used with apply to multiply
387    each column in a dataframe by every other column
388
389    a b c -> a*a, a*b, b*b, b*c, c*c
390    """
391    _df = df.__class__()
392    for i, j in it.combinations_with_replacement(cols, 2):
393        col = f"{i}{j}"
394        _df[col] = df[i] * df[j]
395
396    # Fix index in a groupby().apply() context
397    # https://github.com/dask/dask/issues/8137
398    # https://github.com/pandas-dev/pandas/issues/43568
399    _df.index = [0] * len(_df)
400    return _df
401
402
403def _cov_chunk(df, *index):
404    """Covariance Chunk Logic
405
406    Parameters
407    ----------
408    df : Pandas.DataFrame
409    std : bool, optional
410        When std=True we are calculating with Correlation
411
412    Returns
413    -------
414    tuple
415        Processed X, Multiplied Cols,
416    """
417    if is_series_like(df):
418        df = df.to_frame()
419    df = df.copy()
420
421    # mapping columns to str(numerical) values allows us to easily handle
422    # arbitrary column names (numbers, string, empty strings)
423    col_mapping = collections.OrderedDict()
424    for i, c in enumerate(df.columns):
425        col_mapping[c] = str(i)
426    df = df.rename(columns=col_mapping)
427    cols = df._get_numeric_data().columns
428
429    # when grouping by external series don't exclude columns
430    is_mask = any(is_series_like(s) for s in index)
431    if not is_mask:
432        index = [col_mapping[k] for k in index]
433        cols = cols.drop(np.array(index))
434
435    g = _groupby_raise_unaligned(df, by=index)
436    x = g.sum()
437
438    mul = g.apply(_mul_cols, cols=cols).reset_index(level=-1, drop=True)
439
440    n = g[x.columns].count().rename(columns=lambda c: f"{c}-count")
441    return (x, mul, n, col_mapping)
442
443
444def _cov_agg(_t, levels, ddof, std=False, sort=False):
445    sums = []
446    muls = []
447    counts = []
448
449    # sometime we get a series back from concat combiner
450    t = list(_t)
451
452    cols = t[0][0].columns
453    for x, mul, n, col_mapping in t:
454        sums.append(x)
455        muls.append(mul)
456        counts.append(n)
457        col_mapping = col_mapping
458
459    total_sums = concat(sums).groupby(level=levels, sort=sort).sum()
460    total_muls = concat(muls).groupby(level=levels, sort=sort).sum()
461    total_counts = concat(counts).groupby(level=levels).sum()
462    result = (
463        concat([total_sums, total_muls, total_counts], axis=1)
464        .groupby(level=levels)
465        .apply(_cov_finalizer, cols=cols, std=std)
466    )
467
468    inv_col_mapping = {v: k for k, v in col_mapping.items()}
469    idx_vals = result.index.names
470    idx_mapping = list()
471
472    # when index is None we probably have selected a particular column
473    # df.groupby('a')[['b']].cov()
474    if len(idx_vals) == 1 and all(n is None for n in idx_vals):
475        idx_vals = list(inv_col_mapping.keys() - set(total_sums.columns))
476
477    for idx, val in enumerate(idx_vals):
478        idx_name = inv_col_mapping.get(val, val)
479        idx_mapping.append(idx_name)
480
481        if len(result.columns.levels[0]) < len(col_mapping):
482            # removing index from col_mapping (produces incorrect multiindexes)
483            try:
484                col_mapping.pop(idx_name)
485            except KeyError:
486                # when slicing the col_map will not have the index
487                pass
488
489    keys = list(col_mapping.keys())
490    for level in range(len(result.columns.levels)):
491        result.columns = result.columns.set_levels(keys, level=level)
492
493    result.index.set_names(idx_mapping, inplace=True)
494
495    # stacking can lead to a sorted index
496    s_result = result.stack(dropna=False)
497    assert is_dataframe_like(s_result)
498    return s_result
499
500
501###############################################################
502# nunique
503###############################################################
504def _drop_duplicates_reindex(df):
505    # Fix index in a groupby().apply() context
506    # https://github.com/dask/dask/issues/8137
507    # https://github.com/pandas-dev/pandas/issues/43568
508    result = df.drop_duplicates()
509    result.index = [0] * len(result)
510    return result
511
512
513def _nunique_df_chunk(df, *index, **kwargs):
514    name = kwargs.pop("name")
515
516    g = _groupby_raise_unaligned(df, by=index)
517    if len(df) > 0:
518        grouped = (
519            g[[name]].apply(_drop_duplicates_reindex).reset_index(level=-1, drop=True)
520        )
521    else:
522        # Manually create empty version, since groupby-apply for empty frame
523        # results in df with no columns
524        grouped = g[[name]].nunique()
525        grouped = grouped.astype(df.dtypes[grouped.columns].to_dict())
526
527    return grouped
528
529
530def _nunique_df_combine(df, levels, sort=False):
531    result = (
532        df.groupby(level=levels, sort=sort)
533        .apply(_drop_duplicates_reindex)
534        .reset_index(level=-1, drop=True)
535    )
536    return result
537
538
539def _nunique_df_aggregate(df, levels, name, sort=False):
540    return df.groupby(level=levels, sort=sort)[name].nunique()
541
542
543def _nunique_series_chunk(df, *index, **_ignored_):
544    # convert series to data frame, then hand over to dataframe code path
545    assert is_series_like(df)
546
547    df = df.to_frame()
548    kwargs = dict(name=df.columns[0], levels=_determine_levels(index))
549    return _nunique_df_chunk(df, *index, **kwargs)
550
551
552###############################################################
553# Aggregate support
554#
555# Aggregate is implemented as:
556#
557# 1. group-by-aggregate all partitions into intermediate values
558# 2. collect all partitions into a single partition
559# 3. group-by-aggregate the result into intermediate values
560# 4. transform all intermediate values into the result
561#
562# In Step 1 and 3 the dataframe is grouped on the same columns.
563#
564###############################################################
565def _make_agg_id(func, column):
566    return f"{func!s}-{column!s}-{tokenize(func, column)}"
567
568
569def _normalize_spec(spec, non_group_columns):
570    """
571    Return a list of ``(result_column, func, input_column)`` tuples.
572
573    Spec can be
574
575    - a function
576    - a list of functions
577    - a dictionary that maps input-columns to functions
578    - a dictionary that maps input-columns to a lists of functions
579    - a dictionary that maps input-columns to a dictionaries that map
580      output-columns to functions.
581
582    The non-group columns are a list of all column names that are not used in
583    the groupby operation.
584
585    Usually, the result columns are mutli-level names, returned as tuples.
586    If only a single function is supplied or dictionary mapping columns
587    to single functions, simple names are returned as strings (see the first
588    two examples below).
589
590    Examples
591    --------
592    >>> _normalize_spec('mean', ['a', 'b', 'c'])
593    [('a', 'mean', 'a'), ('b', 'mean', 'b'), ('c', 'mean', 'c')]
594
595    >>> spec = collections.OrderedDict([('a', 'mean'), ('b', 'count')])
596    >>> _normalize_spec(spec, ['a', 'b', 'c'])
597    [('a', 'mean', 'a'), ('b', 'count', 'b')]
598
599    >>> _normalize_spec(['var', 'mean'], ['a', 'b', 'c'])
600    ... # doctest: +NORMALIZE_WHITESPACE
601    [(('a', 'var'), 'var', 'a'), (('a', 'mean'), 'mean', 'a'), \
602     (('b', 'var'), 'var', 'b'), (('b', 'mean'), 'mean', 'b'), \
603     (('c', 'var'), 'var', 'c'), (('c', 'mean'), 'mean', 'c')]
604
605    >>> spec = collections.OrderedDict([('a', 'mean'), ('b', ['sum', 'count'])])
606    >>> _normalize_spec(spec, ['a', 'b', 'c'])
607    ... # doctest: +NORMALIZE_WHITESPACE
608    [(('a', 'mean'), 'mean', 'a'), (('b', 'sum'), 'sum', 'b'), \
609      (('b', 'count'), 'count', 'b')]
610
611    >>> spec = collections.OrderedDict()
612    >>> spec['a'] = ['mean', 'size']
613    >>> spec['b'] = collections.OrderedDict([('e', 'count'), ('f', 'var')])
614    >>> _normalize_spec(spec, ['a', 'b', 'c'])
615    ... # doctest: +NORMALIZE_WHITESPACE
616    [(('a', 'mean'), 'mean', 'a'), (('a', 'size'), 'size', 'a'), \
617     (('b', 'e'), 'count', 'b'), (('b', 'f'), 'var', 'b')]
618    """
619    if not isinstance(spec, dict):
620        spec = collections.OrderedDict(zip(non_group_columns, it.repeat(spec)))
621
622    res = []
623
624    if isinstance(spec, dict):
625        for input_column, subspec in spec.items():
626            if isinstance(subspec, dict):
627                res.extend(
628                    ((input_column, result_column), func, input_column)
629                    for result_column, func in subspec.items()
630                )
631
632            else:
633                if not isinstance(subspec, list):
634                    subspec = [subspec]
635
636                res.extend(
637                    ((input_column, funcname(func)), func, input_column)
638                    for func in subspec
639                )
640
641    else:
642        raise ValueError(f"unsupported agg spec of type {type(spec)}")
643
644    compounds = (list, tuple, dict)
645    use_flat_columns = not any(
646        isinstance(subspec, compounds) for subspec in spec.values()
647    )
648
649    if use_flat_columns:
650        res = [(input_col, func, input_col) for (_, func, input_col) in res]
651
652    return res
653
654
655def _build_agg_args(spec):
656    """
657    Create transformation functions for a normalized aggregate spec.
658
659    Parameters
660    ----------
661    spec: a list of (result-column, aggregation-function, input-column) triples.
662        To work with all argument forms understood by pandas use
663        ``_normalize_spec`` to normalize the argment before passing it on to
664        ``_build_agg_args``.
665
666    Returns
667    -------
668    chunk_funcs: a list of (intermediate-column, function, keyword) triples
669        that are applied on grouped chunks of the initial dataframe.
670
671    agg_funcs: a list of (intermediate-column, functions, keyword) triples that
672        are applied on the grouped concatination of the preprocessed chunks.
673
674    finalizers: a list of (result-column, function, keyword) triples that are
675        applied after the ``agg_funcs``. They are used to create final results
676        from intermediate representations.
677    """
678    known_np_funcs = {np.min: "min", np.max: "max"}
679
680    # check that there are no name conflicts for a single input column
681    by_name = {}
682    for _, func, input_column in spec:
683        key = funcname(known_np_funcs.get(func, func)), input_column
684        by_name.setdefault(key, []).append((func, input_column))
685
686    for funcs in by_name.values():
687        if len(funcs) != 1:
688            raise ValueError(f"conflicting aggregation functions: {funcs}")
689
690    chunks = {}
691    aggs = {}
692    finalizers = []
693
694    for (result_column, func, input_column) in spec:
695        if not isinstance(func, Aggregation):
696            func = funcname(known_np_funcs.get(func, func))
697
698        impls = _build_agg_args_single(result_column, func, input_column)
699
700        # overwrite existing result-columns, generate intermediates only once
701        for spec in impls["chunk_funcs"]:
702            chunks[spec[0]] = spec
703        for spec in impls["aggregate_funcs"]:
704            aggs[spec[0]] = spec
705
706        finalizers.append(impls["finalizer"])
707
708    chunks = sorted(chunks.values())
709    aggs = sorted(aggs.values())
710
711    return chunks, aggs, finalizers
712
713
714def _build_agg_args_single(result_column, func, input_column):
715    simple_impl = {
716        "sum": (M.sum, M.sum),
717        "min": (M.min, M.min),
718        "max": (M.max, M.max),
719        "count": (M.count, M.sum),
720        "size": (M.size, M.sum),
721        "first": (M.first, M.first),
722        "last": (M.last, M.last),
723        "prod": (M.prod, M.prod),
724    }
725
726    if func in simple_impl.keys():
727        return _build_agg_args_simple(
728            result_column, func, input_column, simple_impl[func]
729        )
730
731    elif func == "var":
732        return _build_agg_args_var(result_column, func, input_column)
733
734    elif func == "std":
735        return _build_agg_args_std(result_column, func, input_column)
736
737    elif func == "mean":
738        return _build_agg_args_mean(result_column, func, input_column)
739
740    elif func == "list":
741        return _build_agg_args_list(result_column, func, input_column)
742
743    elif isinstance(func, Aggregation):
744        return _build_agg_args_custom(result_column, func, input_column)
745
746    else:
747        raise ValueError(f"unknown aggregate {func}")
748
749
750def _build_agg_args_simple(result_column, func, input_column, impl_pair):
751    intermediate = _make_agg_id(func, input_column)
752    chunk_impl, agg_impl = impl_pair
753
754    return dict(
755        chunk_funcs=[
756            (
757                intermediate,
758                _apply_func_to_column,
759                dict(column=input_column, func=chunk_impl),
760            )
761        ],
762        aggregate_funcs=[
763            (
764                intermediate,
765                _apply_func_to_column,
766                dict(column=intermediate, func=agg_impl),
767            )
768        ],
769        finalizer=(result_column, itemgetter(intermediate), dict()),
770    )
771
772
773def _build_agg_args_var(result_column, func, input_column):
774    int_sum = _make_agg_id("sum", input_column)
775    int_sum2 = _make_agg_id("sum2", input_column)
776    int_count = _make_agg_id("count", input_column)
777
778    return dict(
779        chunk_funcs=[
780            (int_sum, _apply_func_to_column, dict(column=input_column, func=M.sum)),
781            (int_count, _apply_func_to_column, dict(column=input_column, func=M.count)),
782            (int_sum2, _compute_sum_of_squares, dict(column=input_column)),
783        ],
784        aggregate_funcs=[
785            (col, _apply_func_to_column, dict(column=col, func=M.sum))
786            for col in (int_sum, int_count, int_sum2)
787        ],
788        finalizer=(
789            result_column,
790            _finalize_var,
791            dict(sum_column=int_sum, count_column=int_count, sum2_column=int_sum2),
792        ),
793    )
794
795
796def _build_agg_args_std(result_column, func, input_column):
797    impls = _build_agg_args_var(result_column, func, input_column)
798
799    result_column, _, kwargs = impls["finalizer"]
800    impls["finalizer"] = (result_column, _finalize_std, kwargs)
801
802    return impls
803
804
805def _build_agg_args_mean(result_column, func, input_column):
806    int_sum = _make_agg_id("sum", input_column)
807    int_count = _make_agg_id("count", input_column)
808
809    return dict(
810        chunk_funcs=[
811            (int_sum, _apply_func_to_column, dict(column=input_column, func=M.sum)),
812            (int_count, _apply_func_to_column, dict(column=input_column, func=M.count)),
813        ],
814        aggregate_funcs=[
815            (col, _apply_func_to_column, dict(column=col, func=M.sum))
816            for col in (int_sum, int_count)
817        ],
818        finalizer=(
819            result_column,
820            _finalize_mean,
821            dict(sum_column=int_sum, count_column=int_count),
822        ),
823    )
824
825
826def _build_agg_args_list(result_column, func, input_column):
827    intermediate = _make_agg_id("list", input_column)
828
829    return dict(
830        chunk_funcs=[
831            (
832                intermediate,
833                _apply_func_to_column,
834                dict(column=input_column, func=lambda s: s.apply(list)),
835            )
836        ],
837        aggregate_funcs=[
838            (
839                intermediate,
840                _apply_func_to_column,
841                dict(
842                    column=intermediate,
843                    func=lambda s0: s0.apply(
844                        lambda chunks: list(it.chain.from_iterable(chunks))
845                    ),
846                ),
847            )
848        ],
849        finalizer=(result_column, itemgetter(intermediate), dict()),
850    )
851
852
853def _build_agg_args_custom(result_column, func, input_column):
854    col = _make_agg_id(funcname(func), input_column)
855
856    if func.finalize is None:
857        finalizer = (result_column, operator.itemgetter(col), dict())
858
859    else:
860        finalizer = (
861            result_column,
862            _apply_func_to_columns,
863            dict(func=func.finalize, prefix=col),
864        )
865
866    return dict(
867        chunk_funcs=[
868            (col, _apply_func_to_column, dict(func=func.chunk, column=input_column))
869        ],
870        aggregate_funcs=[
871            (col, _apply_func_to_columns, dict(func=func.agg, prefix=col))
872        ],
873        finalizer=finalizer,
874    )
875
876
877def _groupby_apply_funcs(df, *index, **kwargs):
878    """
879    Group a dataframe and apply multiple aggregation functions.
880
881    Parameters
882    ----------
883    df: pandas.DataFrame
884        The dataframe to work on.
885    index: list of groupers
886        If given, they are added to the keyword arguments as the ``by``
887        argument.
888    funcs: list of result-colum, function, keywordargument triples
889        The list of functions that are applied on the grouped data frame.
890        Has to be passed as a keyword argument.
891    kwargs:
892        All keyword arguments, but ``funcs``, are passed verbatim to the groupby
893        operation of the dataframe
894
895    Returns
896    -------
897    aggregated:
898        the aggregated dataframe.
899    """
900    if len(index):
901        # since we're coming through apply, `by` will be a tuple.
902        # Pandas treats tuples as a single key, and lists as multiple keys
903        # We want multiple keys
904        kwargs.update(by=list(index))
905
906    funcs = kwargs.pop("funcs")
907    grouped = _groupby_raise_unaligned(df, **kwargs)
908
909    result = collections.OrderedDict()
910    for result_column, func, func_kwargs in funcs:
911        r = func(grouped, **func_kwargs)
912
913        if isinstance(r, tuple):
914            for idx, s in enumerate(r):
915                result[f"{result_column}-{idx}"] = s
916
917        else:
918            result[result_column] = r
919
920    if is_dataframe_like(df):
921        return df.__class__(result)
922    else:
923        # Get the DataFrame type of this Series object
924        return df.head(0).to_frame().__class__(result)
925
926
927def _compute_sum_of_squares(grouped, column):
928    # Note: CuDF cannot use `groupby.apply`.
929    # Need to unpack groupby to compute sum of squares
930    if hasattr(grouped, "grouper"):
931        keys = grouped.grouper
932    else:
933        # Handle CuDF groupby object (different from pandas)
934        keys = grouped.grouping.keys
935    df = grouped.obj[column].pow(2) if column else grouped.obj.pow(2)
936    return df.groupby(keys).sum()
937
938
939def _agg_finalize(df, aggregate_funcs, finalize_funcs, level, sort=False, **kwargs):
940    # finish the final aggregation level
941    df = _groupby_apply_funcs(
942        df, funcs=aggregate_funcs, level=level, sort=sort, **kwargs
943    )
944
945    # and finalize the result
946    result = collections.OrderedDict()
947    for result_column, func, finalize_kwargs in finalize_funcs:
948        result[result_column] = func(df, **finalize_kwargs)
949
950    return df.__class__(result)
951
952
953def _apply_func_to_column(df_like, column, func):
954    if column is None:
955        return func(df_like)
956
957    return func(df_like[column])
958
959
960def _apply_func_to_columns(df_like, prefix, func):
961    if is_dataframe_like(df_like):
962        columns = df_like.columns
963    else:
964        # handle GroupBy objects
965        columns = df_like.obj.columns
966
967    columns = sorted(col for col in columns if col.startswith(prefix))
968
969    columns = [df_like[col] for col in columns]
970    return func(*columns)
971
972
973def _finalize_mean(df, sum_column, count_column):
974    return df[sum_column] / df[count_column]
975
976
977def _finalize_var(df, count_column, sum_column, sum2_column, ddof=1):
978    n = df[count_column]
979    x = df[sum_column]
980    x2 = df[sum2_column]
981
982    result = x2 - x ** 2 / n
983    div = n - ddof
984    div[div < 0] = 0
985    result /= div
986    result[(n - ddof) == 0] = np.nan
987
988    return result
989
990
991def _finalize_std(df, count_column, sum_column, sum2_column, ddof=1):
992    result = _finalize_var(df, count_column, sum_column, sum2_column, ddof)
993    return np.sqrt(result)
994
995
996def _cum_agg_aligned(part, cum_last, index, columns, func, initial):
997    align = cum_last.reindex(part.set_index(index).index, fill_value=initial)
998    align.index = part.index
999    return func(part[columns], align)
1000
1001
1002def _cum_agg_filled(a, b, func, initial):
1003    union = a.index.union(b.index)
1004    return func(
1005        a.reindex(union, fill_value=initial),
1006        b.reindex(union, fill_value=initial),
1007        fill_value=initial,
1008    )
1009
1010
1011def _cumcount_aggregate(a, b, fill_value=None):
1012    return a.add(b, fill_value=fill_value) + 1
1013
1014
1015class _GroupBy:
1016    """Superclass for DataFrameGroupBy and SeriesGroupBy
1017
1018    Parameters
1019    ----------
1020
1021    obj: DataFrame or Series
1022        DataFrame or Series to be grouped
1023    by: str, list or Series
1024        The key for grouping
1025    slice: str, list
1026        The slice keys applied to GroupBy result
1027    group_keys: bool
1028        Passed to pandas.DataFrame.groupby()
1029    dropna: bool
1030        Whether to drop null values from groupby index
1031    sort: bool, defult None
1032        Passed along to aggregation methods. If allowed,
1033        the output aggregation will have sorted keys.
1034    observed: bool, default False
1035        This only applies if any of the groupers are Categoricals.
1036        If True: only show observed values for categorical groupers.
1037        If False: show all values for categorical groupers.
1038    """
1039
1040    def __init__(
1041        self,
1042        df,
1043        by=None,
1044        slice=None,
1045        group_keys=True,
1046        dropna=None,
1047        sort=None,
1048        observed=None,
1049    ):
1050
1051        by_ = by if isinstance(by, (tuple, list)) else [by]
1052        if any(isinstance(key, pd.Grouper) for key in by_):
1053            raise NotImplementedError("pd.Grouper is currently not supported by Dask.")
1054
1055        assert isinstance(df, (DataFrame, Series))
1056        self.group_keys = group_keys
1057        self.obj = df
1058        # grouping key passed via groupby method
1059        self.index = _normalize_index(df, by)
1060        self.sort = sort
1061
1062        if isinstance(self.index, list):
1063            do_index_partition_align = all(
1064                item.npartitions == df.npartitions if isinstance(item, Series) else True
1065                for item in self.index
1066            )
1067        elif isinstance(self.index, Series):
1068            do_index_partition_align = df.npartitions == self.index.npartitions
1069        else:
1070            do_index_partition_align = True
1071
1072        if not do_index_partition_align:
1073            raise NotImplementedError(
1074                "The grouped object and index of the "
1075                "groupby must have the same divisions."
1076            )
1077
1078        # slicing key applied to _GroupBy instance
1079        self._slice = slice
1080
1081        if isinstance(self.index, list):
1082            index_meta = [
1083                item._meta if isinstance(item, Series) else item for item in self.index
1084            ]
1085
1086        elif isinstance(self.index, Series):
1087            index_meta = self.index._meta
1088
1089        else:
1090            index_meta = self.index
1091
1092        self.dropna = {}
1093        if dropna is not None:
1094            self.dropna["dropna"] = dropna
1095
1096        # Hold off on setting observed by default: https://github.com/dask/dask/issues/6951
1097        self.observed = {}
1098        if observed is not None:
1099            self.observed["observed"] = observed
1100
1101        self._meta = self.obj._meta.groupby(
1102            index_meta, group_keys=group_keys, **self.observed, **self.dropna
1103        )
1104
1105    @property
1106    def _groupby_kwargs(self):
1107        return {
1108            "by": self.index,
1109            "group_keys": self.group_keys,
1110            **self.dropna,
1111            "sort": self.sort,
1112            **self.observed,
1113        }
1114
1115    @property
1116    def _meta_nonempty(self):
1117        """
1118        Return a pd.DataFrameGroupBy / pd.SeriesGroupBy which contains sample data.
1119        """
1120        sample = self.obj._meta_nonempty
1121
1122        if isinstance(self.index, list):
1123            index_meta = [
1124                item._meta_nonempty if isinstance(item, Series) else item
1125                for item in self.index
1126            ]
1127
1128        elif isinstance(self.index, Series):
1129            index_meta = self.index._meta_nonempty
1130
1131        else:
1132            index_meta = self.index
1133
1134        grouped = sample.groupby(
1135            index_meta,
1136            group_keys=self.group_keys,
1137            **self.observed,
1138            **self.dropna,
1139        )
1140        return _maybe_slice(grouped, self._slice)
1141
1142    def _aca_agg(
1143        self,
1144        token,
1145        func,
1146        aggfunc=None,
1147        meta=None,
1148        split_every=None,
1149        split_out=1,
1150        chunk_kwargs={},
1151        aggregate_kwargs={},
1152    ):
1153        if aggfunc is None:
1154            aggfunc = func
1155
1156        if meta is None:
1157            meta = func(self._meta_nonempty)
1158
1159        columns = meta.name if is_series_like(meta) else meta.columns
1160
1161        token = self._token_prefix + token
1162        levels = _determine_levels(self.index)
1163
1164        return aca(
1165            [self.obj, self.index]
1166            if not isinstance(self.index, list)
1167            else [self.obj] + self.index,
1168            chunk=_apply_chunk,
1169            chunk_kwargs=dict(
1170                chunk=func,
1171                columns=columns,
1172                **self.observed,
1173                **chunk_kwargs,
1174                **self.dropna,
1175            ),
1176            aggregate=_groupby_aggregate,
1177            meta=meta,
1178            token=token,
1179            split_every=split_every,
1180            aggregate_kwargs=dict(
1181                aggfunc=aggfunc,
1182                levels=levels,
1183                **self.observed,
1184                **aggregate_kwargs,
1185                **self.dropna,
1186            ),
1187            split_out=split_out,
1188            split_out_setup=split_out_on_index,
1189            sort=self.sort,
1190        )
1191
1192    def _cum_agg(self, token, chunk, aggregate, initial):
1193        """Wrapper for cumulative groupby operation"""
1194        meta = chunk(self._meta)
1195        columns = meta.name if is_series_like(meta) else meta.columns
1196        index = self.index if isinstance(self.index, list) else [self.index]
1197
1198        name = self._token_prefix + token
1199        name_part = name + "-map"
1200        name_last = name + "-take-last"
1201        name_cum = name + "-cum-last"
1202
1203        # cumulate each partitions
1204        cumpart_raw = map_partitions(
1205            _apply_chunk,
1206            self.obj,
1207            *index,
1208            chunk=chunk,
1209            columns=columns,
1210            token=name_part,
1211            meta=meta,
1212            **self.dropna,
1213        )
1214
1215        cumpart_raw_frame = (
1216            cumpart_raw.to_frame() if is_series_like(meta) else cumpart_raw
1217        )
1218
1219        cumpart_ext = cumpart_raw_frame.assign(
1220            **{
1221                i: self.obj[i]
1222                if np.isscalar(i) and i in self.obj.columns
1223                else self.obj.index
1224                for i in index
1225            }
1226        )
1227
1228        # Use pd.Grouper objects to specify that we are grouping by columns.
1229        # Otherwise, pandas will throw an ambiguity warning if the
1230        # DataFrame's index (self.obj.index) was included in the grouping
1231        # specification (self.index). See pandas #14432
1232        index_groupers = [pd.Grouper(key=ind) for ind in index]
1233        cumlast = map_partitions(
1234            _apply_chunk,
1235            cumpart_ext,
1236            *index_groupers,
1237            columns=0 if columns is None else columns,
1238            chunk=M.last,
1239            meta=meta,
1240            token=name_last,
1241            **self.dropna,
1242        )
1243
1244        # aggregate cumulated partitions and its previous last element
1245        _hash = tokenize(self, token, chunk, aggregate, initial)
1246        name += "-" + _hash
1247        name_cum += "-" + _hash
1248        dask = {}
1249        dask[(name, 0)] = (cumpart_raw._name, 0)
1250
1251        for i in range(1, self.obj.npartitions):
1252            # store each cumulative step to graph to reduce computation
1253            if i == 1:
1254                dask[(name_cum, i)] = (cumlast._name, i - 1)
1255            else:
1256                # aggregate with previous cumulation results
1257                dask[(name_cum, i)] = (
1258                    _cum_agg_filled,
1259                    (name_cum, i - 1),
1260                    (cumlast._name, i - 1),
1261                    aggregate,
1262                    initial,
1263                )
1264            dask[(name, i)] = (
1265                _cum_agg_aligned,
1266                (cumpart_ext._name, i),
1267                (name_cum, i),
1268                index,
1269                0 if columns is None else columns,
1270                aggregate,
1271                initial,
1272            )
1273        graph = HighLevelGraph.from_collections(
1274            name, dask, dependencies=[cumpart_raw, cumpart_ext, cumlast]
1275        )
1276        return new_dd_object(graph, name, chunk(self._meta), self.obj.divisions)
1277
1278    def _shuffle(self, meta):
1279        df = self.obj
1280
1281        if isinstance(self.obj, Series):
1282            # Temporarily convert series to dataframe for shuffle
1283            df = df.to_frame("__series__")
1284            convert_back_to_series = True
1285        else:
1286            convert_back_to_series = False
1287
1288        if isinstance(self.index, DataFrame):  # add index columns to dataframe
1289            df2 = df.assign(
1290                **{"_index_" + c: self.index[c] for c in self.index.columns}
1291            )
1292            index = self.index
1293        elif isinstance(self.index, Series):
1294            df2 = df.assign(_index=self.index)
1295            index = self.index
1296        else:
1297            df2 = df
1298            index = df._select_columns_or_index(self.index)
1299
1300        df3 = shuffle(df2, index)  # shuffle dataframe and index
1301
1302        if isinstance(self.index, DataFrame):
1303            # extract index from dataframe
1304            cols = ["_index_" + c for c in self.index.columns]
1305            index2 = df3[cols]
1306            if is_dataframe_like(meta):
1307                df4 = df3.map_partitions(drop_columns, cols, meta.columns.dtype)
1308            else:
1309                df4 = df3.drop(cols, axis=1)
1310        elif isinstance(self.index, Series):
1311            index2 = df3["_index"]
1312            index2.name = self.index.name
1313            if is_dataframe_like(meta):
1314                df4 = df3.map_partitions(drop_columns, "_index", meta.columns.dtype)
1315            else:
1316                df4 = df3.drop("_index", axis=1)
1317        else:
1318            df4 = df3
1319            index2 = self.index
1320
1321        if convert_back_to_series:
1322            df4 = df4["__series__"].rename(self.obj.name)
1323
1324        return df4, index2
1325
1326    @derived_from(pd.core.groupby.GroupBy)
1327    def cumsum(self, axis=0):
1328        if axis:
1329            return self.obj.cumsum(axis=axis)
1330        else:
1331            return self._cum_agg("cumsum", chunk=M.cumsum, aggregate=M.add, initial=0)
1332
1333    @derived_from(pd.core.groupby.GroupBy)
1334    def cumprod(self, axis=0):
1335        if axis:
1336            return self.obj.cumprod(axis=axis)
1337        else:
1338            return self._cum_agg("cumprod", chunk=M.cumprod, aggregate=M.mul, initial=1)
1339
1340    @derived_from(pd.core.groupby.GroupBy)
1341    def cumcount(self, axis=None):
1342        return self._cum_agg(
1343            "cumcount", chunk=M.cumcount, aggregate=_cumcount_aggregate, initial=-1
1344        )
1345
1346    @derived_from(pd.core.groupby.GroupBy)
1347    def sum(self, split_every=None, split_out=1, min_count=None):
1348        result = self._aca_agg(
1349            token="sum", func=M.sum, split_every=split_every, split_out=split_out
1350        )
1351        if min_count:
1352            return result.where(self.count() >= min_count, other=np.NaN)
1353        else:
1354            return result
1355
1356    @derived_from(pd.core.groupby.GroupBy)
1357    def prod(self, split_every=None, split_out=1, min_count=None):
1358        result = self._aca_agg(
1359            token="prod", func=M.prod, split_every=split_every, split_out=split_out
1360        )
1361        if min_count:
1362            return result.where(self.count() >= min_count, other=np.NaN)
1363        else:
1364            return result
1365
1366    @derived_from(pd.core.groupby.GroupBy)
1367    def min(self, split_every=None, split_out=1):
1368        return self._aca_agg(
1369            token="min", func=M.min, split_every=split_every, split_out=split_out
1370        )
1371
1372    @derived_from(pd.core.groupby.GroupBy)
1373    def max(self, split_every=None, split_out=1):
1374        return self._aca_agg(
1375            token="max", func=M.max, split_every=split_every, split_out=split_out
1376        )
1377
1378    @derived_from(pd.DataFrame)
1379    def idxmin(self, split_every=None, split_out=1, axis=None, skipna=True):
1380        return self._aca_agg(
1381            token="idxmin",
1382            func=M.idxmin,
1383            aggfunc=M.first,
1384            split_every=split_every,
1385            split_out=split_out,
1386            chunk_kwargs=dict(skipna=skipna),
1387        )
1388
1389    @derived_from(pd.DataFrame)
1390    def idxmax(self, split_every=None, split_out=1, axis=None, skipna=True):
1391        return self._aca_agg(
1392            token="idxmax",
1393            func=M.idxmax,
1394            aggfunc=M.first,
1395            split_every=split_every,
1396            split_out=split_out,
1397            chunk_kwargs=dict(skipna=skipna),
1398        )
1399
1400    @derived_from(pd.core.groupby.GroupBy)
1401    def count(self, split_every=None, split_out=1):
1402        return self._aca_agg(
1403            token="count",
1404            func=M.count,
1405            aggfunc=M.sum,
1406            split_every=split_every,
1407            split_out=split_out,
1408        )
1409
1410    @derived_from(pd.core.groupby.GroupBy)
1411    def mean(self, split_every=None, split_out=1):
1412        s = self.sum(split_every=split_every, split_out=split_out)
1413        c = self.count(split_every=split_every, split_out=split_out)
1414        if is_dataframe_like(s):
1415            c = c[s.columns]
1416        return s / c
1417
1418    @derived_from(pd.core.groupby.GroupBy)
1419    def size(self, split_every=None, split_out=1):
1420        return self._aca_agg(
1421            token="size",
1422            func=M.size,
1423            aggfunc=M.sum,
1424            split_every=split_every,
1425            split_out=split_out,
1426        )
1427
1428    @derived_from(pd.core.groupby.GroupBy)
1429    def var(self, ddof=1, split_every=None, split_out=1):
1430        levels = _determine_levels(self.index)
1431        result = aca(
1432            [self.obj, self.index]
1433            if not isinstance(self.index, list)
1434            else [self.obj] + self.index,
1435            chunk=_var_chunk,
1436            aggregate=_var_agg,
1437            combine=_var_combine,
1438            token=self._token_prefix + "var",
1439            aggregate_kwargs={"ddof": ddof, "levels": levels},
1440            combine_kwargs={"levels": levels},
1441            split_every=split_every,
1442            split_out=split_out,
1443            split_out_setup=split_out_on_index,
1444            sort=self.sort,
1445        )
1446
1447        if isinstance(self.obj, Series):
1448            result = result[result.columns[0]]
1449        if self._slice:
1450            result = result[self._slice]
1451
1452        return result
1453
1454    @derived_from(pd.core.groupby.GroupBy)
1455    def std(self, ddof=1, split_every=None, split_out=1):
1456        v = self.var(ddof, split_every=split_every, split_out=split_out)
1457        result = map_partitions(np.sqrt, v, meta=v)
1458        return result
1459
1460    @derived_from(pd.DataFrame)
1461    def corr(self, ddof=1, split_every=None, split_out=1):
1462        """Groupby correlation:
1463        corr(X, Y) = cov(X, Y) / (std_x * std_y)
1464        """
1465        return self.cov(split_every=split_every, split_out=split_out, std=True)
1466
1467    @derived_from(pd.DataFrame)
1468    def cov(self, ddof=1, split_every=None, split_out=1, std=False):
1469        """Groupby covariance is accomplished by
1470
1471        1. Computing intermediate values for sum, count, and the product of
1472           all columns: a b c -> a*a, a*b, b*b, b*c, c*c.
1473
1474        2. The values are then aggregated and the final covariance value is calculated:
1475           cov(X, Y) = X*Y - Xbar * Ybar
1476
1477        When `std` is True calculate Correlation
1478        """
1479
1480        levels = _determine_levels(self.index)
1481
1482        is_mask = any(is_series_like(s) for s in self.index)
1483        if self._slice:
1484            if is_mask:
1485                self.obj = self.obj[self._slice]
1486            else:
1487                sliced_plus = list(self._slice) + list(self.index)
1488                self.obj = self.obj[sliced_plus]
1489
1490        result = aca(
1491            [self.obj, self.index]
1492            if not isinstance(self.index, list)
1493            else [self.obj] + self.index,
1494            chunk=_cov_chunk,
1495            aggregate=_cov_agg,
1496            combine=_cov_combine,
1497            token=self._token_prefix + "cov",
1498            aggregate_kwargs={"ddof": ddof, "levels": levels, "std": std},
1499            combine_kwargs={"levels": levels},
1500            split_every=split_every,
1501            split_out=split_out,
1502            split_out_setup=split_out_on_index,
1503            sort=self.sort,
1504        )
1505
1506        if isinstance(self.obj, Series):
1507            result = result[result.columns[0]]
1508        if self._slice:
1509            result = result[self._slice]
1510        return result
1511
1512    @derived_from(pd.core.groupby.GroupBy)
1513    def first(self, split_every=None, split_out=1):
1514        return self._aca_agg(
1515            token="first", func=M.first, split_every=split_every, split_out=split_out
1516        )
1517
1518    @derived_from(pd.core.groupby.GroupBy)
1519    def last(self, split_every=None, split_out=1):
1520        return self._aca_agg(
1521            token="last", func=M.last, split_every=split_every, split_out=split_out
1522        )
1523
1524    @derived_from(pd.core.groupby.GroupBy)
1525    def get_group(self, key):
1526        token = self._token_prefix + "get_group"
1527
1528        meta = self._meta.obj
1529        if is_dataframe_like(meta) and self._slice is not None:
1530            meta = meta[self._slice]
1531        columns = meta.columns if is_dataframe_like(meta) else meta.name
1532
1533        return map_partitions(
1534            _groupby_get_group,
1535            self.obj,
1536            self.index,
1537            key,
1538            columns,
1539            meta=meta,
1540            token=token,
1541        )
1542
1543    def aggregate(self, arg, split_every, split_out=1):
1544        if isinstance(self.obj, DataFrame):
1545            if isinstance(self.index, tuple) or np.isscalar(self.index):
1546                group_columns = {self.index}
1547
1548            elif isinstance(self.index, list):
1549                group_columns = {
1550                    i for i in self.index if isinstance(i, tuple) or np.isscalar(i)
1551                }
1552
1553            else:
1554                group_columns = set()
1555
1556            if self._slice:
1557                # pandas doesn't exclude the grouping column in a SeriesGroupBy
1558                # like df.groupby('a')['a'].agg(...)
1559                non_group_columns = self._slice
1560                if not isinstance(non_group_columns, list):
1561                    non_group_columns = [non_group_columns]
1562            else:
1563                # NOTE: this step relies on the index normalization to replace
1564                #       series with their name in an index.
1565                non_group_columns = [
1566                    col for col in self.obj.columns if col not in group_columns
1567                ]
1568
1569            spec = _normalize_spec(arg, non_group_columns)
1570
1571        elif isinstance(self.obj, Series):
1572            if isinstance(arg, (list, tuple, dict)):
1573                # implementation detail: if self.obj is a series, a pseudo column
1574                # None is used to denote the series itself. This pseudo column is
1575                # removed from the result columns before passing the spec along.
1576                spec = _normalize_spec({None: arg}, [])
1577                spec = [
1578                    (result_column, func, input_column)
1579                    for ((_, result_column), func, input_column) in spec
1580                ]
1581
1582            else:
1583                spec = _normalize_spec({None: arg}, [])
1584                spec = [
1585                    (self.obj.name, func, input_column)
1586                    for (_, func, input_column) in spec
1587                ]
1588
1589        else:
1590            raise ValueError(f"aggregate on unknown object {self.obj}")
1591
1592        chunk_funcs, aggregate_funcs, finalizers = _build_agg_args(spec)
1593
1594        if isinstance(self.index, (tuple, list)) and len(self.index) > 1:
1595            levels = list(range(len(self.index)))
1596        else:
1597            levels = 0
1598
1599        if not isinstance(self.index, list):
1600            chunk_args = [self.obj, self.index]
1601
1602        else:
1603            chunk_args = [self.obj] + self.index
1604
1605        if not PANDAS_GT_110 and self.dropna:
1606            raise NotImplementedError(
1607                "dropna is not a valid argument for dask.groupby.agg"
1608                f"if pandas < 1.1.0. Pandas version is {pd.__version__}"
1609            )
1610
1611        return aca(
1612            chunk_args,
1613            chunk=_groupby_apply_funcs,
1614            chunk_kwargs=dict(funcs=chunk_funcs, **self.observed, **self.dropna),
1615            combine=_groupby_apply_funcs,
1616            combine_kwargs=dict(
1617                funcs=aggregate_funcs, level=levels, **self.observed, **self.dropna
1618            ),
1619            aggregate=_agg_finalize,
1620            aggregate_kwargs=dict(
1621                aggregate_funcs=aggregate_funcs,
1622                finalize_funcs=finalizers,
1623                level=levels,
1624                **self.observed,
1625                **self.dropna,
1626            ),
1627            token="aggregate",
1628            split_every=split_every,
1629            split_out=split_out,
1630            split_out_setup=split_out_on_index,
1631            sort=self.sort,
1632        )
1633
1634    @insert_meta_param_description(pad=12)
1635    def apply(self, func, *args, **kwargs):
1636        """Parallel version of pandas GroupBy.apply
1637
1638        This mimics the pandas version except for the following:
1639
1640        1.  If the grouper does not align with the index then this causes a full
1641            shuffle.  The order of rows within each group may not be preserved.
1642        2.  Dask's GroupBy.apply is not appropriate for aggregations. For custom
1643            aggregations, use :class:`dask.dataframe.groupby.Aggregation`.
1644
1645        .. warning::
1646
1647           Pandas' groupby-apply can be used to to apply arbitrary functions,
1648           including aggregations that result in one row per group. Dask's
1649           groupby-apply will apply ``func`` once to each partition-group pair,
1650           so when ``func`` is a reduction you'll end up with one row per
1651           partition-group pair. To apply a custom aggregation with Dask,
1652           use :class:`dask.dataframe.groupby.Aggregation`.
1653
1654        Parameters
1655        ----------
1656        func: function
1657            Function to apply
1658        args, kwargs : Scalar, Delayed or object
1659            Arguments and keywords to pass to the function.
1660        $META
1661
1662        Returns
1663        -------
1664        applied : Series or DataFrame depending on columns keyword
1665        """
1666        meta = kwargs.get("meta", no_default)
1667
1668        if meta is no_default:
1669            with raise_on_meta_error(f"groupby.apply({funcname(func)})", udf=True):
1670                meta_args, meta_kwargs = _extract_meta((args, kwargs), nonempty=True)
1671                meta = self._meta_nonempty.apply(func, *meta_args, **meta_kwargs)
1672
1673            msg = (
1674                "`meta` is not specified, inferred from partial data. "
1675                "Please provide `meta` if the result is unexpected.\n"
1676                "  Before: .apply(func)\n"
1677                "  After:  .apply(func, meta={'x': 'f8', 'y': 'f8'}) for dataframe result\n"
1678                "  or:     .apply(func, meta=('x', 'f8'))            for series result"
1679            )
1680            warnings.warn(msg, stacklevel=2)
1681
1682        meta = make_meta(meta, parent_meta=self._meta.obj)
1683
1684        # Validate self.index
1685        if isinstance(self.index, list) and any(
1686            isinstance(item, Series) for item in self.index
1687        ):
1688            raise NotImplementedError(
1689                "groupby-apply with a multiple Series is currently not supported"
1690            )
1691
1692        df = self.obj
1693        should_shuffle = not (
1694            df.known_divisions and df._contains_index_name(self.index)
1695        )
1696
1697        if should_shuffle:
1698            df2, index = self._shuffle(meta)
1699        else:
1700            df2 = df
1701            index = self.index
1702
1703        # Perform embarrassingly parallel groupby-apply
1704        kwargs["meta"] = meta
1705        df3 = map_partitions(
1706            _groupby_slice_apply,
1707            df2,
1708            index,
1709            self._slice,
1710            func,
1711            token=funcname(func),
1712            *args,
1713            group_keys=self.group_keys,
1714            **self.observed,
1715            **self.dropna,
1716            **kwargs,
1717        )
1718
1719        return df3
1720
1721    @insert_meta_param_description(pad=12)
1722    def transform(self, func, *args, **kwargs):
1723        """Parallel version of pandas GroupBy.transform
1724
1725        This mimics the pandas version except for the following:
1726
1727        1.  If the grouper does not align with the index then this causes a full
1728            shuffle.  The order of rows within each group may not be preserved.
1729        2.  Dask's GroupBy.transform is not appropriate for aggregations. For custom
1730            aggregations, use :class:`dask.dataframe.groupby.Aggregation`.
1731
1732        .. warning::
1733
1734           Pandas' groupby-transform can be used to to apply arbitrary functions,
1735           including aggregations that result in one row per group. Dask's
1736           groupby-transform will apply ``func`` once to each partition-group pair,
1737           so when ``func`` is a reduction you'll end up with one row per
1738           partition-group pair. To apply a custom aggregation with Dask,
1739           use :class:`dask.dataframe.groupby.Aggregation`.
1740
1741        Parameters
1742        ----------
1743        func: function
1744            Function to apply
1745        args, kwargs : Scalar, Delayed or object
1746            Arguments and keywords to pass to the function.
1747        $META
1748
1749        Returns
1750        -------
1751        applied : Series or DataFrame depending on columns keyword
1752        """
1753        meta = kwargs.get("meta", no_default)
1754
1755        if meta is no_default:
1756            with raise_on_meta_error(f"groupby.transform({funcname(func)})", udf=True):
1757                meta_args, meta_kwargs = _extract_meta((args, kwargs), nonempty=True)
1758                meta = self._meta_nonempty.transform(func, *meta_args, **meta_kwargs)
1759
1760            msg = (
1761                "`meta` is not specified, inferred from partial data. "
1762                "Please provide `meta` if the result is unexpected.\n"
1763                "  Before: .transform(func)\n"
1764                "  After:  .transform(func, meta={'x': 'f8', 'y': 'f8'}) for dataframe result\n"
1765                "  or:     .transform(func, meta=('x', 'f8'))            for series result"
1766            )
1767            warnings.warn(msg, stacklevel=2)
1768
1769        meta = make_meta(meta, parent_meta=self._meta.obj)
1770
1771        # Validate self.index
1772        if isinstance(self.index, list) and any(
1773            isinstance(item, Series) for item in self.index
1774        ):
1775            raise NotImplementedError(
1776                "groupby-transform with a multiple Series is currently not supported"
1777            )
1778
1779        df = self.obj
1780        should_shuffle = not (
1781            df.known_divisions and df._contains_index_name(self.index)
1782        )
1783
1784        if should_shuffle:
1785            df2, index = self._shuffle(meta)
1786        else:
1787            df2 = df
1788            index = self.index
1789
1790        # Perform embarrassingly parallel groupby-transform
1791        kwargs["meta"] = meta
1792        df3 = map_partitions(
1793            _groupby_slice_transform,
1794            df2,
1795            index,
1796            self._slice,
1797            func,
1798            token=funcname(func),
1799            *args,
1800            group_keys=self.group_keys,
1801            **self.observed,
1802            **self.dropna,
1803            **kwargs,
1804        )
1805
1806        return df3
1807
1808    def rolling(self, window, min_periods=None, center=False, win_type=None, axis=0):
1809        """Provides rolling transformations.
1810
1811        .. note::
1812
1813            Since MultiIndexes are not well supported in Dask, this method returns a
1814            dataframe with the same index as the original data. The groupby column is
1815            not added as the first level of the index like pandas does.
1816
1817            This method works differently from other groupby methods. It does a groupby
1818            on each partition (plus some overlap). This means that the output has the
1819            same shape and number of partitions as the original.
1820
1821        Parameters
1822        ----------
1823        window : str, offset
1824           Size of the moving window. This is the number of observations used
1825           for calculating the statistic. Data must have a ``DatetimeIndex``
1826        min_periods : int, default None
1827            Minimum number of observations in window required to have a value
1828            (otherwise result is NA).
1829        center : boolean, default False
1830            Set the labels at the center of the window.
1831        win_type : string, default None
1832            Provide a window type. The recognized window types are identical
1833            to pandas.
1834        axis : int, default 0
1835
1836        Returns
1837        -------
1838        a Rolling object on which to call a method to compute a statistic
1839
1840        Examples
1841        --------
1842        >>> import dask
1843        >>> ddf = dask.datasets.timeseries(freq="1H")
1844        >>> result = ddf.groupby("name").x.rolling('1D').max()
1845        """
1846        from dask.dataframe.rolling import RollingGroupby
1847
1848        if isinstance(window, Integral):
1849            raise ValueError(
1850                "Only time indexes are supported for rolling groupbys in dask dataframe. "
1851                "``window`` must be a ``freq`` (e.g. '1H')."
1852            )
1853
1854        if min_periods is not None:
1855            if not isinstance(min_periods, Integral):
1856                raise ValueError("min_periods must be an integer")
1857            if min_periods < 0:
1858                raise ValueError("min_periods must be >= 0")
1859
1860        return RollingGroupby(
1861            self,
1862            window=window,
1863            min_periods=min_periods,
1864            center=center,
1865            win_type=win_type,
1866            axis=axis,
1867        )
1868
1869
1870class DataFrameGroupBy(_GroupBy):
1871    _token_prefix = "dataframe-groupby-"
1872
1873    def __getitem__(self, key):
1874        if isinstance(key, list):
1875            g = DataFrameGroupBy(
1876                self.obj, by=self.index, slice=key, sort=self.sort, **self.dropna
1877            )
1878        else:
1879            g = SeriesGroupBy(
1880                self.obj, by=self.index, slice=key, sort=self.sort, **self.dropna
1881            )
1882
1883        # error is raised from pandas
1884        g._meta = g._meta[key]
1885        return g
1886
1887    def __dir__(self):
1888        return sorted(
1889            set(
1890                dir(type(self))
1891                + list(self.__dict__)
1892                + list(filter(M.isidentifier, self.obj.columns))
1893            )
1894        )
1895
1896    def __getattr__(self, key):
1897        try:
1898            return self[key]
1899        except KeyError as e:
1900            raise AttributeError(e) from e
1901
1902    @derived_from(pd.core.groupby.DataFrameGroupBy)
1903    def aggregate(self, arg, split_every=None, split_out=1):
1904        if arg == "size":
1905            return self.size()
1906
1907        return super().aggregate(arg, split_every=split_every, split_out=split_out)
1908
1909    @derived_from(pd.core.groupby.DataFrameGroupBy)
1910    def agg(self, arg, split_every=None, split_out=1):
1911        return self.aggregate(arg, split_every=split_every, split_out=split_out)
1912
1913
1914class SeriesGroupBy(_GroupBy):
1915    _token_prefix = "series-groupby-"
1916
1917    def __init__(self, df, by=None, slice=None, observed=None, **kwargs):
1918        # for any non series object, raise pandas-compat error message
1919        # Hold off on setting observed by default: https://github.com/dask/dask/issues/6951
1920        observed = {"observed": observed} if observed is not None else {}
1921
1922        if isinstance(df, Series):
1923            if isinstance(by, Series):
1924                pass
1925            elif isinstance(by, list):
1926                if len(by) == 0:
1927                    raise ValueError("No group keys passed!")
1928
1929                non_series_items = [item for item in by if not isinstance(item, Series)]
1930                # raise error from pandas, if applicable
1931
1932                df._meta.groupby(non_series_items, **observed)
1933            else:
1934                # raise error from pandas, if applicable
1935                df._meta.groupby(by, **observed)
1936
1937        super().__init__(df, by=by, slice=slice, **observed, **kwargs)
1938
1939    @derived_from(pd.core.groupby.SeriesGroupBy)
1940    def nunique(self, split_every=None, split_out=1):
1941        """
1942        Examples
1943        --------
1944        >>> import pandas as pd
1945        >>> import dask.dataframe as dd
1946        >>> d = {'col1': [1, 2, 3, 4], 'col2': [5, 6, 7, 8]}
1947        >>> df = pd.DataFrame(data=d)
1948        >>> ddf = dd.from_pandas(df, 2)
1949        >>> ddf.groupby(['col1']).col2.nunique().compute()
1950        """
1951        name = self._meta.obj.name
1952        levels = _determine_levels(self.index)
1953
1954        if isinstance(self.obj, DataFrame):
1955            chunk = _nunique_df_chunk
1956
1957        else:
1958            chunk = _nunique_series_chunk
1959
1960        return aca(
1961            [self.obj, self.index]
1962            if not isinstance(self.index, list)
1963            else [self.obj] + self.index,
1964            chunk=chunk,
1965            aggregate=_nunique_df_aggregate,
1966            combine=_nunique_df_combine,
1967            token="series-groupby-nunique",
1968            chunk_kwargs={"levels": levels, "name": name},
1969            aggregate_kwargs={"levels": levels, "name": name},
1970            combine_kwargs={"levels": levels},
1971            split_every=split_every,
1972            split_out=split_out,
1973            split_out_setup=split_out_on_index,
1974            sort=self.sort,
1975        )
1976
1977    @derived_from(pd.core.groupby.SeriesGroupBy)
1978    def aggregate(self, arg, split_every=None, split_out=1):
1979        result = super().aggregate(arg, split_every=split_every, split_out=split_out)
1980        if self._slice:
1981            result = result[self._slice]
1982
1983        if not isinstance(arg, (list, dict)) and isinstance(result, DataFrame):
1984            result = result[result.columns[0]]
1985
1986        return result
1987
1988    @derived_from(pd.core.groupby.SeriesGroupBy)
1989    def agg(self, arg, split_every=None, split_out=1):
1990        return self.aggregate(arg, split_every=split_every, split_out=split_out)
1991
1992    @derived_from(pd.core.groupby.SeriesGroupBy)
1993    def value_counts(self, split_every=None, split_out=1):
1994        return self._aca_agg(
1995            token="value_counts",
1996            func=_value_counts,
1997            aggfunc=_value_counts_aggregate,
1998            split_every=split_every,
1999            split_out=split_out,
2000        )
2001
2002    @derived_from(pd.core.groupby.SeriesGroupBy)
2003    def unique(self, split_every=None, split_out=1):
2004        name = self._meta.obj.name
2005        return self._aca_agg(
2006            token="unique",
2007            func=M.unique,
2008            aggfunc=_unique_aggregate,
2009            aggregate_kwargs={"name": name},
2010            split_every=split_every,
2011            split_out=split_out,
2012        )
2013
2014    @derived_from(pd.core.groupby.SeriesGroupBy)
2015    def tail(self, n=5, split_every=None, split_out=1):
2016        index_levels = len(self.index) if isinstance(self.index, list) else 1
2017        return self._aca_agg(
2018            token="tail",
2019            func=_tail_chunk,
2020            aggfunc=_tail_aggregate,
2021            meta=M.tail(self._meta_nonempty),
2022            chunk_kwargs={"n": n},
2023            aggregate_kwargs={"n": n, "index_levels": index_levels},
2024            split_every=split_every,
2025            split_out=split_out,
2026        )
2027
2028    @derived_from(pd.core.groupby.SeriesGroupBy)
2029    def head(self, n=5, split_every=None, split_out=1):
2030        index_levels = len(self.index) if isinstance(self.index, list) else 1
2031        return self._aca_agg(
2032            token="head",
2033            func=_head_chunk,
2034            aggfunc=_head_aggregate,
2035            meta=M.head(self._meta_nonempty),
2036            chunk_kwargs={"n": n},
2037            aggregate_kwargs={"n": n, "index_levels": index_levels},
2038            split_every=split_every,
2039            split_out=split_out,
2040        )
2041
2042
2043def _unique_aggregate(series_gb, name=None):
2044    ret = type(series_gb.obj)(
2045        {k: v.explode().unique() for k, v in series_gb}, name=name
2046    )
2047    ret.index.names = series_gb.obj.index.names
2048    return ret
2049
2050
2051def _value_counts(x, **kwargs):
2052    if len(x):
2053        return M.value_counts(x, **kwargs)
2054    else:
2055        return pd.Series(dtype=int)
2056
2057
2058def _value_counts_aggregate(series_gb):
2059    to_concat = {k: v.groupby(level=1).sum() for k, v in series_gb}
2060    names = list(series_gb.obj.index.names)
2061    return pd.Series(pd.concat(to_concat, names=names))
2062
2063
2064def _tail_chunk(series_gb, **kwargs):
2065    keys, groups = zip(*series_gb) if len(series_gb) else ((True,), (series_gb,))
2066    return pd.concat([group.tail(**kwargs) for group in groups], keys=keys)
2067
2068
2069def _tail_aggregate(series_gb, **kwargs):
2070    levels = kwargs.pop("index_levels")
2071    return series_gb.tail(**kwargs).droplevel(list(range(levels)))
2072
2073
2074def _head_chunk(series_gb, **kwargs):
2075    keys, groups = zip(*series_gb) if len(series_gb) else ((True,), (series_gb,))
2076    return pd.concat([group.head(**kwargs) for group in groups], keys=keys)
2077
2078
2079def _head_aggregate(series_gb, **kwargs):
2080    levels = kwargs.pop("index_levels")
2081    return series_gb.head(**kwargs).droplevel(list(range(levels)))
2082