1import math 2import numbers 3import re 4import sys 5import textwrap 6import traceback 7from collections.abc import Iterator, Mapping 8from contextlib import contextmanager 9 10import numpy as np 11import pandas as pd 12from pandas.api.types import is_scalar # noqa: F401 13from pandas.api.types import is_categorical_dtype, is_dtype_equal 14 15from ..base import is_dask_collection 16from ..core import get_deps 17from ..local import get_sync 18from ..utils import is_arraylike # noqa: F401 19from ..utils import asciitable 20from ..utils import is_dataframe_like as dask_is_dataframe_like 21from ..utils import is_index_like as dask_is_index_like 22from ..utils import is_series_like as dask_is_series_like 23from ..utils import typename 24from . import _dtypes # noqa: F401 register pandas extension types 25from . import methods 26from ._compat import PANDAS_GT_110, PANDAS_GT_120, tm # noqa: F401 27from .dispatch import make_meta # noqa : F401 28from .dispatch import make_meta_obj, meta_nonempty # noqa : F401 29from .extensions import make_scalar 30 31meta_object_types = (pd.Series, pd.DataFrame, pd.Index, pd.MultiIndex) 32try: 33 import scipy.sparse as sp 34 35 meta_object_types += (sp.spmatrix,) 36except ImportError: 37 pass 38 39 40def is_integer_na_dtype(t): 41 dtype = getattr(t, "dtype", t) 42 types = ( 43 pd.Int8Dtype, 44 pd.Int16Dtype, 45 pd.Int32Dtype, 46 pd.Int64Dtype, 47 pd.UInt8Dtype, 48 pd.UInt16Dtype, 49 pd.UInt32Dtype, 50 pd.UInt64Dtype, 51 ) 52 return isinstance(dtype, types) 53 54 55def is_float_na_dtype(t): 56 if not PANDAS_GT_120: 57 return False 58 59 dtype = getattr(t, "dtype", t) 60 types = ( 61 pd.Float32Dtype, 62 pd.Float64Dtype, 63 ) 64 return isinstance(dtype, types) 65 66 67def shard_df_on_index(df, divisions): 68 """Shard a DataFrame by ranges on its index 69 70 Examples 71 -------- 72 73 >>> df = pd.DataFrame({'a': [0, 10, 20, 30, 40], 'b': [5, 4 ,3, 2, 1]}) 74 >>> df 75 a b 76 0 0 5 77 1 10 4 78 2 20 3 79 3 30 2 80 4 40 1 81 82 >>> shards = list(shard_df_on_index(df, [2, 4])) 83 >>> shards[0] 84 a b 85 0 0 5 86 1 10 4 87 88 >>> shards[1] 89 a b 90 2 20 3 91 3 30 2 92 93 >>> shards[2] 94 a b 95 4 40 1 96 97 >>> list(shard_df_on_index(df, []))[0] # empty case 98 a b 99 0 0 5 100 1 10 4 101 2 20 3 102 3 30 2 103 4 40 1 104 """ 105 106 if isinstance(divisions, Iterator): 107 divisions = list(divisions) 108 if not len(divisions): 109 yield df 110 else: 111 divisions = np.array(divisions) 112 df = df.sort_index() 113 index = df.index 114 if is_categorical_dtype(index): 115 index = index.as_ordered() 116 indices = index.searchsorted(divisions) 117 yield df.iloc[: indices[0]] 118 for i in range(len(indices) - 1): 119 yield df.iloc[indices[i] : indices[i + 1]] 120 yield df.iloc[indices[-1] :] 121 122 123_META_TYPES = "meta : pd.DataFrame, pd.Series, dict, iterable, tuple, optional" 124_META_DESCRIPTION = """\ 125An empty ``pd.DataFrame`` or ``pd.Series`` that matches the dtypes and 126column names of the output. This metadata is necessary for many algorithms 127in dask dataframe to work. For ease of use, some alternative inputs are 128also available. Instead of a ``DataFrame``, a ``dict`` of ``{name: dtype}`` 129or iterable of ``(name, dtype)`` can be provided (note that the order of 130the names should match the order of the columns). Instead of a series, a 131tuple of ``(name, dtype)`` can be used. If not provided, dask will try to 132infer the metadata. This may lead to unexpected results, so providing 133``meta`` is recommended. For more information, see 134``dask.dataframe.utils.make_meta``. 135""" 136 137 138def insert_meta_param_description(*args, **kwargs): 139 """Replace `$META` in docstring with param description. 140 141 If pad keyword is provided, will pad description by that number of 142 spaces (default is 8).""" 143 if not args: 144 return lambda f: insert_meta_param_description(f, **kwargs) 145 f = args[0] 146 indent = " " * kwargs.get("pad", 8) 147 body = textwrap.wrap( 148 _META_DESCRIPTION, initial_indent=indent, subsequent_indent=indent, width=78 149 ) 150 descr = "{}\n{}".format(_META_TYPES, "\n".join(body)) 151 if f.__doc__: 152 if "$META" in f.__doc__: 153 f.__doc__ = f.__doc__.replace("$META", descr) 154 else: 155 # Put it at the end of the parameters section 156 parameter_header = "Parameters\n%s----------" % indent[4:] 157 first, last = re.split("Parameters\\n[ ]*----------", f.__doc__) 158 parameters, rest = last.split("\n\n", 1) 159 f.__doc__ = "{}{}{}\n{}{}\n\n{}".format( 160 first, parameter_header, parameters, indent[4:], descr, rest 161 ) 162 return f 163 164 165@contextmanager 166def raise_on_meta_error(funcname=None, udf=False): 167 """Reraise errors in this block to show metadata inference failure. 168 169 Parameters 170 ---------- 171 funcname : str, optional 172 If provided, will be added to the error message to indicate the 173 name of the method that failed. 174 """ 175 try: 176 yield 177 except Exception as e: 178 exc_type, exc_value, exc_traceback = sys.exc_info() 179 tb = "".join(traceback.format_tb(exc_traceback)) 180 msg = "Metadata inference failed{0}.\n\n" 181 if udf: 182 msg += ( 183 "You have supplied a custom function and Dask is unable to \n" 184 "determine the type of output that that function returns. \n\n" 185 "To resolve this please provide a meta= keyword.\n" 186 "The docstring of the Dask function you ran should have more information.\n\n" 187 ) 188 msg += ( 189 "Original error is below:\n" 190 "------------------------\n" 191 "{1}\n\n" 192 "Traceback:\n" 193 "---------\n" 194 "{2}" 195 ) 196 msg = msg.format(f" in `{funcname}`" if funcname else "", repr(e), tb) 197 raise ValueError(msg) from e 198 199 200UNKNOWN_CATEGORIES = "__UNKNOWN_CATEGORIES__" 201 202 203def has_known_categories(x): 204 """Returns whether the categories in `x` are known. 205 206 Parameters 207 ---------- 208 x : Series or CategoricalIndex 209 """ 210 x = getattr(x, "_meta", x) 211 if is_series_like(x): 212 return UNKNOWN_CATEGORIES not in x.cat.categories 213 elif is_index_like(x) and hasattr(x, "categories"): 214 return UNKNOWN_CATEGORIES not in x.categories 215 raise TypeError("Expected Series or CategoricalIndex") 216 217 218def strip_unknown_categories(x, just_drop_unknown=False): 219 """Replace any unknown categoricals with empty categoricals. 220 221 Useful for preventing ``UNKNOWN_CATEGORIES`` from leaking into results. 222 """ 223 if isinstance(x, (pd.Series, pd.DataFrame)): 224 x = x.copy() 225 if isinstance(x, pd.DataFrame): 226 cat_mask = x.dtypes == "category" 227 if cat_mask.any(): 228 cats = cat_mask[cat_mask].index 229 for c in cats: 230 if not has_known_categories(x[c]): 231 if just_drop_unknown: 232 x[c].cat.remove_categories(UNKNOWN_CATEGORIES, inplace=True) 233 else: 234 x[c] = x[c].cat.set_categories([]) 235 elif isinstance(x, pd.Series): 236 if is_categorical_dtype(x.dtype) and not has_known_categories(x): 237 x = x.cat.set_categories([]) 238 if isinstance(x.index, pd.CategoricalIndex) and not has_known_categories( 239 x.index 240 ): 241 x.index = x.index.set_categories([]) 242 elif isinstance(x, pd.CategoricalIndex) and not has_known_categories(x): 243 x = x.set_categories([]) 244 return x 245 246 247def clear_known_categories(x, cols=None, index=True): 248 """Set categories to be unknown. 249 250 Parameters 251 ---------- 252 x : DataFrame, Series, Index 253 cols : iterable, optional 254 If x is a DataFrame, set only categoricals in these columns to unknown. 255 By default, all categorical columns are set to unknown categoricals 256 index : bool, optional 257 If True and x is a Series or DataFrame, set the clear known categories 258 in the index as well. 259 """ 260 if isinstance(x, (pd.Series, pd.DataFrame)): 261 x = x.copy() 262 if isinstance(x, pd.DataFrame): 263 mask = x.dtypes == "category" 264 if cols is None: 265 cols = mask[mask].index 266 elif not mask.loc[cols].all(): 267 raise ValueError("Not all columns are categoricals") 268 for c in cols: 269 x[c] = x[c].cat.set_categories([UNKNOWN_CATEGORIES]) 270 elif isinstance(x, pd.Series): 271 if is_categorical_dtype(x.dtype): 272 x = x.cat.set_categories([UNKNOWN_CATEGORIES]) 273 if index and isinstance(x.index, pd.CategoricalIndex): 274 x.index = x.index.set_categories([UNKNOWN_CATEGORIES]) 275 elif isinstance(x, pd.CategoricalIndex): 276 x = x.set_categories([UNKNOWN_CATEGORIES]) 277 return x 278 279 280def _empty_series(name, dtype, index=None): 281 if isinstance(dtype, str) and dtype == "category": 282 return pd.Series( 283 pd.Categorical([UNKNOWN_CATEGORIES]), name=name, index=index 284 ).iloc[:0] 285 return pd.Series([], dtype=dtype, name=name, index=index) 286 287 288_simple_fake_mapping = { 289 "b": np.bool_(True), 290 "V": np.void(b" "), 291 "M": np.datetime64("1970-01-01"), 292 "m": np.timedelta64(1), 293 "S": np.str_("foo"), 294 "a": np.str_("foo"), 295 "U": np.unicode_("foo"), 296 "O": "foo", 297} 298 299 300def _scalar_from_dtype(dtype): 301 if dtype.kind in ("i", "f", "u"): 302 return dtype.type(1) 303 elif dtype.kind == "c": 304 return dtype.type(complex(1, 0)) 305 elif dtype.kind in _simple_fake_mapping: 306 o = _simple_fake_mapping[dtype.kind] 307 return o.astype(dtype) if dtype.kind in ("m", "M") else o 308 else: 309 raise TypeError(f"Can't handle dtype: {dtype}") 310 311 312def _nonempty_scalar(x): 313 if type(x) in make_scalar._lookup: 314 return make_scalar(x) 315 316 if np.isscalar(x): 317 dtype = x.dtype if hasattr(x, "dtype") else np.dtype(type(x)) 318 return make_scalar(dtype) 319 320 raise TypeError(f"Can't handle meta of type '{typename(type(x))}'") 321 322 323def is_dataframe_like(df): 324 return dask_is_dataframe_like(df) 325 326 327def is_series_like(s): 328 return dask_is_series_like(s) 329 330 331def is_index_like(s): 332 return dask_is_index_like(s) 333 334 335def check_meta(x, meta, funcname=None, numeric_equal=True): 336 """Check that the dask metadata matches the result. 337 338 If metadata matches, ``x`` is passed through unchanged. A nice error is 339 raised if metadata doesn't match. 340 341 Parameters 342 ---------- 343 x : DataFrame, Series, or Index 344 meta : DataFrame, Series, or Index 345 The expected metadata that ``x`` should match 346 funcname : str, optional 347 The name of the function in which the metadata was specified. If 348 provided, the function name will be included in the error message to be 349 more helpful to users. 350 numeric_equal : bool, optionl 351 If True, integer and floating dtypes compare equal. This is useful due 352 to panda's implicit conversion of integer to floating upon encountering 353 missingness, which is hard to infer statically. 354 """ 355 eq_types = {"i", "f", "u"} if numeric_equal else set() 356 357 def equal_dtypes(a, b): 358 if is_categorical_dtype(a) != is_categorical_dtype(b): 359 return False 360 if isinstance(a, str) and a == "-" or isinstance(b, str) and b == "-": 361 return False 362 if is_categorical_dtype(a) and is_categorical_dtype(b): 363 if UNKNOWN_CATEGORIES in a.categories or UNKNOWN_CATEGORIES in b.categories: 364 return True 365 return a == b 366 return (a.kind in eq_types and b.kind in eq_types) or is_dtype_equal(a, b) 367 368 if not ( 369 is_dataframe_like(meta) or is_series_like(meta) or is_index_like(meta) 370 ) or is_dask_collection(meta): 371 raise TypeError( 372 "Expected partition to be DataFrame, Series, or " 373 "Index, got `%s`" % typename(type(meta)) 374 ) 375 376 # Notice, we use .__class__ as opposed to type() in order to support 377 # object proxies see <https://github.com/dask/dask/pull/6981> 378 if x.__class__ != meta.__class__: 379 errmsg = "Expected partition of type `{}` but got `{}`".format( 380 typename(type(meta)), 381 typename(type(x)), 382 ) 383 elif is_dataframe_like(meta): 384 dtypes = pd.concat([x.dtypes, meta.dtypes], axis=1, sort=True) 385 bad_dtypes = [ 386 (repr(col), a, b) 387 for col, a, b in dtypes.fillna("-").itertuples() 388 if not equal_dtypes(a, b) 389 ] 390 if bad_dtypes: 391 errmsg = "Partition type: `{}`\n{}".format( 392 typename(type(meta)), 393 asciitable(["Column", "Found", "Expected"], bad_dtypes), 394 ) 395 else: 396 check_matching_columns(meta, x) 397 return x 398 else: 399 if equal_dtypes(x.dtype, meta.dtype): 400 return x 401 errmsg = "Partition type: `{}`\n{}".format( 402 typename(type(meta)), 403 asciitable(["", "dtype"], [("Found", x.dtype), ("Expected", meta.dtype)]), 404 ) 405 406 raise ValueError( 407 "Metadata mismatch found%s.\n\n" 408 "%s" % ((" in `%s`" % funcname if funcname else ""), errmsg) 409 ) 410 411 412def check_matching_columns(meta, actual): 413 # Need nan_to_num otherwise nan comparison gives False 414 if not np.array_equal(np.nan_to_num(meta.columns), np.nan_to_num(actual.columns)): 415 extra = methods.tolist(actual.columns.difference(meta.columns)) 416 missing = methods.tolist(meta.columns.difference(actual.columns)) 417 if extra or missing: 418 extra_info = f" Extra: {extra}\n Missing: {missing}" 419 else: 420 extra_info = "Order of columns does not match" 421 raise ValueError( 422 "The columns in the computed data do not match" 423 " the columns in the provided metadata\n" 424 f"{extra_info}" 425 ) 426 427 428def index_summary(idx, name=None): 429 """Summarized representation of an Index.""" 430 n = len(idx) 431 if name is None: 432 name = idx.__class__.__name__ 433 if n: 434 head = idx[0] 435 tail = idx[-1] 436 summary = f", {head} to {tail}" 437 else: 438 summary = "" 439 440 return f"{name}: {n} entries{summary}" 441 442 443############################################################### 444# Testing 445############################################################### 446 447 448def _check_dask(dsk, check_names=True, check_dtypes=True, result=None): 449 import dask.dataframe as dd 450 451 if hasattr(dsk, "__dask_graph__"): 452 graph = dsk.__dask_graph__() 453 if hasattr(graph, "validate"): 454 graph.validate() 455 if result is None: 456 result = dsk.compute(scheduler="sync") 457 if isinstance(dsk, dd.Index): 458 assert "Index" in type(result).__name__, type(result) 459 # assert type(dsk._meta) == type(result), type(dsk._meta) 460 if check_names: 461 assert dsk.name == result.name 462 assert dsk._meta.name == result.name 463 if isinstance(result, pd.MultiIndex): 464 assert result.names == dsk._meta.names 465 if check_dtypes: 466 assert_dask_dtypes(dsk, result) 467 elif isinstance(dsk, dd.Series): 468 assert "Series" in type(result).__name__, type(result) 469 assert type(dsk._meta) == type(result), type(dsk._meta) 470 if check_names: 471 assert dsk.name == result.name, (dsk.name, result.name) 472 assert dsk._meta.name == result.name 473 if check_dtypes: 474 assert_dask_dtypes(dsk, result) 475 _check_dask( 476 dsk.index, 477 check_names=check_names, 478 check_dtypes=check_dtypes, 479 result=result.index, 480 ) 481 elif isinstance(dsk, dd.DataFrame): 482 assert "DataFrame" in type(result).__name__, type(result) 483 assert isinstance(dsk.columns, pd.Index), type(dsk.columns) 484 assert type(dsk._meta) == type(result), type(dsk._meta) 485 if check_names: 486 tm.assert_index_equal(dsk.columns, result.columns) 487 tm.assert_index_equal(dsk._meta.columns, result.columns) 488 if check_dtypes: 489 assert_dask_dtypes(dsk, result) 490 _check_dask( 491 dsk.index, 492 check_names=check_names, 493 check_dtypes=check_dtypes, 494 result=result.index, 495 ) 496 elif isinstance(dsk, dd.core.Scalar): 497 assert np.isscalar(result) or isinstance( 498 result, (pd.Timestamp, pd.Timedelta) 499 ) 500 if check_dtypes: 501 assert_dask_dtypes(dsk, result) 502 else: 503 msg = f"Unsupported dask instance {type(dsk)} found" 504 raise AssertionError(msg) 505 return result 506 return dsk 507 508 509def _maybe_sort(a, check_index: bool): 510 # sort by value, then index 511 try: 512 if is_dataframe_like(a): 513 if set(a.index.names) & set(a.columns): 514 a.index.names = [ 515 "-overlapped-index-name-%d" % i for i in range(len(a.index.names)) 516 ] 517 a = a.sort_values(by=methods.tolist(a.columns)) 518 else: 519 a = a.sort_values() 520 except (TypeError, IndexError, ValueError): 521 pass 522 return a.sort_index() if check_index else a 523 524 525def assert_eq( 526 a, 527 b, 528 check_names=True, 529 check_dtype=True, 530 check_divisions=True, 531 check_index=True, 532 **kwargs, 533): 534 if check_divisions: 535 assert_divisions(a) 536 assert_divisions(b) 537 if hasattr(a, "divisions") and hasattr(b, "divisions"): 538 at = type(np.asarray(a.divisions).tolist()[0]) # numpy to python 539 bt = type(np.asarray(b.divisions).tolist()[0]) # scalar conversion 540 assert at == bt, (at, bt) 541 assert_sane_keynames(a) 542 assert_sane_keynames(b) 543 a = _check_dask(a, check_names=check_names, check_dtypes=check_dtype) 544 b = _check_dask(b, check_names=check_names, check_dtypes=check_dtype) 545 if hasattr(a, "to_pandas"): 546 a = a.to_pandas() 547 if hasattr(b, "to_pandas"): 548 b = b.to_pandas() 549 if isinstance(a, (pd.DataFrame, pd.Series)): 550 a = _maybe_sort(a, check_index) 551 b = _maybe_sort(b, check_index) 552 if not check_index: 553 a = a.reset_index(drop=True) 554 b = b.reset_index(drop=True) 555 if isinstance(a, pd.DataFrame): 556 tm.assert_frame_equal( 557 a, b, check_names=check_names, check_dtype=check_dtype, **kwargs 558 ) 559 elif isinstance(a, pd.Series): 560 tm.assert_series_equal( 561 a, b, check_names=check_names, check_dtype=check_dtype, **kwargs 562 ) 563 elif isinstance(a, pd.Index): 564 tm.assert_index_equal(a, b, exact=check_dtype, **kwargs) 565 else: 566 if a == b: 567 return True 568 else: 569 if np.isnan(a): 570 assert np.isnan(b) 571 else: 572 assert np.allclose(a, b) 573 return True 574 575 576def assert_dask_graph(dask, label): 577 if hasattr(dask, "dask"): 578 dask = dask.dask 579 assert isinstance(dask, Mapping) 580 for k in dask: 581 if isinstance(k, tuple): 582 k = k[0] 583 if k.startswith(label): 584 return True 585 raise AssertionError(f"given dask graph doesn't contain label: {label}") 586 587 588def assert_divisions(ddf): 589 if not hasattr(ddf, "divisions"): 590 return 591 if not getattr(ddf, "known_divisions", False): 592 return 593 594 def index(x): 595 if is_index_like(x): 596 return x 597 try: 598 return x.index.get_level_values(0) 599 except AttributeError: 600 return x.index 601 602 results = get_sync(ddf.dask, ddf.__dask_keys__()) 603 for i, df in enumerate(results[:-1]): 604 if len(df): 605 assert index(df).min() >= ddf.divisions[i] 606 assert index(df).max() < ddf.divisions[i + 1] 607 608 if len(results[-1]): 609 assert index(results[-1]).min() >= ddf.divisions[-2] 610 assert index(results[-1]).max() <= ddf.divisions[-1] 611 612 613def assert_sane_keynames(ddf): 614 if not hasattr(ddf, "dask"): 615 return 616 for k in ddf.dask.keys(): 617 while isinstance(k, tuple): 618 k = k[0] 619 assert isinstance(k, (str, bytes)) 620 assert len(k) < 100 621 assert " " not in k 622 assert k.split("-")[0].isidentifier(), k 623 624 625def assert_dask_dtypes(ddf, res, numeric_equal=True): 626 """Check that the dask metadata matches the result. 627 628 If `numeric_equal`, integer and floating dtypes compare equal. This is 629 useful due to the implicit conversion of integer to floating upon 630 encountering missingness, which is hard to infer statically.""" 631 632 eq_type_sets = [{"O", "S", "U", "a"}] # treat object and strings alike 633 if numeric_equal: 634 eq_type_sets.append({"i", "f", "u"}) 635 636 def eq_dtypes(a, b): 637 return any( 638 a.kind in eq_types and b.kind in eq_types for eq_types in eq_type_sets 639 ) or (a == b) 640 641 if not is_dask_collection(res) and is_dataframe_like(res): 642 for col, a, b in pd.concat([ddf._meta.dtypes, res.dtypes], axis=1).itertuples(): 643 assert eq_dtypes(a, b) 644 elif not is_dask_collection(res) and (is_index_like(res) or is_series_like(res)): 645 a = ddf._meta.dtype 646 b = res.dtype 647 assert eq_dtypes(a, b) 648 else: 649 if hasattr(ddf._meta, "dtype"): 650 a = ddf._meta.dtype 651 if not hasattr(res, "dtype"): 652 assert np.isscalar(res) 653 b = np.dtype(type(res)) 654 else: 655 b = res.dtype 656 assert eq_dtypes(a, b) 657 else: 658 assert type(ddf._meta) == type(res) 659 660 661def assert_max_deps(x, n, eq=True): 662 dependencies, dependents = get_deps(x.dask) 663 if eq: 664 assert max(map(len, dependencies.values())) == n 665 else: 666 assert max(map(len, dependencies.values())) <= n 667 668 669def valid_divisions(divisions): 670 """Are the provided divisions valid? 671 672 Examples 673 -------- 674 >>> valid_divisions([1, 2, 3]) 675 True 676 >>> valid_divisions([3, 2, 1]) 677 False 678 >>> valid_divisions([1, 1, 1]) 679 False 680 >>> valid_divisions([0, 1, 1]) 681 True 682 >>> valid_divisions(123) 683 False 684 >>> valid_divisions([0, float('nan'), 1]) 685 False 686 """ 687 if not isinstance(divisions, (tuple, list)): 688 return False 689 690 for i, x in enumerate(divisions[:-2]): 691 if x >= divisions[i + 1]: 692 return False 693 if isinstance(x, numbers.Number) and math.isnan(x): 694 return False 695 696 for x in divisions[-2:]: 697 if isinstance(x, numbers.Number) and math.isnan(x): 698 return False 699 700 if divisions[-2] > divisions[-1]: 701 return False 702 703 return True 704 705 706def drop_by_shallow_copy(df, columns, errors="raise"): 707 """Use shallow copy to drop columns in place""" 708 df2 = df.copy(deep=False) 709 if not pd.api.types.is_list_like(columns): 710 columns = [columns] 711 df2.drop(columns=columns, inplace=True, errors=errors) 712 return df2 713