1import math
2import numbers
3import re
4import sys
5import textwrap
6import traceback
7from collections.abc import Iterator, Mapping
8from contextlib import contextmanager
9
10import numpy as np
11import pandas as pd
12from pandas.api.types import is_scalar  # noqa: F401
13from pandas.api.types import is_categorical_dtype, is_dtype_equal
14
15from ..base import is_dask_collection
16from ..core import get_deps
17from ..local import get_sync
18from ..utils import is_arraylike  # noqa: F401
19from ..utils import asciitable
20from ..utils import is_dataframe_like as dask_is_dataframe_like
21from ..utils import is_index_like as dask_is_index_like
22from ..utils import is_series_like as dask_is_series_like
23from ..utils import typename
24from . import _dtypes  # noqa: F401 register pandas extension types
25from . import methods
26from ._compat import PANDAS_GT_110, PANDAS_GT_120, tm  # noqa: F401
27from .dispatch import make_meta  # noqa : F401
28from .dispatch import make_meta_obj, meta_nonempty  # noqa : F401
29from .extensions import make_scalar
30
31meta_object_types = (pd.Series, pd.DataFrame, pd.Index, pd.MultiIndex)
32try:
33    import scipy.sparse as sp
34
35    meta_object_types += (sp.spmatrix,)
36except ImportError:
37    pass
38
39
40def is_integer_na_dtype(t):
41    dtype = getattr(t, "dtype", t)
42    types = (
43        pd.Int8Dtype,
44        pd.Int16Dtype,
45        pd.Int32Dtype,
46        pd.Int64Dtype,
47        pd.UInt8Dtype,
48        pd.UInt16Dtype,
49        pd.UInt32Dtype,
50        pd.UInt64Dtype,
51    )
52    return isinstance(dtype, types)
53
54
55def is_float_na_dtype(t):
56    if not PANDAS_GT_120:
57        return False
58
59    dtype = getattr(t, "dtype", t)
60    types = (
61        pd.Float32Dtype,
62        pd.Float64Dtype,
63    )
64    return isinstance(dtype, types)
65
66
67def shard_df_on_index(df, divisions):
68    """Shard a DataFrame by ranges on its index
69
70    Examples
71    --------
72
73    >>> df = pd.DataFrame({'a': [0, 10, 20, 30, 40], 'b': [5, 4 ,3, 2, 1]})
74    >>> df
75        a  b
76    0   0  5
77    1  10  4
78    2  20  3
79    3  30  2
80    4  40  1
81
82    >>> shards = list(shard_df_on_index(df, [2, 4]))
83    >>> shards[0]
84        a  b
85    0   0  5
86    1  10  4
87
88    >>> shards[1]
89        a  b
90    2  20  3
91    3  30  2
92
93    >>> shards[2]
94        a  b
95    4  40  1
96
97    >>> list(shard_df_on_index(df, []))[0]  # empty case
98        a  b
99    0   0  5
100    1  10  4
101    2  20  3
102    3  30  2
103    4  40  1
104    """
105
106    if isinstance(divisions, Iterator):
107        divisions = list(divisions)
108    if not len(divisions):
109        yield df
110    else:
111        divisions = np.array(divisions)
112        df = df.sort_index()
113        index = df.index
114        if is_categorical_dtype(index):
115            index = index.as_ordered()
116        indices = index.searchsorted(divisions)
117        yield df.iloc[: indices[0]]
118        for i in range(len(indices) - 1):
119            yield df.iloc[indices[i] : indices[i + 1]]
120        yield df.iloc[indices[-1] :]
121
122
123_META_TYPES = "meta : pd.DataFrame, pd.Series, dict, iterable, tuple, optional"
124_META_DESCRIPTION = """\
125An empty ``pd.DataFrame`` or ``pd.Series`` that matches the dtypes and
126column names of the output. This metadata is necessary for many algorithms
127in dask dataframe to work.  For ease of use, some alternative inputs are
128also available. Instead of a ``DataFrame``, a ``dict`` of ``{name: dtype}``
129or iterable of ``(name, dtype)`` can be provided (note that the order of
130the names should match the order of the columns). Instead of a series, a
131tuple of ``(name, dtype)`` can be used. If not provided, dask will try to
132infer the metadata. This may lead to unexpected results, so providing
133``meta`` is recommended. For more information, see
134``dask.dataframe.utils.make_meta``.
135"""
136
137
138def insert_meta_param_description(*args, **kwargs):
139    """Replace `$META` in docstring with param description.
140
141    If pad keyword is provided, will pad description by that number of
142    spaces (default is 8)."""
143    if not args:
144        return lambda f: insert_meta_param_description(f, **kwargs)
145    f = args[0]
146    indent = " " * kwargs.get("pad", 8)
147    body = textwrap.wrap(
148        _META_DESCRIPTION, initial_indent=indent, subsequent_indent=indent, width=78
149    )
150    descr = "{}\n{}".format(_META_TYPES, "\n".join(body))
151    if f.__doc__:
152        if "$META" in f.__doc__:
153            f.__doc__ = f.__doc__.replace("$META", descr)
154        else:
155            # Put it at the end of the parameters section
156            parameter_header = "Parameters\n%s----------" % indent[4:]
157            first, last = re.split("Parameters\\n[ ]*----------", f.__doc__)
158            parameters, rest = last.split("\n\n", 1)
159            f.__doc__ = "{}{}{}\n{}{}\n\n{}".format(
160                first, parameter_header, parameters, indent[4:], descr, rest
161            )
162    return f
163
164
165@contextmanager
166def raise_on_meta_error(funcname=None, udf=False):
167    """Reraise errors in this block to show metadata inference failure.
168
169    Parameters
170    ----------
171    funcname : str, optional
172        If provided, will be added to the error message to indicate the
173        name of the method that failed.
174    """
175    try:
176        yield
177    except Exception as e:
178        exc_type, exc_value, exc_traceback = sys.exc_info()
179        tb = "".join(traceback.format_tb(exc_traceback))
180        msg = "Metadata inference failed{0}.\n\n"
181        if udf:
182            msg += (
183                "You have supplied a custom function and Dask is unable to \n"
184                "determine the type of output that that function returns. \n\n"
185                "To resolve this please provide a meta= keyword.\n"
186                "The docstring of the Dask function you ran should have more information.\n\n"
187            )
188        msg += (
189            "Original error is below:\n"
190            "------------------------\n"
191            "{1}\n\n"
192            "Traceback:\n"
193            "---------\n"
194            "{2}"
195        )
196        msg = msg.format(f" in `{funcname}`" if funcname else "", repr(e), tb)
197        raise ValueError(msg) from e
198
199
200UNKNOWN_CATEGORIES = "__UNKNOWN_CATEGORIES__"
201
202
203def has_known_categories(x):
204    """Returns whether the categories in `x` are known.
205
206    Parameters
207    ----------
208    x : Series or CategoricalIndex
209    """
210    x = getattr(x, "_meta", x)
211    if is_series_like(x):
212        return UNKNOWN_CATEGORIES not in x.cat.categories
213    elif is_index_like(x) and hasattr(x, "categories"):
214        return UNKNOWN_CATEGORIES not in x.categories
215    raise TypeError("Expected Series or CategoricalIndex")
216
217
218def strip_unknown_categories(x, just_drop_unknown=False):
219    """Replace any unknown categoricals with empty categoricals.
220
221    Useful for preventing ``UNKNOWN_CATEGORIES`` from leaking into results.
222    """
223    if isinstance(x, (pd.Series, pd.DataFrame)):
224        x = x.copy()
225        if isinstance(x, pd.DataFrame):
226            cat_mask = x.dtypes == "category"
227            if cat_mask.any():
228                cats = cat_mask[cat_mask].index
229                for c in cats:
230                    if not has_known_categories(x[c]):
231                        if just_drop_unknown:
232                            x[c].cat.remove_categories(UNKNOWN_CATEGORIES, inplace=True)
233                        else:
234                            x[c] = x[c].cat.set_categories([])
235        elif isinstance(x, pd.Series):
236            if is_categorical_dtype(x.dtype) and not has_known_categories(x):
237                x = x.cat.set_categories([])
238        if isinstance(x.index, pd.CategoricalIndex) and not has_known_categories(
239            x.index
240        ):
241            x.index = x.index.set_categories([])
242    elif isinstance(x, pd.CategoricalIndex) and not has_known_categories(x):
243        x = x.set_categories([])
244    return x
245
246
247def clear_known_categories(x, cols=None, index=True):
248    """Set categories to be unknown.
249
250    Parameters
251    ----------
252    x : DataFrame, Series, Index
253    cols : iterable, optional
254        If x is a DataFrame, set only categoricals in these columns to unknown.
255        By default, all categorical columns are set to unknown categoricals
256    index : bool, optional
257        If True and x is a Series or DataFrame, set the clear known categories
258        in the index as well.
259    """
260    if isinstance(x, (pd.Series, pd.DataFrame)):
261        x = x.copy()
262        if isinstance(x, pd.DataFrame):
263            mask = x.dtypes == "category"
264            if cols is None:
265                cols = mask[mask].index
266            elif not mask.loc[cols].all():
267                raise ValueError("Not all columns are categoricals")
268            for c in cols:
269                x[c] = x[c].cat.set_categories([UNKNOWN_CATEGORIES])
270        elif isinstance(x, pd.Series):
271            if is_categorical_dtype(x.dtype):
272                x = x.cat.set_categories([UNKNOWN_CATEGORIES])
273        if index and isinstance(x.index, pd.CategoricalIndex):
274            x.index = x.index.set_categories([UNKNOWN_CATEGORIES])
275    elif isinstance(x, pd.CategoricalIndex):
276        x = x.set_categories([UNKNOWN_CATEGORIES])
277    return x
278
279
280def _empty_series(name, dtype, index=None):
281    if isinstance(dtype, str) and dtype == "category":
282        return pd.Series(
283            pd.Categorical([UNKNOWN_CATEGORIES]), name=name, index=index
284        ).iloc[:0]
285    return pd.Series([], dtype=dtype, name=name, index=index)
286
287
288_simple_fake_mapping = {
289    "b": np.bool_(True),
290    "V": np.void(b" "),
291    "M": np.datetime64("1970-01-01"),
292    "m": np.timedelta64(1),
293    "S": np.str_("foo"),
294    "a": np.str_("foo"),
295    "U": np.unicode_("foo"),
296    "O": "foo",
297}
298
299
300def _scalar_from_dtype(dtype):
301    if dtype.kind in ("i", "f", "u"):
302        return dtype.type(1)
303    elif dtype.kind == "c":
304        return dtype.type(complex(1, 0))
305    elif dtype.kind in _simple_fake_mapping:
306        o = _simple_fake_mapping[dtype.kind]
307        return o.astype(dtype) if dtype.kind in ("m", "M") else o
308    else:
309        raise TypeError(f"Can't handle dtype: {dtype}")
310
311
312def _nonempty_scalar(x):
313    if type(x) in make_scalar._lookup:
314        return make_scalar(x)
315
316    if np.isscalar(x):
317        dtype = x.dtype if hasattr(x, "dtype") else np.dtype(type(x))
318        return make_scalar(dtype)
319
320    raise TypeError(f"Can't handle meta of type '{typename(type(x))}'")
321
322
323def is_dataframe_like(df):
324    return dask_is_dataframe_like(df)
325
326
327def is_series_like(s):
328    return dask_is_series_like(s)
329
330
331def is_index_like(s):
332    return dask_is_index_like(s)
333
334
335def check_meta(x, meta, funcname=None, numeric_equal=True):
336    """Check that the dask metadata matches the result.
337
338    If metadata matches, ``x`` is passed through unchanged. A nice error is
339    raised if metadata doesn't match.
340
341    Parameters
342    ----------
343    x : DataFrame, Series, or Index
344    meta : DataFrame, Series, or Index
345        The expected metadata that ``x`` should match
346    funcname : str, optional
347        The name of the function in which the metadata was specified. If
348        provided, the function name will be included in the error message to be
349        more helpful to users.
350    numeric_equal : bool, optionl
351        If True, integer and floating dtypes compare equal. This is useful due
352        to panda's implicit conversion of integer to floating upon encountering
353        missingness, which is hard to infer statically.
354    """
355    eq_types = {"i", "f", "u"} if numeric_equal else set()
356
357    def equal_dtypes(a, b):
358        if is_categorical_dtype(a) != is_categorical_dtype(b):
359            return False
360        if isinstance(a, str) and a == "-" or isinstance(b, str) and b == "-":
361            return False
362        if is_categorical_dtype(a) and is_categorical_dtype(b):
363            if UNKNOWN_CATEGORIES in a.categories or UNKNOWN_CATEGORIES in b.categories:
364                return True
365            return a == b
366        return (a.kind in eq_types and b.kind in eq_types) or is_dtype_equal(a, b)
367
368    if not (
369        is_dataframe_like(meta) or is_series_like(meta) or is_index_like(meta)
370    ) or is_dask_collection(meta):
371        raise TypeError(
372            "Expected partition to be DataFrame, Series, or "
373            "Index, got `%s`" % typename(type(meta))
374        )
375
376    # Notice, we use .__class__ as opposed to type() in order to support
377    # object proxies see <https://github.com/dask/dask/pull/6981>
378    if x.__class__ != meta.__class__:
379        errmsg = "Expected partition of type `{}` but got `{}`".format(
380            typename(type(meta)),
381            typename(type(x)),
382        )
383    elif is_dataframe_like(meta):
384        dtypes = pd.concat([x.dtypes, meta.dtypes], axis=1, sort=True)
385        bad_dtypes = [
386            (repr(col), a, b)
387            for col, a, b in dtypes.fillna("-").itertuples()
388            if not equal_dtypes(a, b)
389        ]
390        if bad_dtypes:
391            errmsg = "Partition type: `{}`\n{}".format(
392                typename(type(meta)),
393                asciitable(["Column", "Found", "Expected"], bad_dtypes),
394            )
395        else:
396            check_matching_columns(meta, x)
397            return x
398    else:
399        if equal_dtypes(x.dtype, meta.dtype):
400            return x
401        errmsg = "Partition type: `{}`\n{}".format(
402            typename(type(meta)),
403            asciitable(["", "dtype"], [("Found", x.dtype), ("Expected", meta.dtype)]),
404        )
405
406    raise ValueError(
407        "Metadata mismatch found%s.\n\n"
408        "%s" % ((" in `%s`" % funcname if funcname else ""), errmsg)
409    )
410
411
412def check_matching_columns(meta, actual):
413    # Need nan_to_num otherwise nan comparison gives False
414    if not np.array_equal(np.nan_to_num(meta.columns), np.nan_to_num(actual.columns)):
415        extra = methods.tolist(actual.columns.difference(meta.columns))
416        missing = methods.tolist(meta.columns.difference(actual.columns))
417        if extra or missing:
418            extra_info = f"  Extra:   {extra}\n  Missing: {missing}"
419        else:
420            extra_info = "Order of columns does not match"
421        raise ValueError(
422            "The columns in the computed data do not match"
423            " the columns in the provided metadata\n"
424            f"{extra_info}"
425        )
426
427
428def index_summary(idx, name=None):
429    """Summarized representation of an Index."""
430    n = len(idx)
431    if name is None:
432        name = idx.__class__.__name__
433    if n:
434        head = idx[0]
435        tail = idx[-1]
436        summary = f", {head} to {tail}"
437    else:
438        summary = ""
439
440    return f"{name}: {n} entries{summary}"
441
442
443###############################################################
444# Testing
445###############################################################
446
447
448def _check_dask(dsk, check_names=True, check_dtypes=True, result=None):
449    import dask.dataframe as dd
450
451    if hasattr(dsk, "__dask_graph__"):
452        graph = dsk.__dask_graph__()
453        if hasattr(graph, "validate"):
454            graph.validate()
455        if result is None:
456            result = dsk.compute(scheduler="sync")
457        if isinstance(dsk, dd.Index):
458            assert "Index" in type(result).__name__, type(result)
459            # assert type(dsk._meta) == type(result), type(dsk._meta)
460            if check_names:
461                assert dsk.name == result.name
462                assert dsk._meta.name == result.name
463                if isinstance(result, pd.MultiIndex):
464                    assert result.names == dsk._meta.names
465            if check_dtypes:
466                assert_dask_dtypes(dsk, result)
467        elif isinstance(dsk, dd.Series):
468            assert "Series" in type(result).__name__, type(result)
469            assert type(dsk._meta) == type(result), type(dsk._meta)
470            if check_names:
471                assert dsk.name == result.name, (dsk.name, result.name)
472                assert dsk._meta.name == result.name
473            if check_dtypes:
474                assert_dask_dtypes(dsk, result)
475            _check_dask(
476                dsk.index,
477                check_names=check_names,
478                check_dtypes=check_dtypes,
479                result=result.index,
480            )
481        elif isinstance(dsk, dd.DataFrame):
482            assert "DataFrame" in type(result).__name__, type(result)
483            assert isinstance(dsk.columns, pd.Index), type(dsk.columns)
484            assert type(dsk._meta) == type(result), type(dsk._meta)
485            if check_names:
486                tm.assert_index_equal(dsk.columns, result.columns)
487                tm.assert_index_equal(dsk._meta.columns, result.columns)
488            if check_dtypes:
489                assert_dask_dtypes(dsk, result)
490            _check_dask(
491                dsk.index,
492                check_names=check_names,
493                check_dtypes=check_dtypes,
494                result=result.index,
495            )
496        elif isinstance(dsk, dd.core.Scalar):
497            assert np.isscalar(result) or isinstance(
498                result, (pd.Timestamp, pd.Timedelta)
499            )
500            if check_dtypes:
501                assert_dask_dtypes(dsk, result)
502        else:
503            msg = f"Unsupported dask instance {type(dsk)} found"
504            raise AssertionError(msg)
505        return result
506    return dsk
507
508
509def _maybe_sort(a, check_index: bool):
510    # sort by value, then index
511    try:
512        if is_dataframe_like(a):
513            if set(a.index.names) & set(a.columns):
514                a.index.names = [
515                    "-overlapped-index-name-%d" % i for i in range(len(a.index.names))
516                ]
517            a = a.sort_values(by=methods.tolist(a.columns))
518        else:
519            a = a.sort_values()
520    except (TypeError, IndexError, ValueError):
521        pass
522    return a.sort_index() if check_index else a
523
524
525def assert_eq(
526    a,
527    b,
528    check_names=True,
529    check_dtype=True,
530    check_divisions=True,
531    check_index=True,
532    **kwargs,
533):
534    if check_divisions:
535        assert_divisions(a)
536        assert_divisions(b)
537        if hasattr(a, "divisions") and hasattr(b, "divisions"):
538            at = type(np.asarray(a.divisions).tolist()[0])  # numpy to python
539            bt = type(np.asarray(b.divisions).tolist()[0])  # scalar conversion
540            assert at == bt, (at, bt)
541    assert_sane_keynames(a)
542    assert_sane_keynames(b)
543    a = _check_dask(a, check_names=check_names, check_dtypes=check_dtype)
544    b = _check_dask(b, check_names=check_names, check_dtypes=check_dtype)
545    if hasattr(a, "to_pandas"):
546        a = a.to_pandas()
547    if hasattr(b, "to_pandas"):
548        b = b.to_pandas()
549    if isinstance(a, (pd.DataFrame, pd.Series)):
550        a = _maybe_sort(a, check_index)
551        b = _maybe_sort(b, check_index)
552    if not check_index:
553        a = a.reset_index(drop=True)
554        b = b.reset_index(drop=True)
555    if isinstance(a, pd.DataFrame):
556        tm.assert_frame_equal(
557            a, b, check_names=check_names, check_dtype=check_dtype, **kwargs
558        )
559    elif isinstance(a, pd.Series):
560        tm.assert_series_equal(
561            a, b, check_names=check_names, check_dtype=check_dtype, **kwargs
562        )
563    elif isinstance(a, pd.Index):
564        tm.assert_index_equal(a, b, exact=check_dtype, **kwargs)
565    else:
566        if a == b:
567            return True
568        else:
569            if np.isnan(a):
570                assert np.isnan(b)
571            else:
572                assert np.allclose(a, b)
573    return True
574
575
576def assert_dask_graph(dask, label):
577    if hasattr(dask, "dask"):
578        dask = dask.dask
579    assert isinstance(dask, Mapping)
580    for k in dask:
581        if isinstance(k, tuple):
582            k = k[0]
583        if k.startswith(label):
584            return True
585    raise AssertionError(f"given dask graph doesn't contain label: {label}")
586
587
588def assert_divisions(ddf):
589    if not hasattr(ddf, "divisions"):
590        return
591    if not getattr(ddf, "known_divisions", False):
592        return
593
594    def index(x):
595        if is_index_like(x):
596            return x
597        try:
598            return x.index.get_level_values(0)
599        except AttributeError:
600            return x.index
601
602    results = get_sync(ddf.dask, ddf.__dask_keys__())
603    for i, df in enumerate(results[:-1]):
604        if len(df):
605            assert index(df).min() >= ddf.divisions[i]
606            assert index(df).max() < ddf.divisions[i + 1]
607
608    if len(results[-1]):
609        assert index(results[-1]).min() >= ddf.divisions[-2]
610        assert index(results[-1]).max() <= ddf.divisions[-1]
611
612
613def assert_sane_keynames(ddf):
614    if not hasattr(ddf, "dask"):
615        return
616    for k in ddf.dask.keys():
617        while isinstance(k, tuple):
618            k = k[0]
619        assert isinstance(k, (str, bytes))
620        assert len(k) < 100
621        assert " " not in k
622        assert k.split("-")[0].isidentifier(), k
623
624
625def assert_dask_dtypes(ddf, res, numeric_equal=True):
626    """Check that the dask metadata matches the result.
627
628    If `numeric_equal`, integer and floating dtypes compare equal. This is
629    useful due to the implicit conversion of integer to floating upon
630    encountering missingness, which is hard to infer statically."""
631
632    eq_type_sets = [{"O", "S", "U", "a"}]  # treat object and strings alike
633    if numeric_equal:
634        eq_type_sets.append({"i", "f", "u"})
635
636    def eq_dtypes(a, b):
637        return any(
638            a.kind in eq_types and b.kind in eq_types for eq_types in eq_type_sets
639        ) or (a == b)
640
641    if not is_dask_collection(res) and is_dataframe_like(res):
642        for col, a, b in pd.concat([ddf._meta.dtypes, res.dtypes], axis=1).itertuples():
643            assert eq_dtypes(a, b)
644    elif not is_dask_collection(res) and (is_index_like(res) or is_series_like(res)):
645        a = ddf._meta.dtype
646        b = res.dtype
647        assert eq_dtypes(a, b)
648    else:
649        if hasattr(ddf._meta, "dtype"):
650            a = ddf._meta.dtype
651            if not hasattr(res, "dtype"):
652                assert np.isscalar(res)
653                b = np.dtype(type(res))
654            else:
655                b = res.dtype
656            assert eq_dtypes(a, b)
657        else:
658            assert type(ddf._meta) == type(res)
659
660
661def assert_max_deps(x, n, eq=True):
662    dependencies, dependents = get_deps(x.dask)
663    if eq:
664        assert max(map(len, dependencies.values())) == n
665    else:
666        assert max(map(len, dependencies.values())) <= n
667
668
669def valid_divisions(divisions):
670    """Are the provided divisions valid?
671
672    Examples
673    --------
674    >>> valid_divisions([1, 2, 3])
675    True
676    >>> valid_divisions([3, 2, 1])
677    False
678    >>> valid_divisions([1, 1, 1])
679    False
680    >>> valid_divisions([0, 1, 1])
681    True
682    >>> valid_divisions(123)
683    False
684    >>> valid_divisions([0, float('nan'), 1])
685    False
686    """
687    if not isinstance(divisions, (tuple, list)):
688        return False
689
690    for i, x in enumerate(divisions[:-2]):
691        if x >= divisions[i + 1]:
692            return False
693        if isinstance(x, numbers.Number) and math.isnan(x):
694            return False
695
696    for x in divisions[-2:]:
697        if isinstance(x, numbers.Number) and math.isnan(x):
698            return False
699
700    if divisions[-2] > divisions[-1]:
701        return False
702
703    return True
704
705
706def drop_by_shallow_copy(df, columns, errors="raise"):
707    """Use shallow copy to drop columns in place"""
708    df2 = df.copy(deep=False)
709    if not pd.api.types.is_list_like(columns):
710        columns = [columns]
711    df2.drop(columns=columns, inplace=True, errors=errors)
712    return df2
713