1import bz2 2from collections import Counter 3from contextlib import contextmanager 4from datetime import datetime 5from functools import wraps 6import gzip 7import operator 8import os 9from pathlib import Path 10import random 11import re 12from shutil import rmtree 13import string 14import tempfile 15from typing import IO, Any, Callable, ContextManager, List, Optional, Type, Union, cast 16import warnings 17import zipfile 18 19import numpy as np 20from numpy.random import rand, randn 21 22from pandas._config.localization import ( # noqa:F401 23 can_set_locale, 24 get_locales, 25 set_locale, 26) 27 28from pandas._libs.lib import no_default 29import pandas._libs.testing as _testing 30from pandas._typing import Dtype, FilePathOrBuffer, FrameOrSeries 31from pandas.compat import get_lzma_file, import_lzma 32 33from pandas.core.dtypes.common import ( 34 is_bool, 35 is_categorical_dtype, 36 is_datetime64_dtype, 37 is_datetime64tz_dtype, 38 is_extension_array_dtype, 39 is_interval_dtype, 40 is_number, 41 is_numeric_dtype, 42 is_period_dtype, 43 is_sequence, 44 is_timedelta64_dtype, 45 needs_i8_conversion, 46) 47from pandas.core.dtypes.missing import array_equivalent 48 49import pandas as pd 50from pandas import ( 51 Categorical, 52 CategoricalIndex, 53 DataFrame, 54 DatetimeIndex, 55 Index, 56 IntervalIndex, 57 MultiIndex, 58 RangeIndex, 59 Series, 60 bdate_range, 61) 62from pandas.core.algorithms import safe_sort, take_1d 63from pandas.core.arrays import ( 64 DatetimeArray, 65 ExtensionArray, 66 IntervalArray, 67 PeriodArray, 68 TimedeltaArray, 69 period_array, 70) 71from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin 72 73from pandas.io.common import urlopen 74from pandas.io.formats.printing import pprint_thing 75 76lzma = import_lzma() 77 78_N = 30 79_K = 4 80_RAISE_NETWORK_ERROR_DEFAULT = False 81 82UNSIGNED_INT_DTYPES: List[Dtype] = ["uint8", "uint16", "uint32", "uint64"] 83UNSIGNED_EA_INT_DTYPES: List[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"] 84SIGNED_INT_DTYPES: List[Dtype] = [int, "int8", "int16", "int32", "int64"] 85SIGNED_EA_INT_DTYPES: List[Dtype] = ["Int8", "Int16", "Int32", "Int64"] 86ALL_INT_DTYPES = UNSIGNED_INT_DTYPES + SIGNED_INT_DTYPES 87ALL_EA_INT_DTYPES = UNSIGNED_EA_INT_DTYPES + SIGNED_EA_INT_DTYPES 88 89FLOAT_DTYPES: List[Dtype] = [float, "float32", "float64"] 90FLOAT_EA_DTYPES: List[Dtype] = ["Float32", "Float64"] 91COMPLEX_DTYPES: List[Dtype] = [complex, "complex64", "complex128"] 92STRING_DTYPES: List[Dtype] = [str, "str", "U"] 93 94DATETIME64_DTYPES: List[Dtype] = ["datetime64[ns]", "M8[ns]"] 95TIMEDELTA64_DTYPES: List[Dtype] = ["timedelta64[ns]", "m8[ns]"] 96 97BOOL_DTYPES = [bool, "bool"] 98BYTES_DTYPES = [bytes, "bytes"] 99OBJECT_DTYPES = [object, "object"] 100 101ALL_REAL_DTYPES = FLOAT_DTYPES + ALL_INT_DTYPES 102ALL_NUMPY_DTYPES = ( 103 ALL_REAL_DTYPES 104 + COMPLEX_DTYPES 105 + STRING_DTYPES 106 + DATETIME64_DTYPES 107 + TIMEDELTA64_DTYPES 108 + BOOL_DTYPES 109 + OBJECT_DTYPES 110 + BYTES_DTYPES 111) 112 113NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA] 114 115 116# set testing_mode 117_testing_mode_warnings = (DeprecationWarning, ResourceWarning) 118 119 120def set_testing_mode(): 121 # set the testing mode filters 122 testing_mode = os.environ.get("PANDAS_TESTING_MODE", "None") 123 if "deprecate" in testing_mode: 124 # pandas\_testing.py:119: error: Argument 2 to "simplefilter" has 125 # incompatible type "Tuple[Type[DeprecationWarning], 126 # Type[ResourceWarning]]"; expected "Type[Warning]" 127 warnings.simplefilter( 128 "always", _testing_mode_warnings # type: ignore[arg-type] 129 ) 130 131 132def reset_testing_mode(): 133 # reset the testing mode filters 134 testing_mode = os.environ.get("PANDAS_TESTING_MODE", "None") 135 if "deprecate" in testing_mode: 136 # pandas\_testing.py:126: error: Argument 2 to "simplefilter" has 137 # incompatible type "Tuple[Type[DeprecationWarning], 138 # Type[ResourceWarning]]"; expected "Type[Warning]" 139 warnings.simplefilter( 140 "ignore", _testing_mode_warnings # type: ignore[arg-type] 141 ) 142 143 144set_testing_mode() 145 146 147def reset_display_options(): 148 """ 149 Reset the display options for printing and representing objects. 150 """ 151 pd.reset_option("^display.", silent=True) 152 153 154def round_trip_pickle( 155 obj: Any, path: Optional[FilePathOrBuffer] = None 156) -> FrameOrSeries: 157 """ 158 Pickle an object and then read it again. 159 160 Parameters 161 ---------- 162 obj : any object 163 The object to pickle and then re-read. 164 path : str, path object or file-like object, default None 165 The path where the pickled object is written and then read. 166 167 Returns 168 ------- 169 pandas object 170 The original object that was pickled and then re-read. 171 """ 172 _path = path 173 if _path is None: 174 _path = f"__{rands(10)}__.pickle" 175 with ensure_clean(_path) as temp_path: 176 pd.to_pickle(obj, temp_path) 177 return pd.read_pickle(temp_path) 178 179 180def round_trip_pathlib(writer, reader, path: Optional[str] = None): 181 """ 182 Write an object to file specified by a pathlib.Path and read it back 183 184 Parameters 185 ---------- 186 writer : callable bound to pandas object 187 IO writing function (e.g. DataFrame.to_csv ) 188 reader : callable 189 IO reading function (e.g. pd.read_csv ) 190 path : str, default None 191 The path where the object is written and then read. 192 193 Returns 194 ------- 195 pandas object 196 The original object that was serialized and then re-read. 197 """ 198 import pytest 199 200 Path = pytest.importorskip("pathlib").Path 201 if path is None: 202 path = "___pathlib___" 203 with ensure_clean(path) as path: 204 writer(Path(path)) 205 obj = reader(Path(path)) 206 return obj 207 208 209def round_trip_localpath(writer, reader, path: Optional[str] = None): 210 """ 211 Write an object to file specified by a py.path LocalPath and read it back. 212 213 Parameters 214 ---------- 215 writer : callable bound to pandas object 216 IO writing function (e.g. DataFrame.to_csv ) 217 reader : callable 218 IO reading function (e.g. pd.read_csv ) 219 path : str, default None 220 The path where the object is written and then read. 221 222 Returns 223 ------- 224 pandas object 225 The original object that was serialized and then re-read. 226 """ 227 import pytest 228 229 LocalPath = pytest.importorskip("py.path").local 230 if path is None: 231 path = "___localpath___" 232 with ensure_clean(path) as path: 233 writer(LocalPath(path)) 234 obj = reader(LocalPath(path)) 235 return obj 236 237 238@contextmanager 239def decompress_file(path, compression): 240 """ 241 Open a compressed file and return a file object. 242 243 Parameters 244 ---------- 245 path : str 246 The path where the file is read from. 247 248 compression : {'gzip', 'bz2', 'zip', 'xz', None} 249 Name of the decompression to use 250 251 Returns 252 ------- 253 file object 254 """ 255 if compression is None: 256 f = open(path, "rb") 257 elif compression == "gzip": 258 # pandas\_testing.py:243: error: Incompatible types in assignment 259 # (expression has type "IO[Any]", variable has type "BinaryIO") 260 f = gzip.open(path, "rb") # type: ignore[assignment] 261 elif compression == "bz2": 262 # pandas\_testing.py:245: error: Incompatible types in assignment 263 # (expression has type "BZ2File", variable has type "BinaryIO") 264 f = bz2.BZ2File(path, "rb") # type: ignore[assignment] 265 elif compression == "xz": 266 f = get_lzma_file(lzma)(path, "rb") 267 elif compression == "zip": 268 zip_file = zipfile.ZipFile(path) 269 zip_names = zip_file.namelist() 270 if len(zip_names) == 1: 271 # pandas\_testing.py:252: error: Incompatible types in assignment 272 # (expression has type "IO[bytes]", variable has type "BinaryIO") 273 f = zip_file.open(zip_names.pop()) # type: ignore[assignment] 274 else: 275 raise ValueError(f"ZIP file {path} error. Only one file per ZIP.") 276 else: 277 raise ValueError(f"Unrecognized compression type: {compression}") 278 279 try: 280 yield f 281 finally: 282 f.close() 283 if compression == "zip": 284 zip_file.close() 285 286 287def write_to_compressed(compression, path, data, dest="test"): 288 """ 289 Write data to a compressed file. 290 291 Parameters 292 ---------- 293 compression : {'gzip', 'bz2', 'zip', 'xz'} 294 The compression type to use. 295 path : str 296 The file path to write the data. 297 data : str 298 The data to write. 299 dest : str, default "test" 300 The destination file (for ZIP only) 301 302 Raises 303 ------ 304 ValueError : An invalid compression value was passed in. 305 """ 306 if compression == "zip": 307 compress_method = zipfile.ZipFile 308 elif compression == "gzip": 309 # pandas\_testing.py:288: error: Incompatible types in assignment 310 # (expression has type "Type[GzipFile]", variable has type 311 # "Type[ZipFile]") 312 compress_method = gzip.GzipFile # type: ignore[assignment] 313 elif compression == "bz2": 314 # pandas\_testing.py:290: error: Incompatible types in assignment 315 # (expression has type "Type[BZ2File]", variable has type 316 # "Type[ZipFile]") 317 compress_method = bz2.BZ2File # type: ignore[assignment] 318 elif compression == "xz": 319 compress_method = get_lzma_file(lzma) 320 else: 321 raise ValueError(f"Unrecognized compression type: {compression}") 322 323 if compression == "zip": 324 mode = "w" 325 args = (dest, data) 326 method = "writestr" 327 else: 328 mode = "wb" 329 # pandas\_testing.py:302: error: Incompatible types in assignment 330 # (expression has type "Tuple[Any]", variable has type "Tuple[Any, 331 # Any]") 332 args = (data,) # type: ignore[assignment] 333 method = "write" 334 335 with compress_method(path, mode=mode) as f: 336 getattr(f, method)(*args) 337 338 339def _get_tol_from_less_precise(check_less_precise: Union[bool, int]) -> float: 340 """ 341 Return the tolerance equivalent to the deprecated `check_less_precise` 342 parameter. 343 344 Parameters 345 ---------- 346 check_less_precise : bool or int 347 348 Returns 349 ------- 350 float 351 Tolerance to be used as relative/absolute tolerance. 352 353 Examples 354 -------- 355 >>> # Using check_less_precise as a bool: 356 >>> _get_tol_from_less_precise(False) 357 0.5e-5 358 >>> _get_tol_from_less_precise(True) 359 0.5e-3 360 >>> # Using check_less_precise as an int representing the decimal 361 >>> # tolerance intended: 362 >>> _get_tol_from_less_precise(2) 363 0.5e-2 364 >>> _get_tol_from_less_precise(8) 365 0.5e-8 366 367 """ 368 if isinstance(check_less_precise, bool): 369 if check_less_precise: 370 # 3-digit tolerance 371 return 0.5e-3 372 else: 373 # 5-digit tolerance 374 return 0.5e-5 375 else: 376 # Equivalent to setting checking_less_precise=<decimals> 377 return 0.5 * 10 ** -check_less_precise 378 379 380def assert_almost_equal( 381 left, 382 right, 383 check_dtype: Union[bool, str] = "equiv", 384 check_less_precise: Union[bool, int] = no_default, 385 rtol: float = 1.0e-5, 386 atol: float = 1.0e-8, 387 **kwargs, 388): 389 """ 390 Check that the left and right objects are approximately equal. 391 392 By approximately equal, we refer to objects that are numbers or that 393 contain numbers which may be equivalent to specific levels of precision. 394 395 Parameters 396 ---------- 397 left : object 398 right : object 399 check_dtype : bool or {'equiv'}, default 'equiv' 400 Check dtype if both a and b are the same type. If 'equiv' is passed in, 401 then `RangeIndex` and `Int64Index` are also considered equivalent 402 when doing type checking. 403 check_less_precise : bool or int, default False 404 Specify comparison precision. 5 digits (False) or 3 digits (True) 405 after decimal points are compared. If int, then specify the number 406 of digits to compare. 407 408 When comparing two numbers, if the first number has magnitude less 409 than 1e-5, we compare the two numbers directly and check whether 410 they are equivalent within the specified precision. Otherwise, we 411 compare the **ratio** of the second number to the first number and 412 check whether it is equivalent to 1 within the specified precision. 413 414 .. deprecated:: 1.1.0 415 Use `rtol` and `atol` instead to define relative/absolute 416 tolerance, respectively. Similar to :func:`math.isclose`. 417 rtol : float, default 1e-5 418 Relative tolerance. 419 420 .. versionadded:: 1.1.0 421 atol : float, default 1e-8 422 Absolute tolerance. 423 424 .. versionadded:: 1.1.0 425 """ 426 if check_less_precise is not no_default: 427 warnings.warn( 428 "The 'check_less_precise' keyword in testing.assert_*_equal " 429 "is deprecated and will be removed in a future version. " 430 "You can stop passing 'check_less_precise' to silence this warning.", 431 FutureWarning, 432 stacklevel=2, 433 ) 434 rtol = atol = _get_tol_from_less_precise(check_less_precise) 435 436 if isinstance(left, pd.Index): 437 assert_index_equal( 438 left, 439 right, 440 check_exact=False, 441 exact=check_dtype, 442 rtol=rtol, 443 atol=atol, 444 **kwargs, 445 ) 446 447 elif isinstance(left, pd.Series): 448 assert_series_equal( 449 left, 450 right, 451 check_exact=False, 452 check_dtype=check_dtype, 453 rtol=rtol, 454 atol=atol, 455 **kwargs, 456 ) 457 458 elif isinstance(left, pd.DataFrame): 459 assert_frame_equal( 460 left, 461 right, 462 check_exact=False, 463 check_dtype=check_dtype, 464 rtol=rtol, 465 atol=atol, 466 **kwargs, 467 ) 468 469 else: 470 # Other sequences. 471 if check_dtype: 472 if is_number(left) and is_number(right): 473 # Do not compare numeric classes, like np.float64 and float. 474 pass 475 elif is_bool(left) and is_bool(right): 476 # Do not compare bool classes, like np.bool_ and bool. 477 pass 478 else: 479 if isinstance(left, np.ndarray) or isinstance(right, np.ndarray): 480 obj = "numpy array" 481 else: 482 obj = "Input" 483 assert_class_equal(left, right, obj=obj) 484 _testing.assert_almost_equal( 485 left, right, check_dtype=check_dtype, rtol=rtol, atol=atol, **kwargs 486 ) 487 488 489def _check_isinstance(left, right, cls): 490 """ 491 Helper method for our assert_* methods that ensures that 492 the two objects being compared have the right type before 493 proceeding with the comparison. 494 495 Parameters 496 ---------- 497 left : The first object being compared. 498 right : The second object being compared. 499 cls : The class type to check against. 500 501 Raises 502 ------ 503 AssertionError : Either `left` or `right` is not an instance of `cls`. 504 """ 505 cls_name = cls.__name__ 506 507 if not isinstance(left, cls): 508 raise AssertionError( 509 f"{cls_name} Expected type {cls}, found {type(left)} instead" 510 ) 511 if not isinstance(right, cls): 512 raise AssertionError( 513 f"{cls_name} Expected type {cls}, found {type(right)} instead" 514 ) 515 516 517def assert_dict_equal(left, right, compare_keys: bool = True): 518 519 _check_isinstance(left, right, dict) 520 _testing.assert_dict_equal(left, right, compare_keys=compare_keys) 521 522 523def randbool(size=(), p: float = 0.5): 524 return rand(*size) <= p 525 526 527RANDS_CHARS = np.array(list(string.ascii_letters + string.digits), dtype=(np.str_, 1)) 528RANDU_CHARS = np.array( 529 list("".join(map(chr, range(1488, 1488 + 26))) + string.digits), 530 dtype=(np.unicode_, 1), 531) 532 533 534def rands_array(nchars, size, dtype="O"): 535 """ 536 Generate an array of byte strings. 537 """ 538 retval = ( 539 np.random.choice(RANDS_CHARS, size=nchars * np.prod(size)) 540 .view((np.str_, nchars)) 541 .reshape(size) 542 ) 543 return retval.astype(dtype) 544 545 546def randu_array(nchars, size, dtype="O"): 547 """ 548 Generate an array of unicode strings. 549 """ 550 retval = ( 551 np.random.choice(RANDU_CHARS, size=nchars * np.prod(size)) 552 .view((np.unicode_, nchars)) 553 .reshape(size) 554 ) 555 return retval.astype(dtype) 556 557 558def rands(nchars): 559 """ 560 Generate one random byte string. 561 562 See `rands_array` if you want to create an array of random strings. 563 564 """ 565 return "".join(np.random.choice(RANDS_CHARS, nchars)) 566 567 568def close(fignum=None): 569 from matplotlib.pyplot import close as _close, get_fignums 570 571 if fignum is None: 572 for fignum in get_fignums(): 573 _close(fignum) 574 else: 575 _close(fignum) 576 577 578# ----------------------------------------------------------------------------- 579# contextmanager to ensure the file cleanup 580 581 582@contextmanager 583def ensure_clean(filename=None, return_filelike: bool = False, **kwargs: Any): 584 """ 585 Gets a temporary path and agrees to remove on close. 586 587 This implementation does not use tempfile.mkstemp to avoid having a file handle. 588 If the code using the returned path wants to delete the file itself, windows 589 requires that no program has a file handle to it. 590 591 Parameters 592 ---------- 593 filename : str (optional) 594 suffix of the created file. 595 return_filelike : bool (default False) 596 if True, returns a file-like which is *always* cleaned. Necessary for 597 savefig and other functions which want to append extensions. 598 **kwargs 599 Additional keywords are passed to open(). 600 601 """ 602 folder = Path(tempfile.gettempdir()) 603 604 if filename is None: 605 filename = "" 606 filename = ( 607 "".join(random.choices(string.ascii_letters + string.digits, k=30)) + filename 608 ) 609 path = folder / filename 610 611 path.touch() 612 613 handle_or_str: Union[str, IO] = str(path) 614 if return_filelike: 615 kwargs.setdefault("mode", "w+b") 616 handle_or_str = open(path, **kwargs) 617 618 try: 619 yield handle_or_str 620 finally: 621 if not isinstance(handle_or_str, str): 622 handle_or_str.close() 623 if path.is_file(): 624 path.unlink() 625 626 627@contextmanager 628def ensure_clean_dir(): 629 """ 630 Get a temporary directory path and agrees to remove on close. 631 632 Yields 633 ------ 634 Temporary directory path 635 """ 636 directory_name = tempfile.mkdtemp(suffix="") 637 try: 638 yield directory_name 639 finally: 640 try: 641 rmtree(directory_name) 642 except OSError: 643 pass 644 645 646@contextmanager 647def ensure_safe_environment_variables(): 648 """ 649 Get a context manager to safely set environment variables 650 651 All changes will be undone on close, hence environment variables set 652 within this contextmanager will neither persist nor change global state. 653 """ 654 saved_environ = dict(os.environ) 655 try: 656 yield 657 finally: 658 os.environ.clear() 659 os.environ.update(saved_environ) 660 661 662# ----------------------------------------------------------------------------- 663# Comparators 664 665 666def equalContents(arr1, arr2) -> bool: 667 """ 668 Checks if the set of unique elements of arr1 and arr2 are equivalent. 669 """ 670 return frozenset(arr1) == frozenset(arr2) 671 672 673def assert_index_equal( 674 left: Index, 675 right: Index, 676 exact: Union[bool, str] = "equiv", 677 check_names: bool = True, 678 check_less_precise: Union[bool, int] = no_default, 679 check_exact: bool = True, 680 check_categorical: bool = True, 681 check_order: bool = True, 682 rtol: float = 1.0e-5, 683 atol: float = 1.0e-8, 684 obj: str = "Index", 685) -> None: 686 """ 687 Check that left and right Index are equal. 688 689 Parameters 690 ---------- 691 left : Index 692 right : Index 693 exact : bool or {'equiv'}, default 'equiv' 694 Whether to check the Index class, dtype and inferred_type 695 are identical. If 'equiv', then RangeIndex can be substituted for 696 Int64Index as well. 697 check_names : bool, default True 698 Whether to check the names attribute. 699 check_less_precise : bool or int, default False 700 Specify comparison precision. Only used when check_exact is False. 701 5 digits (False) or 3 digits (True) after decimal points are compared. 702 If int, then specify the digits to compare. 703 704 .. deprecated:: 1.1.0 705 Use `rtol` and `atol` instead to define relative/absolute 706 tolerance, respectively. Similar to :func:`math.isclose`. 707 check_exact : bool, default True 708 Whether to compare number exactly. 709 check_categorical : bool, default True 710 Whether to compare internal Categorical exactly. 711 check_order : bool, default True 712 Whether to compare the order of index entries as well as their values. 713 If True, both indexes must contain the same elements, in the same order. 714 If False, both indexes must contain the same elements, but in any order. 715 716 .. versionadded:: 1.2.0 717 rtol : float, default 1e-5 718 Relative tolerance. Only used when check_exact is False. 719 720 .. versionadded:: 1.1.0 721 atol : float, default 1e-8 722 Absolute tolerance. Only used when check_exact is False. 723 724 .. versionadded:: 1.1.0 725 obj : str, default 'Index' 726 Specify object name being compared, internally used to show appropriate 727 assertion message. 728 729 Examples 730 -------- 731 >>> from pandas.testing import assert_index_equal 732 >>> a = pd.Index([1, 2, 3]) 733 >>> b = pd.Index([1, 2, 3]) 734 >>> assert_index_equal(a, b) 735 """ 736 __tracebackhide__ = True 737 738 def _check_types(left, right, obj="Index"): 739 if exact: 740 assert_class_equal(left, right, exact=exact, obj=obj) 741 742 # Skip exact dtype checking when `check_categorical` is False 743 if check_categorical: 744 assert_attr_equal("dtype", left, right, obj=obj) 745 746 # allow string-like to have different inferred_types 747 if left.inferred_type in ("string"): 748 assert right.inferred_type in ("string") 749 else: 750 assert_attr_equal("inferred_type", left, right, obj=obj) 751 752 def _get_ilevel_values(index, level): 753 # accept level number only 754 unique = index.levels[level] 755 level_codes = index.codes[level] 756 filled = take_1d(unique._values, level_codes, fill_value=unique._na_value) 757 return unique._shallow_copy(filled, name=index.names[level]) 758 759 if check_less_precise is not no_default: 760 warnings.warn( 761 "The 'check_less_precise' keyword in testing.assert_*_equal " 762 "is deprecated and will be removed in a future version. " 763 "You can stop passing 'check_less_precise' to silence this warning.", 764 FutureWarning, 765 stacklevel=2, 766 ) 767 rtol = atol = _get_tol_from_less_precise(check_less_precise) 768 769 # instance validation 770 _check_isinstance(left, right, Index) 771 772 # class / dtype comparison 773 _check_types(left, right, obj=obj) 774 775 # level comparison 776 if left.nlevels != right.nlevels: 777 msg1 = f"{obj} levels are different" 778 msg2 = f"{left.nlevels}, {left}" 779 msg3 = f"{right.nlevels}, {right}" 780 raise_assert_detail(obj, msg1, msg2, msg3) 781 782 # length comparison 783 if len(left) != len(right): 784 msg1 = f"{obj} length are different" 785 msg2 = f"{len(left)}, {left}" 786 msg3 = f"{len(right)}, {right}" 787 raise_assert_detail(obj, msg1, msg2, msg3) 788 789 # If order doesn't matter then sort the index entries 790 if not check_order: 791 left = Index(safe_sort(left)) 792 right = Index(safe_sort(right)) 793 794 # MultiIndex special comparison for little-friendly error messages 795 if left.nlevels > 1: 796 left = cast(MultiIndex, left) 797 right = cast(MultiIndex, right) 798 799 for level in range(left.nlevels): 800 # cannot use get_level_values here because it can change dtype 801 llevel = _get_ilevel_values(left, level) 802 rlevel = _get_ilevel_values(right, level) 803 804 lobj = f"MultiIndex level [{level}]" 805 assert_index_equal( 806 llevel, 807 rlevel, 808 exact=exact, 809 check_names=check_names, 810 check_exact=check_exact, 811 rtol=rtol, 812 atol=atol, 813 obj=lobj, 814 ) 815 # get_level_values may change dtype 816 _check_types(left.levels[level], right.levels[level], obj=obj) 817 818 # skip exact index checking when `check_categorical` is False 819 if check_exact and check_categorical: 820 if not left.equals(right): 821 diff = np.sum((left.values != right.values).astype(int)) * 100.0 / len(left) 822 msg = f"{obj} values are different ({np.round(diff, 5)} %)" 823 raise_assert_detail(obj, msg, left, right) 824 else: 825 _testing.assert_almost_equal( 826 left.values, 827 right.values, 828 rtol=rtol, 829 atol=atol, 830 check_dtype=exact, 831 obj=obj, 832 lobj=left, 833 robj=right, 834 ) 835 836 # metadata comparison 837 if check_names: 838 assert_attr_equal("names", left, right, obj=obj) 839 if isinstance(left, pd.PeriodIndex) or isinstance(right, pd.PeriodIndex): 840 assert_attr_equal("freq", left, right, obj=obj) 841 if isinstance(left, pd.IntervalIndex) or isinstance(right, pd.IntervalIndex): 842 assert_interval_array_equal(left._values, right._values) 843 844 if check_categorical: 845 if is_categorical_dtype(left.dtype) or is_categorical_dtype(right.dtype): 846 assert_categorical_equal(left._values, right._values, obj=f"{obj} category") 847 848 849def assert_class_equal(left, right, exact: Union[bool, str] = True, obj="Input"): 850 """ 851 Checks classes are equal. 852 """ 853 __tracebackhide__ = True 854 855 def repr_class(x): 856 if isinstance(x, Index): 857 # return Index as it is to include values in the error message 858 return x 859 860 return type(x).__name__ 861 862 if exact == "equiv": 863 if type(left) != type(right): 864 # allow equivalence of Int64Index/RangeIndex 865 types = {type(left).__name__, type(right).__name__} 866 if len(types - {"Int64Index", "RangeIndex"}): 867 msg = f"{obj} classes are not equivalent" 868 raise_assert_detail(obj, msg, repr_class(left), repr_class(right)) 869 elif exact: 870 if type(left) != type(right): 871 msg = f"{obj} classes are different" 872 raise_assert_detail(obj, msg, repr_class(left), repr_class(right)) 873 874 875def assert_attr_equal(attr: str, left, right, obj: str = "Attributes"): 876 """ 877 Check attributes are equal. Both objects must have attribute. 878 879 Parameters 880 ---------- 881 attr : str 882 Attribute name being compared. 883 left : object 884 right : object 885 obj : str, default 'Attributes' 886 Specify object name being compared, internally used to show appropriate 887 assertion message 888 """ 889 __tracebackhide__ = True 890 891 left_attr = getattr(left, attr) 892 right_attr = getattr(right, attr) 893 894 if left_attr is right_attr: 895 return True 896 elif ( 897 is_number(left_attr) 898 and np.isnan(left_attr) 899 and is_number(right_attr) 900 and np.isnan(right_attr) 901 ): 902 # np.nan 903 return True 904 905 try: 906 result = left_attr == right_attr 907 except TypeError: 908 # datetimetz on rhs may raise TypeError 909 result = False 910 if not isinstance(result, bool): 911 result = result.all() 912 913 if result: 914 return True 915 else: 916 msg = f'Attribute "{attr}" are different' 917 raise_assert_detail(obj, msg, left_attr, right_attr) 918 919 920def assert_is_valid_plot_return_object(objs): 921 import matplotlib.pyplot as plt 922 923 if isinstance(objs, (pd.Series, np.ndarray)): 924 for el in objs.ravel(): 925 msg = ( 926 "one of 'objs' is not a matplotlib Axes instance, " 927 f"type encountered {repr(type(el).__name__)}" 928 ) 929 assert isinstance(el, (plt.Axes, dict)), msg 930 else: 931 msg = ( 932 "objs is neither an ndarray of Artist instances nor a single " 933 "ArtistArtist instance, tuple, or dict, 'objs' is a " 934 f"{repr(type(objs).__name__)}" 935 ) 936 assert isinstance(objs, (plt.Artist, tuple, dict)), msg 937 938 939def assert_is_sorted(seq): 940 """Assert that the sequence is sorted.""" 941 if isinstance(seq, (Index, Series)): 942 seq = seq.values 943 # sorting does not change precisions 944 assert_numpy_array_equal(seq, np.sort(np.array(seq))) 945 946 947def assert_categorical_equal( 948 left, right, check_dtype=True, check_category_order=True, obj="Categorical" 949): 950 """ 951 Test that Categoricals are equivalent. 952 953 Parameters 954 ---------- 955 left : Categorical 956 right : Categorical 957 check_dtype : bool, default True 958 Check that integer dtype of the codes are the same 959 check_category_order : bool, default True 960 Whether the order of the categories should be compared, which 961 implies identical integer codes. If False, only the resulting 962 values are compared. The ordered attribute is 963 checked regardless. 964 obj : str, default 'Categorical' 965 Specify object name being compared, internally used to show appropriate 966 assertion message 967 """ 968 _check_isinstance(left, right, Categorical) 969 970 if check_category_order: 971 assert_index_equal(left.categories, right.categories, obj=f"{obj}.categories") 972 assert_numpy_array_equal( 973 left.codes, right.codes, check_dtype=check_dtype, obj=f"{obj}.codes" 974 ) 975 else: 976 try: 977 lc = left.categories.sort_values() 978 rc = right.categories.sort_values() 979 except TypeError: 980 # e.g. '<' not supported between instances of 'int' and 'str' 981 lc, rc = left.categories, right.categories 982 assert_index_equal(lc, rc, obj=f"{obj}.categories") 983 assert_index_equal( 984 left.categories.take(left.codes), 985 right.categories.take(right.codes), 986 obj=f"{obj}.values", 987 ) 988 989 assert_attr_equal("ordered", left, right, obj=obj) 990 991 992def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray"): 993 """ 994 Test that two IntervalArrays are equivalent. 995 996 Parameters 997 ---------- 998 left, right : IntervalArray 999 The IntervalArrays to compare. 1000 exact : bool or {'equiv'}, default 'equiv' 1001 Whether to check the Index class, dtype and inferred_type 1002 are identical. If 'equiv', then RangeIndex can be substituted for 1003 Int64Index as well. 1004 obj : str, default 'IntervalArray' 1005 Specify object name being compared, internally used to show appropriate 1006 assertion message 1007 """ 1008 _check_isinstance(left, right, IntervalArray) 1009 1010 kwargs = {} 1011 if left._left.dtype.kind in ["m", "M"]: 1012 # We have a DatetimeArray or TimedeltaArray 1013 kwargs["check_freq"] = False 1014 1015 assert_equal(left._left, right._left, obj=f"{obj}.left", **kwargs) 1016 assert_equal(left._right, right._right, obj=f"{obj}.left", **kwargs) 1017 1018 assert_attr_equal("closed", left, right, obj=obj) 1019 1020 1021def assert_period_array_equal(left, right, obj="PeriodArray"): 1022 _check_isinstance(left, right, PeriodArray) 1023 1024 assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data") 1025 assert_attr_equal("freq", left, right, obj=obj) 1026 1027 1028def assert_datetime_array_equal(left, right, obj="DatetimeArray", check_freq=True): 1029 __tracebackhide__ = True 1030 _check_isinstance(left, right, DatetimeArray) 1031 1032 assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data") 1033 if check_freq: 1034 assert_attr_equal("freq", left, right, obj=obj) 1035 assert_attr_equal("tz", left, right, obj=obj) 1036 1037 1038def assert_timedelta_array_equal(left, right, obj="TimedeltaArray", check_freq=True): 1039 __tracebackhide__ = True 1040 _check_isinstance(left, right, TimedeltaArray) 1041 assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data") 1042 if check_freq: 1043 assert_attr_equal("freq", left, right, obj=obj) 1044 1045 1046def raise_assert_detail(obj, message, left, right, diff=None, index_values=None): 1047 __tracebackhide__ = True 1048 1049 msg = f"""{obj} are different 1050 1051{message}""" 1052 1053 if isinstance(index_values, np.ndarray): 1054 msg += f"\n[index]: {pprint_thing(index_values)}" 1055 1056 if isinstance(left, np.ndarray): 1057 left = pprint_thing(left) 1058 elif is_categorical_dtype(left): 1059 left = repr(left) 1060 1061 if isinstance(right, np.ndarray): 1062 right = pprint_thing(right) 1063 elif is_categorical_dtype(right): 1064 right = repr(right) 1065 1066 msg += f""" 1067[left]: {left} 1068[right]: {right}""" 1069 1070 if diff is not None: 1071 msg += f"\n[diff]: {diff}" 1072 1073 raise AssertionError(msg) 1074 1075 1076def assert_numpy_array_equal( 1077 left, 1078 right, 1079 strict_nan=False, 1080 check_dtype=True, 1081 err_msg=None, 1082 check_same=None, 1083 obj="numpy array", 1084 index_values=None, 1085): 1086 """ 1087 Check that 'np.ndarray' is equivalent. 1088 1089 Parameters 1090 ---------- 1091 left, right : numpy.ndarray or iterable 1092 The two arrays to be compared. 1093 strict_nan : bool, default False 1094 If True, consider NaN and None to be different. 1095 check_dtype : bool, default True 1096 Check dtype if both a and b are np.ndarray. 1097 err_msg : str, default None 1098 If provided, used as assertion message. 1099 check_same : None|'copy'|'same', default None 1100 Ensure left and right refer/do not refer to the same memory area. 1101 obj : str, default 'numpy array' 1102 Specify object name being compared, internally used to show appropriate 1103 assertion message. 1104 index_values : numpy.ndarray, default None 1105 optional index (shared by both left and right), used in output. 1106 """ 1107 __tracebackhide__ = True 1108 1109 # instance validation 1110 # Show a detailed error message when classes are different 1111 assert_class_equal(left, right, obj=obj) 1112 # both classes must be an np.ndarray 1113 _check_isinstance(left, right, np.ndarray) 1114 1115 def _get_base(obj): 1116 return obj.base if getattr(obj, "base", None) is not None else obj 1117 1118 left_base = _get_base(left) 1119 right_base = _get_base(right) 1120 1121 if check_same == "same": 1122 if left_base is not right_base: 1123 raise AssertionError(f"{repr(left_base)} is not {repr(right_base)}") 1124 elif check_same == "copy": 1125 if left_base is right_base: 1126 raise AssertionError(f"{repr(left_base)} is {repr(right_base)}") 1127 1128 def _raise(left, right, err_msg): 1129 if err_msg is None: 1130 if left.shape != right.shape: 1131 raise_assert_detail( 1132 obj, f"{obj} shapes are different", left.shape, right.shape 1133 ) 1134 1135 diff = 0 1136 for left_arr, right_arr in zip(left, right): 1137 # count up differences 1138 if not array_equivalent(left_arr, right_arr, strict_nan=strict_nan): 1139 diff += 1 1140 1141 diff = diff * 100.0 / left.size 1142 msg = f"{obj} values are different ({np.round(diff, 5)} %)" 1143 raise_assert_detail(obj, msg, left, right, index_values=index_values) 1144 1145 raise AssertionError(err_msg) 1146 1147 # compare shape and values 1148 if not array_equivalent(left, right, strict_nan=strict_nan): 1149 _raise(left, right, err_msg) 1150 1151 if check_dtype: 1152 if isinstance(left, np.ndarray) and isinstance(right, np.ndarray): 1153 assert_attr_equal("dtype", left, right, obj=obj) 1154 1155 1156def assert_extension_array_equal( 1157 left, 1158 right, 1159 check_dtype=True, 1160 index_values=None, 1161 check_less_precise=no_default, 1162 check_exact=False, 1163 rtol: float = 1.0e-5, 1164 atol: float = 1.0e-8, 1165): 1166 """ 1167 Check that left and right ExtensionArrays are equal. 1168 1169 Parameters 1170 ---------- 1171 left, right : ExtensionArray 1172 The two arrays to compare. 1173 check_dtype : bool, default True 1174 Whether to check if the ExtensionArray dtypes are identical. 1175 index_values : numpy.ndarray, default None 1176 Optional index (shared by both left and right), used in output. 1177 check_less_precise : bool or int, default False 1178 Specify comparison precision. Only used when check_exact is False. 1179 5 digits (False) or 3 digits (True) after decimal points are compared. 1180 If int, then specify the digits to compare. 1181 1182 .. deprecated:: 1.1.0 1183 Use `rtol` and `atol` instead to define relative/absolute 1184 tolerance, respectively. Similar to :func:`math.isclose`. 1185 check_exact : bool, default False 1186 Whether to compare number exactly. 1187 rtol : float, default 1e-5 1188 Relative tolerance. Only used when check_exact is False. 1189 1190 .. versionadded:: 1.1.0 1191 atol : float, default 1e-8 1192 Absolute tolerance. Only used when check_exact is False. 1193 1194 .. versionadded:: 1.1.0 1195 1196 Notes 1197 ----- 1198 Missing values are checked separately from valid values. 1199 A mask of missing values is computed for each and checked to match. 1200 The remaining all-valid values are cast to object dtype and checked. 1201 1202 Examples 1203 -------- 1204 >>> from pandas.testing import assert_extension_array_equal 1205 >>> a = pd.Series([1, 2, 3, 4]) 1206 >>> b, c = a.array, a.array 1207 >>> assert_extension_array_equal(b, c) 1208 """ 1209 if check_less_precise is not no_default: 1210 warnings.warn( 1211 "The 'check_less_precise' keyword in testing.assert_*_equal " 1212 "is deprecated and will be removed in a future version. " 1213 "You can stop passing 'check_less_precise' to silence this warning.", 1214 FutureWarning, 1215 stacklevel=2, 1216 ) 1217 rtol = atol = _get_tol_from_less_precise(check_less_precise) 1218 1219 assert isinstance(left, ExtensionArray), "left is not an ExtensionArray" 1220 assert isinstance(right, ExtensionArray), "right is not an ExtensionArray" 1221 if check_dtype: 1222 assert_attr_equal("dtype", left, right, obj="ExtensionArray") 1223 1224 if ( 1225 isinstance(left, DatetimeLikeArrayMixin) 1226 and isinstance(right, DatetimeLikeArrayMixin) 1227 and type(right) == type(left) 1228 ): 1229 # Avoid slow object-dtype comparisons 1230 # np.asarray for case where we have a np.MaskedArray 1231 assert_numpy_array_equal( 1232 np.asarray(left.asi8), np.asarray(right.asi8), index_values=index_values 1233 ) 1234 return 1235 1236 left_na = np.asarray(left.isna()) 1237 right_na = np.asarray(right.isna()) 1238 assert_numpy_array_equal( 1239 left_na, right_na, obj="ExtensionArray NA mask", index_values=index_values 1240 ) 1241 1242 left_valid = np.asarray(left[~left_na].astype(object)) 1243 right_valid = np.asarray(right[~right_na].astype(object)) 1244 if check_exact: 1245 assert_numpy_array_equal( 1246 left_valid, right_valid, obj="ExtensionArray", index_values=index_values 1247 ) 1248 else: 1249 _testing.assert_almost_equal( 1250 left_valid, 1251 right_valid, 1252 check_dtype=check_dtype, 1253 rtol=rtol, 1254 atol=atol, 1255 obj="ExtensionArray", 1256 index_values=index_values, 1257 ) 1258 1259 1260# This could be refactored to use the NDFrame.equals method 1261def assert_series_equal( 1262 left, 1263 right, 1264 check_dtype=True, 1265 check_index_type="equiv", 1266 check_series_type=True, 1267 check_less_precise=no_default, 1268 check_names=True, 1269 check_exact=False, 1270 check_datetimelike_compat=False, 1271 check_categorical=True, 1272 check_category_order=True, 1273 check_freq=True, 1274 check_flags=True, 1275 rtol=1.0e-5, 1276 atol=1.0e-8, 1277 obj="Series", 1278): 1279 """ 1280 Check that left and right Series are equal. 1281 1282 Parameters 1283 ---------- 1284 left : Series 1285 right : Series 1286 check_dtype : bool, default True 1287 Whether to check the Series dtype is identical. 1288 check_index_type : bool or {'equiv'}, default 'equiv' 1289 Whether to check the Index class, dtype and inferred_type 1290 are identical. 1291 check_series_type : bool, default True 1292 Whether to check the Series class is identical. 1293 check_less_precise : bool or int, default False 1294 Specify comparison precision. Only used when check_exact is False. 1295 5 digits (False) or 3 digits (True) after decimal points are compared. 1296 If int, then specify the digits to compare. 1297 1298 When comparing two numbers, if the first number has magnitude less 1299 than 1e-5, we compare the two numbers directly and check whether 1300 they are equivalent within the specified precision. Otherwise, we 1301 compare the **ratio** of the second number to the first number and 1302 check whether it is equivalent to 1 within the specified precision. 1303 1304 .. deprecated:: 1.1.0 1305 Use `rtol` and `atol` instead to define relative/absolute 1306 tolerance, respectively. Similar to :func:`math.isclose`. 1307 check_names : bool, default True 1308 Whether to check the Series and Index names attribute. 1309 check_exact : bool, default False 1310 Whether to compare number exactly. 1311 check_datetimelike_compat : bool, default False 1312 Compare datetime-like which is comparable ignoring dtype. 1313 check_categorical : bool, default True 1314 Whether to compare internal Categorical exactly. 1315 check_category_order : bool, default True 1316 Whether to compare category order of internal Categoricals. 1317 1318 .. versionadded:: 1.0.2 1319 check_freq : bool, default True 1320 Whether to check the `freq` attribute on a DatetimeIndex or TimedeltaIndex. 1321 1322 .. versionadded:: 1.1.0 1323 check_flags : bool, default True 1324 Whether to check the `flags` attribute. 1325 1326 .. versionadded:: 1.2.0 1327 1328 rtol : float, default 1e-5 1329 Relative tolerance. Only used when check_exact is False. 1330 1331 .. versionadded:: 1.1.0 1332 atol : float, default 1e-8 1333 Absolute tolerance. Only used when check_exact is False. 1334 1335 .. versionadded:: 1.1.0 1336 obj : str, default 'Series' 1337 Specify object name being compared, internally used to show appropriate 1338 assertion message. 1339 1340 Examples 1341 -------- 1342 >>> from pandas.testing import assert_series_equal 1343 >>> a = pd.Series([1, 2, 3, 4]) 1344 >>> b = pd.Series([1, 2, 3, 4]) 1345 >>> assert_series_equal(a, b) 1346 """ 1347 __tracebackhide__ = True 1348 1349 if check_less_precise is not no_default: 1350 warnings.warn( 1351 "The 'check_less_precise' keyword in testing.assert_*_equal " 1352 "is deprecated and will be removed in a future version. " 1353 "You can stop passing 'check_less_precise' to silence this warning.", 1354 FutureWarning, 1355 stacklevel=2, 1356 ) 1357 rtol = atol = _get_tol_from_less_precise(check_less_precise) 1358 1359 # instance validation 1360 _check_isinstance(left, right, Series) 1361 1362 if check_series_type: 1363 assert_class_equal(left, right, obj=obj) 1364 1365 # length comparison 1366 if len(left) != len(right): 1367 msg1 = f"{len(left)}, {left.index}" 1368 msg2 = f"{len(right)}, {right.index}" 1369 raise_assert_detail(obj, "Series length are different", msg1, msg2) 1370 1371 if check_flags: 1372 assert left.flags == right.flags, f"{repr(left.flags)} != {repr(right.flags)}" 1373 1374 # index comparison 1375 assert_index_equal( 1376 left.index, 1377 right.index, 1378 exact=check_index_type, 1379 check_names=check_names, 1380 check_exact=check_exact, 1381 check_categorical=check_categorical, 1382 rtol=rtol, 1383 atol=atol, 1384 obj=f"{obj}.index", 1385 ) 1386 if check_freq and isinstance(left.index, (pd.DatetimeIndex, pd.TimedeltaIndex)): 1387 lidx = left.index 1388 ridx = right.index 1389 assert lidx.freq == ridx.freq, (lidx.freq, ridx.freq) 1390 1391 if check_dtype: 1392 # We want to skip exact dtype checking when `check_categorical` 1393 # is False. We'll still raise if only one is a `Categorical`, 1394 # regardless of `check_categorical` 1395 if ( 1396 is_categorical_dtype(left.dtype) 1397 and is_categorical_dtype(right.dtype) 1398 and not check_categorical 1399 ): 1400 pass 1401 else: 1402 assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}") 1403 1404 if check_exact and is_numeric_dtype(left.dtype) and is_numeric_dtype(right.dtype): 1405 left_values = left._values 1406 right_values = right._values 1407 # Only check exact if dtype is numeric 1408 if is_extension_array_dtype(left_values) and is_extension_array_dtype( 1409 right_values 1410 ): 1411 assert_extension_array_equal( 1412 left_values, 1413 right_values, 1414 check_dtype=check_dtype, 1415 index_values=np.asarray(left.index), 1416 ) 1417 else: 1418 assert_numpy_array_equal( 1419 left_values, 1420 right_values, 1421 check_dtype=check_dtype, 1422 obj=str(obj), 1423 index_values=np.asarray(left.index), 1424 ) 1425 elif check_datetimelike_compat and ( 1426 needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype) 1427 ): 1428 # we want to check only if we have compat dtypes 1429 # e.g. integer and M|m are NOT compat, but we can simply check 1430 # the values in that case 1431 1432 # datetimelike may have different objects (e.g. datetime.datetime 1433 # vs Timestamp) but will compare equal 1434 if not Index(left._values).equals(Index(right._values)): 1435 msg = ( 1436 f"[datetimelike_compat=True] {left._values} " 1437 f"is not equal to {right._values}." 1438 ) 1439 raise AssertionError(msg) 1440 elif is_interval_dtype(left.dtype) and is_interval_dtype(right.dtype): 1441 assert_interval_array_equal(left.array, right.array) 1442 elif is_categorical_dtype(left.dtype) or is_categorical_dtype(right.dtype): 1443 _testing.assert_almost_equal( 1444 left._values, 1445 right._values, 1446 rtol=rtol, 1447 atol=atol, 1448 check_dtype=check_dtype, 1449 obj=str(obj), 1450 index_values=np.asarray(left.index), 1451 ) 1452 elif is_extension_array_dtype(left.dtype) and is_extension_array_dtype(right.dtype): 1453 assert_extension_array_equal( 1454 left._values, 1455 right._values, 1456 check_dtype=check_dtype, 1457 index_values=np.asarray(left.index), 1458 ) 1459 elif is_extension_array_dtype_and_needs_i8_conversion( 1460 left.dtype, right.dtype 1461 ) or is_extension_array_dtype_and_needs_i8_conversion(right.dtype, left.dtype): 1462 assert_extension_array_equal( 1463 left._values, 1464 right._values, 1465 check_dtype=check_dtype, 1466 index_values=np.asarray(left.index), 1467 ) 1468 elif needs_i8_conversion(left.dtype) and needs_i8_conversion(right.dtype): 1469 # DatetimeArray or TimedeltaArray 1470 assert_extension_array_equal( 1471 left._values, 1472 right._values, 1473 check_dtype=check_dtype, 1474 index_values=np.asarray(left.index), 1475 ) 1476 else: 1477 _testing.assert_almost_equal( 1478 left._values, 1479 right._values, 1480 rtol=rtol, 1481 atol=atol, 1482 check_dtype=check_dtype, 1483 obj=str(obj), 1484 index_values=np.asarray(left.index), 1485 ) 1486 1487 # metadata comparison 1488 if check_names: 1489 assert_attr_equal("name", left, right, obj=obj) 1490 1491 if check_categorical: 1492 if is_categorical_dtype(left.dtype) or is_categorical_dtype(right.dtype): 1493 assert_categorical_equal( 1494 left._values, 1495 right._values, 1496 obj=f"{obj} category", 1497 check_category_order=check_category_order, 1498 ) 1499 1500 1501# This could be refactored to use the NDFrame.equals method 1502def assert_frame_equal( 1503 left, 1504 right, 1505 check_dtype=True, 1506 check_index_type="equiv", 1507 check_column_type="equiv", 1508 check_frame_type=True, 1509 check_less_precise=no_default, 1510 check_names=True, 1511 by_blocks=False, 1512 check_exact=False, 1513 check_datetimelike_compat=False, 1514 check_categorical=True, 1515 check_like=False, 1516 check_freq=True, 1517 check_flags=True, 1518 rtol=1.0e-5, 1519 atol=1.0e-8, 1520 obj="DataFrame", 1521): 1522 """ 1523 Check that left and right DataFrame are equal. 1524 1525 This function is intended to compare two DataFrames and output any 1526 differences. Is is mostly intended for use in unit tests. 1527 Additional parameters allow varying the strictness of the 1528 equality checks performed. 1529 1530 Parameters 1531 ---------- 1532 left : DataFrame 1533 First DataFrame to compare. 1534 right : DataFrame 1535 Second DataFrame to compare. 1536 check_dtype : bool, default True 1537 Whether to check the DataFrame dtype is identical. 1538 check_index_type : bool or {'equiv'}, default 'equiv' 1539 Whether to check the Index class, dtype and inferred_type 1540 are identical. 1541 check_column_type : bool or {'equiv'}, default 'equiv' 1542 Whether to check the columns class, dtype and inferred_type 1543 are identical. Is passed as the ``exact`` argument of 1544 :func:`assert_index_equal`. 1545 check_frame_type : bool, default True 1546 Whether to check the DataFrame class is identical. 1547 check_less_precise : bool or int, default False 1548 Specify comparison precision. Only used when check_exact is False. 1549 5 digits (False) or 3 digits (True) after decimal points are compared. 1550 If int, then specify the digits to compare. 1551 1552 When comparing two numbers, if the first number has magnitude less 1553 than 1e-5, we compare the two numbers directly and check whether 1554 they are equivalent within the specified precision. Otherwise, we 1555 compare the **ratio** of the second number to the first number and 1556 check whether it is equivalent to 1 within the specified precision. 1557 1558 .. deprecated:: 1.1.0 1559 Use `rtol` and `atol` instead to define relative/absolute 1560 tolerance, respectively. Similar to :func:`math.isclose`. 1561 check_names : bool, default True 1562 Whether to check that the `names` attribute for both the `index` 1563 and `column` attributes of the DataFrame is identical. 1564 by_blocks : bool, default False 1565 Specify how to compare internal data. If False, compare by columns. 1566 If True, compare by blocks. 1567 check_exact : bool, default False 1568 Whether to compare number exactly. 1569 check_datetimelike_compat : bool, default False 1570 Compare datetime-like which is comparable ignoring dtype. 1571 check_categorical : bool, default True 1572 Whether to compare internal Categorical exactly. 1573 check_like : bool, default False 1574 If True, ignore the order of index & columns. 1575 Note: index labels must match their respective rows 1576 (same as in columns) - same labels must be with the same data. 1577 check_freq : bool, default True 1578 Whether to check the `freq` attribute on a DatetimeIndex or TimedeltaIndex. 1579 1580 .. versionadded:: 1.1.0 1581 check_flags : bool, default True 1582 Whether to check the `flags` attribute. 1583 rtol : float, default 1e-5 1584 Relative tolerance. Only used when check_exact is False. 1585 1586 .. versionadded:: 1.1.0 1587 atol : float, default 1e-8 1588 Absolute tolerance. Only used when check_exact is False. 1589 1590 .. versionadded:: 1.1.0 1591 obj : str, default 'DataFrame' 1592 Specify object name being compared, internally used to show appropriate 1593 assertion message. 1594 1595 See Also 1596 -------- 1597 assert_series_equal : Equivalent method for asserting Series equality. 1598 DataFrame.equals : Check DataFrame equality. 1599 1600 Examples 1601 -------- 1602 This example shows comparing two DataFrames that are equal 1603 but with columns of differing dtypes. 1604 1605 >>> from pandas._testing import assert_frame_equal 1606 >>> df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) 1607 >>> df2 = pd.DataFrame({'a': [1, 2], 'b': [3.0, 4.0]}) 1608 1609 df1 equals itself. 1610 1611 >>> assert_frame_equal(df1, df1) 1612 1613 df1 differs from df2 as column 'b' is of a different type. 1614 1615 >>> assert_frame_equal(df1, df2) 1616 Traceback (most recent call last): 1617 ... 1618 AssertionError: Attributes of DataFrame.iloc[:, 1] (column name="b") are different 1619 1620 Attribute "dtype" are different 1621 [left]: int64 1622 [right]: float64 1623 1624 Ignore differing dtypes in columns with check_dtype. 1625 1626 >>> assert_frame_equal(df1, df2, check_dtype=False) 1627 """ 1628 __tracebackhide__ = True 1629 1630 if check_less_precise is not no_default: 1631 warnings.warn( 1632 "The 'check_less_precise' keyword in testing.assert_*_equal " 1633 "is deprecated and will be removed in a future version. " 1634 "You can stop passing 'check_less_precise' to silence this warning.", 1635 FutureWarning, 1636 stacklevel=2, 1637 ) 1638 rtol = atol = _get_tol_from_less_precise(check_less_precise) 1639 1640 # instance validation 1641 _check_isinstance(left, right, DataFrame) 1642 1643 if check_frame_type: 1644 assert isinstance(left, type(right)) 1645 # assert_class_equal(left, right, obj=obj) 1646 1647 # shape comparison 1648 if left.shape != right.shape: 1649 raise_assert_detail( 1650 obj, f"{obj} shape mismatch", f"{repr(left.shape)}", f"{repr(right.shape)}" 1651 ) 1652 1653 if check_flags: 1654 assert left.flags == right.flags, f"{repr(left.flags)} != {repr(right.flags)}" 1655 1656 # index comparison 1657 assert_index_equal( 1658 left.index, 1659 right.index, 1660 exact=check_index_type, 1661 check_names=check_names, 1662 check_exact=check_exact, 1663 check_categorical=check_categorical, 1664 check_order=not check_like, 1665 rtol=rtol, 1666 atol=atol, 1667 obj=f"{obj}.index", 1668 ) 1669 1670 # column comparison 1671 assert_index_equal( 1672 left.columns, 1673 right.columns, 1674 exact=check_column_type, 1675 check_names=check_names, 1676 check_exact=check_exact, 1677 check_categorical=check_categorical, 1678 check_order=not check_like, 1679 rtol=rtol, 1680 atol=atol, 1681 obj=f"{obj}.columns", 1682 ) 1683 1684 if check_like: 1685 left, right = left.reindex_like(right), right 1686 1687 # compare by blocks 1688 if by_blocks: 1689 rblocks = right._to_dict_of_blocks() 1690 lblocks = left._to_dict_of_blocks() 1691 for dtype in list(set(list(lblocks.keys()) + list(rblocks.keys()))): 1692 assert dtype in lblocks 1693 assert dtype in rblocks 1694 assert_frame_equal( 1695 lblocks[dtype], rblocks[dtype], check_dtype=check_dtype, obj=obj 1696 ) 1697 1698 # compare by columns 1699 else: 1700 for i, col in enumerate(left.columns): 1701 assert col in right 1702 lcol = left.iloc[:, i] 1703 rcol = right.iloc[:, i] 1704 assert_series_equal( 1705 lcol, 1706 rcol, 1707 check_dtype=check_dtype, 1708 check_index_type=check_index_type, 1709 check_exact=check_exact, 1710 check_names=check_names, 1711 check_datetimelike_compat=check_datetimelike_compat, 1712 check_categorical=check_categorical, 1713 check_freq=check_freq, 1714 obj=f'{obj}.iloc[:, {i}] (column name="{col}")', 1715 rtol=rtol, 1716 atol=atol, 1717 ) 1718 1719 1720def assert_equal(left, right, **kwargs): 1721 """ 1722 Wrapper for tm.assert_*_equal to dispatch to the appropriate test function. 1723 1724 Parameters 1725 ---------- 1726 left, right : Index, Series, DataFrame, ExtensionArray, or np.ndarray 1727 The two items to be compared. 1728 **kwargs 1729 All keyword arguments are passed through to the underlying assert method. 1730 """ 1731 __tracebackhide__ = True 1732 1733 if isinstance(left, pd.Index): 1734 assert_index_equal(left, right, **kwargs) 1735 if isinstance(left, (pd.DatetimeIndex, pd.TimedeltaIndex)): 1736 assert left.freq == right.freq, (left.freq, right.freq) 1737 elif isinstance(left, pd.Series): 1738 assert_series_equal(left, right, **kwargs) 1739 elif isinstance(left, pd.DataFrame): 1740 assert_frame_equal(left, right, **kwargs) 1741 elif isinstance(left, IntervalArray): 1742 assert_interval_array_equal(left, right, **kwargs) 1743 elif isinstance(left, PeriodArray): 1744 assert_period_array_equal(left, right, **kwargs) 1745 elif isinstance(left, DatetimeArray): 1746 assert_datetime_array_equal(left, right, **kwargs) 1747 elif isinstance(left, TimedeltaArray): 1748 assert_timedelta_array_equal(left, right, **kwargs) 1749 elif isinstance(left, ExtensionArray): 1750 assert_extension_array_equal(left, right, **kwargs) 1751 elif isinstance(left, np.ndarray): 1752 assert_numpy_array_equal(left, right, **kwargs) 1753 elif isinstance(left, str): 1754 assert kwargs == {} 1755 assert left == right 1756 else: 1757 raise NotImplementedError(type(left)) 1758 1759 1760def box_expected(expected, box_cls, transpose=True): 1761 """ 1762 Helper function to wrap the expected output of a test in a given box_class. 1763 1764 Parameters 1765 ---------- 1766 expected : np.ndarray, Index, Series 1767 box_cls : {Index, Series, DataFrame} 1768 1769 Returns 1770 ------- 1771 subclass of box_cls 1772 """ 1773 if box_cls is pd.array: 1774 expected = pd.array(expected) 1775 elif box_cls is pd.Index: 1776 expected = pd.Index(expected) 1777 elif box_cls is pd.Series: 1778 expected = pd.Series(expected) 1779 elif box_cls is pd.DataFrame: 1780 expected = pd.Series(expected).to_frame() 1781 if transpose: 1782 # for vector operations, we need a DataFrame to be a single-row, 1783 # not a single-column, in order to operate against non-DataFrame 1784 # vectors of the same length. 1785 expected = expected.T 1786 elif box_cls is PeriodArray: 1787 # the PeriodArray constructor is not as flexible as period_array 1788 expected = period_array(expected) 1789 elif box_cls is DatetimeArray: 1790 expected = DatetimeArray(expected) 1791 elif box_cls is TimedeltaArray: 1792 expected = TimedeltaArray(expected) 1793 elif box_cls is np.ndarray: 1794 expected = np.array(expected) 1795 elif box_cls is to_array: 1796 expected = to_array(expected) 1797 else: 1798 raise NotImplementedError(box_cls) 1799 return expected 1800 1801 1802def to_array(obj): 1803 # temporary implementation until we get pd.array in place 1804 dtype = getattr(obj, "dtype", None) 1805 1806 if is_period_dtype(dtype): 1807 return period_array(obj) 1808 elif is_datetime64_dtype(dtype) or is_datetime64tz_dtype(dtype): 1809 return DatetimeArray._from_sequence(obj) 1810 elif is_timedelta64_dtype(dtype): 1811 return TimedeltaArray._from_sequence(obj) 1812 else: 1813 return np.array(obj) 1814 1815 1816# ----------------------------------------------------------------------------- 1817# Sparse 1818 1819 1820def assert_sp_array_equal(left, right): 1821 """ 1822 Check that the left and right SparseArray are equal. 1823 1824 Parameters 1825 ---------- 1826 left : SparseArray 1827 right : SparseArray 1828 """ 1829 _check_isinstance(left, right, pd.arrays.SparseArray) 1830 1831 assert_numpy_array_equal(left.sp_values, right.sp_values) 1832 1833 # SparseIndex comparison 1834 assert isinstance(left.sp_index, pd._libs.sparse.SparseIndex) 1835 assert isinstance(right.sp_index, pd._libs.sparse.SparseIndex) 1836 1837 left_index = left.sp_index 1838 right_index = right.sp_index 1839 1840 if not left_index.equals(right_index): 1841 raise_assert_detail( 1842 "SparseArray.index", "index are not equal", left_index, right_index 1843 ) 1844 else: 1845 # Just ensure a 1846 pass 1847 1848 assert_attr_equal("fill_value", left, right) 1849 assert_attr_equal("dtype", left, right) 1850 assert_numpy_array_equal(left.to_dense(), right.to_dense()) 1851 1852 1853# ----------------------------------------------------------------------------- 1854# Others 1855 1856 1857def assert_contains_all(iterable, dic): 1858 for k in iterable: 1859 assert k in dic, f"Did not contain item: {repr(k)}" 1860 1861 1862def assert_copy(iter1, iter2, **eql_kwargs): 1863 """ 1864 iter1, iter2: iterables that produce elements 1865 comparable with assert_almost_equal 1866 1867 Checks that the elements are equal, but not 1868 the same object. (Does not check that items 1869 in sequences are also not the same object) 1870 """ 1871 for elem1, elem2 in zip(iter1, iter2): 1872 assert_almost_equal(elem1, elem2, **eql_kwargs) 1873 msg = ( 1874 f"Expected object {repr(type(elem1))} and object {repr(type(elem2))} to be " 1875 "different objects, but they were the same object." 1876 ) 1877 assert elem1 is not elem2, msg 1878 1879 1880def is_extension_array_dtype_and_needs_i8_conversion(left_dtype, right_dtype) -> bool: 1881 """ 1882 Checks that we have the combination of an ExtensionArraydtype and 1883 a dtype that should be converted to int64 1884 1885 Returns 1886 ------- 1887 bool 1888 1889 Related to issue #37609 1890 """ 1891 return is_extension_array_dtype(left_dtype) and needs_i8_conversion(right_dtype) 1892 1893 1894def getCols(k): 1895 return string.ascii_uppercase[:k] 1896 1897 1898# make index 1899def makeStringIndex(k=10, name=None): 1900 return Index(rands_array(nchars=10, size=k), name=name) 1901 1902 1903def makeUnicodeIndex(k=10, name=None): 1904 return Index(randu_array(nchars=10, size=k), name=name) 1905 1906 1907def makeCategoricalIndex(k=10, n=3, name=None, **kwargs): 1908 """ make a length k index or n categories """ 1909 x = rands_array(nchars=4, size=n) 1910 return CategoricalIndex( 1911 Categorical.from_codes(np.arange(k) % n, categories=x), name=name, **kwargs 1912 ) 1913 1914 1915def makeIntervalIndex(k=10, name=None, **kwargs): 1916 """ make a length k IntervalIndex """ 1917 x = np.linspace(0, 100, num=(k + 1)) 1918 return IntervalIndex.from_breaks(x, name=name, **kwargs) 1919 1920 1921def makeBoolIndex(k=10, name=None): 1922 if k == 1: 1923 return Index([True], name=name) 1924 elif k == 2: 1925 return Index([False, True], name=name) 1926 return Index([False, True] + [False] * (k - 2), name=name) 1927 1928 1929def makeIntIndex(k=10, name=None): 1930 return Index(list(range(k)), name=name) 1931 1932 1933def makeUIntIndex(k=10, name=None): 1934 return Index([2 ** 63 + i for i in range(k)], name=name) 1935 1936 1937def makeRangeIndex(k=10, name=None, **kwargs): 1938 return RangeIndex(0, k, 1, name=name, **kwargs) 1939 1940 1941def makeFloatIndex(k=10, name=None): 1942 values = sorted(np.random.random_sample(k)) - np.random.random_sample(1) 1943 return Index(values * (10 ** np.random.randint(0, 9)), name=name) 1944 1945 1946def makeDateIndex(k=10, freq="B", name=None, **kwargs): 1947 dt = datetime(2000, 1, 1) 1948 dr = bdate_range(dt, periods=k, freq=freq, name=name) 1949 return DatetimeIndex(dr, name=name, **kwargs) 1950 1951 1952def makeTimedeltaIndex(k=10, freq="D", name=None, **kwargs): 1953 return pd.timedelta_range(start="1 day", periods=k, freq=freq, name=name, **kwargs) 1954 1955 1956def makePeriodIndex(k=10, name=None, **kwargs): 1957 dt = datetime(2000, 1, 1) 1958 return pd.period_range(start=dt, periods=k, freq="B", name=name, **kwargs) 1959 1960 1961def makeMultiIndex(k=10, names=None, **kwargs): 1962 return MultiIndex.from_product((("foo", "bar"), (1, 2)), names=names, **kwargs) 1963 1964 1965_names = [ 1966 "Alice", 1967 "Bob", 1968 "Charlie", 1969 "Dan", 1970 "Edith", 1971 "Frank", 1972 "George", 1973 "Hannah", 1974 "Ingrid", 1975 "Jerry", 1976 "Kevin", 1977 "Laura", 1978 "Michael", 1979 "Norbert", 1980 "Oliver", 1981 "Patricia", 1982 "Quinn", 1983 "Ray", 1984 "Sarah", 1985 "Tim", 1986 "Ursula", 1987 "Victor", 1988 "Wendy", 1989 "Xavier", 1990 "Yvonne", 1991 "Zelda", 1992] 1993 1994 1995def _make_timeseries(start="2000-01-01", end="2000-12-31", freq="1D", seed=None): 1996 """ 1997 Make a DataFrame with a DatetimeIndex 1998 1999 Parameters 2000 ---------- 2001 start : str or Timestamp, default "2000-01-01" 2002 The start of the index. Passed to date_range with `freq`. 2003 end : str or Timestamp, default "2000-12-31" 2004 The end of the index. Passed to date_range with `freq`. 2005 freq : str or Freq 2006 The frequency to use for the DatetimeIndex 2007 seed : int, optional 2008 The random state seed. 2009 2010 * name : object dtype with string names 2011 * id : int dtype with 2012 * x, y : float dtype 2013 2014 Examples 2015 -------- 2016 >>> _make_timeseries() 2017 id name x y 2018 timestamp 2019 2000-01-01 982 Frank 0.031261 0.986727 2020 2000-01-02 1025 Edith -0.086358 -0.032920 2021 2000-01-03 982 Edith 0.473177 0.298654 2022 2000-01-04 1009 Sarah 0.534344 -0.750377 2023 2000-01-05 963 Zelda -0.271573 0.054424 2024 ... ... ... ... ... 2025 2000-12-27 980 Ingrid -0.132333 -0.422195 2026 2000-12-28 972 Frank -0.376007 -0.298687 2027 2000-12-29 1009 Ursula -0.865047 -0.503133 2028 2000-12-30 1000 Hannah -0.063757 -0.507336 2029 2000-12-31 972 Tim -0.869120 0.531685 2030 """ 2031 index = pd.date_range(start=start, end=end, freq=freq, name="timestamp") 2032 n = len(index) 2033 state = np.random.RandomState(seed) 2034 columns = { 2035 "name": state.choice(_names, size=n), 2036 "id": state.poisson(1000, size=n), 2037 "x": state.rand(n) * 2 - 1, 2038 "y": state.rand(n) * 2 - 1, 2039 } 2040 df = pd.DataFrame(columns, index=index, columns=sorted(columns)) 2041 if df.index[-1] == end: 2042 df = df.iloc[:-1] 2043 return df 2044 2045 2046def index_subclass_makers_generator(): 2047 make_index_funcs = [ 2048 makeDateIndex, 2049 makePeriodIndex, 2050 makeTimedeltaIndex, 2051 makeRangeIndex, 2052 makeIntervalIndex, 2053 makeCategoricalIndex, 2054 makeMultiIndex, 2055 ] 2056 yield from make_index_funcs 2057 2058 2059def all_timeseries_index_generator(k=10): 2060 """ 2061 Generator which can be iterated over to get instances of all the classes 2062 which represent time-series. 2063 2064 Parameters 2065 ---------- 2066 k: length of each of the index instances 2067 """ 2068 make_index_funcs = [makeDateIndex, makePeriodIndex, makeTimedeltaIndex] 2069 for make_index_func in make_index_funcs: 2070 # pandas\_testing.py:1986: error: Cannot call function of unknown type 2071 yield make_index_func(k=k) # type: ignore[operator] 2072 2073 2074# make series 2075def makeFloatSeries(name=None): 2076 index = makeStringIndex(_N) 2077 return Series(randn(_N), index=index, name=name) 2078 2079 2080def makeStringSeries(name=None): 2081 index = makeStringIndex(_N) 2082 return Series(randn(_N), index=index, name=name) 2083 2084 2085def makeObjectSeries(name=None): 2086 data = makeStringIndex(_N) 2087 data = Index(data, dtype=object) 2088 index = makeStringIndex(_N) 2089 return Series(data, index=index, name=name) 2090 2091 2092def getSeriesData(): 2093 index = makeStringIndex(_N) 2094 return {c: Series(randn(_N), index=index) for c in getCols(_K)} 2095 2096 2097def makeTimeSeries(nper=None, freq="B", name=None): 2098 if nper is None: 2099 nper = _N 2100 return Series(randn(nper), index=makeDateIndex(nper, freq=freq), name=name) 2101 2102 2103def makePeriodSeries(nper=None, name=None): 2104 if nper is None: 2105 nper = _N 2106 return Series(randn(nper), index=makePeriodIndex(nper), name=name) 2107 2108 2109def getTimeSeriesData(nper=None, freq="B"): 2110 return {c: makeTimeSeries(nper, freq) for c in getCols(_K)} 2111 2112 2113def getPeriodData(nper=None): 2114 return {c: makePeriodSeries(nper) for c in getCols(_K)} 2115 2116 2117# make frame 2118def makeTimeDataFrame(nper=None, freq="B"): 2119 data = getTimeSeriesData(nper, freq) 2120 return DataFrame(data) 2121 2122 2123def makeDataFrame(): 2124 data = getSeriesData() 2125 return DataFrame(data) 2126 2127 2128def getMixedTypeDict(): 2129 index = Index(["a", "b", "c", "d", "e"]) 2130 2131 data = { 2132 "A": [0.0, 1.0, 2.0, 3.0, 4.0], 2133 "B": [0.0, 1.0, 0.0, 1.0, 0.0], 2134 "C": ["foo1", "foo2", "foo3", "foo4", "foo5"], 2135 "D": bdate_range("1/1/2009", periods=5), 2136 } 2137 2138 return index, data 2139 2140 2141def makeMixedDataFrame(): 2142 return DataFrame(getMixedTypeDict()[1]) 2143 2144 2145def makePeriodFrame(nper=None): 2146 data = getPeriodData(nper) 2147 return DataFrame(data) 2148 2149 2150def makeCustomIndex( 2151 nentries, nlevels, prefix="#", names=False, ndupe_l=None, idx_type=None 2152): 2153 """ 2154 Create an index/multindex with given dimensions, levels, names, etc' 2155 2156 nentries - number of entries in index 2157 nlevels - number of levels (> 1 produces multindex) 2158 prefix - a string prefix for labels 2159 names - (Optional), bool or list of strings. if True will use default 2160 names, if false will use no names, if a list is given, the name of 2161 each level in the index will be taken from the list. 2162 ndupe_l - (Optional), list of ints, the number of rows for which the 2163 label will repeated at the corresponding level, you can specify just 2164 the first few, the rest will use the default ndupe_l of 1. 2165 len(ndupe_l) <= nlevels. 2166 idx_type - "i"/"f"/"s"/"u"/"dt"/"p"/"td". 2167 If idx_type is not None, `idx_nlevels` must be 1. 2168 "i"/"f" creates an integer/float index, 2169 "s"/"u" creates a string/unicode index 2170 "dt" create a datetime index. 2171 "td" create a datetime index. 2172 2173 if unspecified, string labels will be generated. 2174 """ 2175 if ndupe_l is None: 2176 ndupe_l = [1] * nlevels 2177 assert is_sequence(ndupe_l) and len(ndupe_l) <= nlevels 2178 assert names is None or names is False or names is True or len(names) is nlevels 2179 assert idx_type is None or ( 2180 idx_type in ("i", "f", "s", "u", "dt", "p", "td") and nlevels == 1 2181 ) 2182 2183 if names is True: 2184 # build default names 2185 names = [prefix + str(i) for i in range(nlevels)] 2186 if names is False: 2187 # pass None to index constructor for no name 2188 names = None 2189 2190 # make singleton case uniform 2191 if isinstance(names, str) and nlevels == 1: 2192 names = [names] 2193 2194 # specific 1D index type requested? 2195 idx_func = { 2196 "i": makeIntIndex, 2197 "f": makeFloatIndex, 2198 "s": makeStringIndex, 2199 "u": makeUnicodeIndex, 2200 "dt": makeDateIndex, 2201 "td": makeTimedeltaIndex, 2202 "p": makePeriodIndex, 2203 }.get(idx_type) 2204 if idx_func: 2205 # pandas\_testing.py:2120: error: Cannot call function of unknown type 2206 idx = idx_func(nentries) # type: ignore[operator] 2207 # but we need to fill in the name 2208 if names: 2209 idx.name = names[0] 2210 return idx 2211 elif idx_type is not None: 2212 raise ValueError( 2213 f"{repr(idx_type)} is not a legal value for `idx_type`, " 2214 "use 'i'/'f'/'s'/'u'/'dt'/'p'/'td'." 2215 ) 2216 2217 if len(ndupe_l) < nlevels: 2218 ndupe_l.extend([1] * (nlevels - len(ndupe_l))) 2219 assert len(ndupe_l) == nlevels 2220 2221 assert all(x > 0 for x in ndupe_l) 2222 2223 tuples = [] 2224 for i in range(nlevels): 2225 2226 def keyfunc(x): 2227 import re 2228 2229 numeric_tuple = re.sub(r"[^\d_]_?", "", x).split("_") 2230 return [int(num) for num in numeric_tuple] 2231 2232 # build a list of lists to create the index from 2233 div_factor = nentries // ndupe_l[i] + 1 2234 # pandas\_testing.py:2148: error: Need type annotation for 'cnt' 2235 cnt = Counter() # type: ignore[var-annotated] 2236 for j in range(div_factor): 2237 label = f"{prefix}_l{i}_g{j}" 2238 cnt[label] = ndupe_l[i] 2239 # cute Counter trick 2240 result = sorted(cnt.elements(), key=keyfunc)[:nentries] 2241 tuples.append(result) 2242 2243 tuples = list(zip(*tuples)) 2244 2245 # convert tuples to index 2246 if nentries == 1: 2247 # we have a single level of tuples, i.e. a regular Index 2248 index = Index(tuples[0], name=names[0]) 2249 elif nlevels == 1: 2250 name = None if names is None else names[0] 2251 index = Index((x[0] for x in tuples), name=name) 2252 else: 2253 index = MultiIndex.from_tuples(tuples, names=names) 2254 return index 2255 2256 2257def makeCustomDataframe( 2258 nrows, 2259 ncols, 2260 c_idx_names=True, 2261 r_idx_names=True, 2262 c_idx_nlevels=1, 2263 r_idx_nlevels=1, 2264 data_gen_f=None, 2265 c_ndupe_l=None, 2266 r_ndupe_l=None, 2267 dtype=None, 2268 c_idx_type=None, 2269 r_idx_type=None, 2270): 2271 """ 2272 Create a DataFrame using supplied parameters. 2273 2274 Parameters 2275 ---------- 2276 nrows, ncols - number of data rows/cols 2277 c_idx_names, idx_names - False/True/list of strings, yields No names , 2278 default names or uses the provided names for the levels of the 2279 corresponding index. You can provide a single string when 2280 c_idx_nlevels ==1. 2281 c_idx_nlevels - number of levels in columns index. > 1 will yield MultiIndex 2282 r_idx_nlevels - number of levels in rows index. > 1 will yield MultiIndex 2283 data_gen_f - a function f(row,col) which return the data value 2284 at that position, the default generator used yields values of the form 2285 "RxCy" based on position. 2286 c_ndupe_l, r_ndupe_l - list of integers, determines the number 2287 of duplicates for each label at a given level of the corresponding 2288 index. The default `None` value produces a multiplicity of 1 across 2289 all levels, i.e. a unique index. Will accept a partial list of length 2290 N < idx_nlevels, for just the first N levels. If ndupe doesn't divide 2291 nrows/ncol, the last label might have lower multiplicity. 2292 dtype - passed to the DataFrame constructor as is, in case you wish to 2293 have more control in conjunction with a custom `data_gen_f` 2294 r_idx_type, c_idx_type - "i"/"f"/"s"/"u"/"dt"/"td". 2295 If idx_type is not None, `idx_nlevels` must be 1. 2296 "i"/"f" creates an integer/float index, 2297 "s"/"u" creates a string/unicode index 2298 "dt" create a datetime index. 2299 "td" create a timedelta index. 2300 2301 if unspecified, string labels will be generated. 2302 2303 Examples 2304 -------- 2305 # 5 row, 3 columns, default names on both, single index on both axis 2306 >> makeCustomDataframe(5,3) 2307 2308 # make the data a random int between 1 and 100 2309 >> mkdf(5,3,data_gen_f=lambda r,c:randint(1,100)) 2310 2311 # 2-level multiindex on rows with each label duplicated 2312 # twice on first level, default names on both axis, single 2313 # index on both axis 2314 >> a=makeCustomDataframe(5,3,r_idx_nlevels=2,r_ndupe_l=[2]) 2315 2316 # DatetimeIndex on row, index with unicode labels on columns 2317 # no names on either axis 2318 >> a=makeCustomDataframe(5,3,c_idx_names=False,r_idx_names=False, 2319 r_idx_type="dt",c_idx_type="u") 2320 2321 # 4-level multindex on rows with names provided, 2-level multindex 2322 # on columns with default labels and default names. 2323 >> a=makeCustomDataframe(5,3,r_idx_nlevels=4, 2324 r_idx_names=["FEE","FI","FO","FAM"], 2325 c_idx_nlevels=2) 2326 2327 >> a=mkdf(5,3,r_idx_nlevels=2,c_idx_nlevels=4) 2328 """ 2329 assert c_idx_nlevels > 0 2330 assert r_idx_nlevels > 0 2331 assert r_idx_type is None or ( 2332 r_idx_type in ("i", "f", "s", "u", "dt", "p", "td") and r_idx_nlevels == 1 2333 ) 2334 assert c_idx_type is None or ( 2335 c_idx_type in ("i", "f", "s", "u", "dt", "p", "td") and c_idx_nlevels == 1 2336 ) 2337 2338 columns = makeCustomIndex( 2339 ncols, 2340 nlevels=c_idx_nlevels, 2341 prefix="C", 2342 names=c_idx_names, 2343 ndupe_l=c_ndupe_l, 2344 idx_type=c_idx_type, 2345 ) 2346 index = makeCustomIndex( 2347 nrows, 2348 nlevels=r_idx_nlevels, 2349 prefix="R", 2350 names=r_idx_names, 2351 ndupe_l=r_ndupe_l, 2352 idx_type=r_idx_type, 2353 ) 2354 2355 # by default, generate data based on location 2356 if data_gen_f is None: 2357 data_gen_f = lambda r, c: f"R{r}C{c}" 2358 2359 data = [[data_gen_f(r, c) for c in range(ncols)] for r in range(nrows)] 2360 2361 return DataFrame(data, index, columns, dtype=dtype) 2362 2363 2364def _create_missing_idx(nrows, ncols, density, random_state=None): 2365 if random_state is None: 2366 random_state = np.random 2367 else: 2368 random_state = np.random.RandomState(random_state) 2369 2370 # below is cribbed from scipy.sparse 2371 size = int(np.round((1 - density) * nrows * ncols)) 2372 # generate a few more to ensure unique values 2373 min_rows = 5 2374 fac = 1.02 2375 extra_size = min(size + min_rows, fac * size) 2376 2377 def _gen_unique_rand(rng, _extra_size): 2378 ind = rng.rand(int(_extra_size)) 2379 return np.unique(np.floor(ind * nrows * ncols))[:size] 2380 2381 ind = _gen_unique_rand(random_state, extra_size) 2382 while ind.size < size: 2383 extra_size *= 1.05 2384 ind = _gen_unique_rand(random_state, extra_size) 2385 2386 j = np.floor(ind * 1.0 / nrows).astype(int) 2387 i = (ind - j * nrows).astype(int) 2388 return i.tolist(), j.tolist() 2389 2390 2391def makeMissingDataframe(density=0.9, random_state=None): 2392 df = makeDataFrame() 2393 # pandas\_testing.py:2306: error: "_create_missing_idx" gets multiple 2394 # values for keyword argument "density" [misc] 2395 2396 # pandas\_testing.py:2306: error: "_create_missing_idx" gets multiple 2397 # values for keyword argument "random_state" [misc] 2398 i, j = _create_missing_idx( # type: ignore[misc] 2399 *df.shape, density=density, random_state=random_state 2400 ) 2401 df.values[i, j] = np.nan 2402 return df 2403 2404 2405def optional_args(decorator): 2406 """ 2407 allows a decorator to take optional positional and keyword arguments. 2408 Assumes that taking a single, callable, positional argument means that 2409 it is decorating a function, i.e. something like this:: 2410 2411 @my_decorator 2412 def function(): pass 2413 2414 Calls decorator with decorator(f, *args, **kwargs) 2415 """ 2416 2417 @wraps(decorator) 2418 def wrapper(*args, **kwargs): 2419 def dec(f): 2420 return decorator(f, *args, **kwargs) 2421 2422 is_decorating = not kwargs and len(args) == 1 and callable(args[0]) 2423 if is_decorating: 2424 f = args[0] 2425 # pandas\_testing.py:2331: error: Incompatible types in assignment 2426 # (expression has type "List[<nothing>]", variable has type 2427 # "Tuple[Any, ...]") 2428 args = [] # type: ignore[assignment] 2429 return dec(f) 2430 else: 2431 return dec 2432 2433 return wrapper 2434 2435 2436# skip tests on exceptions with this message 2437_network_error_messages = ( 2438 # 'urlopen error timed out', 2439 # 'timeout: timed out', 2440 # 'socket.timeout: timed out', 2441 "timed out", 2442 "Server Hangup", 2443 "HTTP Error 503: Service Unavailable", 2444 "502: Proxy Error", 2445 "HTTP Error 502: internal error", 2446 "HTTP Error 502", 2447 "HTTP Error 503", 2448 "HTTP Error 403", 2449 "HTTP Error 400", 2450 "Temporary failure in name resolution", 2451 "Name or service not known", 2452 "Connection refused", 2453 "certificate verify", 2454) 2455 2456# or this e.errno/e.reason.errno 2457_network_errno_vals = ( 2458 101, # Network is unreachable 2459 111, # Connection refused 2460 110, # Connection timed out 2461 104, # Connection reset Error 2462 54, # Connection reset by peer 2463 60, # urllib.error.URLError: [Errno 60] Connection timed out 2464) 2465 2466# Both of the above shouldn't mask real issues such as 404's 2467# or refused connections (changed DNS). 2468# But some tests (test_data yahoo) contact incredibly flakey 2469# servers. 2470 2471# and conditionally raise on exception types in _get_default_network_errors 2472 2473 2474def _get_default_network_errors(): 2475 # Lazy import for http.client because it imports many things from the stdlib 2476 import http.client 2477 2478 return (IOError, http.client.HTTPException, TimeoutError) 2479 2480 2481def can_connect(url, error_classes=None): 2482 """ 2483 Try to connect to the given url. True if succeeds, False if IOError 2484 raised 2485 2486 Parameters 2487 ---------- 2488 url : basestring 2489 The URL to try to connect to 2490 2491 Returns 2492 ------- 2493 connectable : bool 2494 Return True if no IOError (unable to connect) or URLError (bad url) was 2495 raised 2496 """ 2497 if error_classes is None: 2498 error_classes = _get_default_network_errors() 2499 2500 try: 2501 with urlopen(url): 2502 pass 2503 except error_classes: 2504 return False 2505 else: 2506 return True 2507 2508 2509@optional_args 2510def network( 2511 t, 2512 url="https://www.google.com", 2513 raise_on_error=_RAISE_NETWORK_ERROR_DEFAULT, 2514 check_before_test=False, 2515 error_classes=None, 2516 skip_errnos=_network_errno_vals, 2517 _skip_on_messages=_network_error_messages, 2518): 2519 """ 2520 Label a test as requiring network connection and, if an error is 2521 encountered, only raise if it does not find a network connection. 2522 2523 In comparison to ``network``, this assumes an added contract to your test: 2524 you must assert that, under normal conditions, your test will ONLY fail if 2525 it does not have network connectivity. 2526 2527 You can call this in 3 ways: as a standard decorator, with keyword 2528 arguments, or with a positional argument that is the url to check. 2529 2530 Parameters 2531 ---------- 2532 t : callable 2533 The test requiring network connectivity. 2534 url : path 2535 The url to test via ``pandas.io.common.urlopen`` to check 2536 for connectivity. Defaults to 'https://www.google.com'. 2537 raise_on_error : bool 2538 If True, never catches errors. 2539 check_before_test : bool 2540 If True, checks connectivity before running the test case. 2541 error_classes : tuple or Exception 2542 error classes to ignore. If not in ``error_classes``, raises the error. 2543 defaults to IOError. Be careful about changing the error classes here. 2544 skip_errnos : iterable of int 2545 Any exception that has .errno or .reason.erno set to one 2546 of these values will be skipped with an appropriate 2547 message. 2548 _skip_on_messages: iterable of string 2549 any exception e for which one of the strings is 2550 a substring of str(e) will be skipped with an appropriate 2551 message. Intended to suppress errors where an errno isn't available. 2552 2553 Notes 2554 ----- 2555 * ``raise_on_error`` supersedes ``check_before_test`` 2556 2557 Returns 2558 ------- 2559 t : callable 2560 The decorated test ``t``, with checks for connectivity errors. 2561 2562 Example 2563 ------- 2564 2565 Tests decorated with @network will fail if it's possible to make a network 2566 connection to another URL (defaults to google.com):: 2567 2568 >>> from pandas._testing import network 2569 >>> from pandas.io.common import urlopen 2570 >>> @network 2571 ... def test_network(): 2572 ... with urlopen("rabbit://bonanza.com"): 2573 ... pass 2574 Traceback 2575 ... 2576 URLError: <urlopen error unknown url type: rabit> 2577 2578 You can specify alternative URLs:: 2579 2580 >>> @network("https://www.yahoo.com") 2581 ... def test_something_with_yahoo(): 2582 ... raise IOError("Failure Message") 2583 >>> test_something_with_yahoo() 2584 Traceback (most recent call last): 2585 ... 2586 IOError: Failure Message 2587 2588 If you set check_before_test, it will check the url first and not run the 2589 test on failure:: 2590 2591 >>> @network("failing://url.blaher", check_before_test=True) 2592 ... def test_something(): 2593 ... print("I ran!") 2594 ... raise ValueError("Failure") 2595 >>> test_something() 2596 Traceback (most recent call last): 2597 ... 2598 2599 Errors not related to networking will always be raised. 2600 """ 2601 from pytest import skip 2602 2603 if error_classes is None: 2604 error_classes = _get_default_network_errors() 2605 2606 t.network = True 2607 2608 @wraps(t) 2609 def wrapper(*args, **kwargs): 2610 if ( 2611 check_before_test 2612 and not raise_on_error 2613 and not can_connect(url, error_classes) 2614 ): 2615 skip() 2616 try: 2617 return t(*args, **kwargs) 2618 except Exception as err: 2619 errno = getattr(err, "errno", None) 2620 if not errno and hasattr(errno, "reason"): 2621 # pandas\_testing.py:2521: error: "Exception" has no attribute 2622 # "reason" 2623 errno = getattr(err.reason, "errno", None) # type: ignore[attr-defined] 2624 2625 if errno in skip_errnos: 2626 skip(f"Skipping test due to known errno and error {err}") 2627 2628 e_str = str(err) 2629 2630 if any(m.lower() in e_str.lower() for m in _skip_on_messages): 2631 skip( 2632 f"Skipping test because exception message is known and error {err}" 2633 ) 2634 2635 if not isinstance(err, error_classes): 2636 raise 2637 2638 if raise_on_error or can_connect(url, error_classes): 2639 raise 2640 else: 2641 skip(f"Skipping test due to lack of connectivity and error {err}") 2642 2643 return wrapper 2644 2645 2646with_connectivity_check = network 2647 2648 2649@contextmanager 2650def assert_produces_warning( 2651 expected_warning: Optional[Union[Type[Warning], bool]] = Warning, 2652 filter_level="always", 2653 check_stacklevel: bool = True, 2654 raise_on_extra_warnings: bool = True, 2655 match: Optional[str] = None, 2656): 2657 """ 2658 Context manager for running code expected to either raise a specific 2659 warning, or not raise any warnings. Verifies that the code raises the 2660 expected warning, and that it does not raise any other unexpected 2661 warnings. It is basically a wrapper around ``warnings.catch_warnings``. 2662 2663 Parameters 2664 ---------- 2665 expected_warning : {Warning, False, None}, default Warning 2666 The type of Exception raised. ``exception.Warning`` is the base 2667 class for all warnings. To check that no warning is returned, 2668 specify ``False`` or ``None``. 2669 filter_level : str or None, default "always" 2670 Specifies whether warnings are ignored, displayed, or turned 2671 into errors. 2672 Valid values are: 2673 2674 * "error" - turns matching warnings into exceptions 2675 * "ignore" - discard the warning 2676 * "always" - always emit a warning 2677 * "default" - print the warning the first time it is generated 2678 from each location 2679 * "module" - print the warning the first time it is generated 2680 from each module 2681 * "once" - print the warning the first time it is generated 2682 2683 check_stacklevel : bool, default True 2684 If True, displays the line that called the function containing 2685 the warning to show were the function is called. Otherwise, the 2686 line that implements the function is displayed. 2687 raise_on_extra_warnings : bool, default True 2688 Whether extra warnings not of the type `expected_warning` should 2689 cause the test to fail. 2690 match : str, optional 2691 Match warning message. 2692 2693 Examples 2694 -------- 2695 >>> import warnings 2696 >>> with assert_produces_warning(): 2697 ... warnings.warn(UserWarning()) 2698 ... 2699 >>> with assert_produces_warning(False): 2700 ... warnings.warn(RuntimeWarning()) 2701 ... 2702 Traceback (most recent call last): 2703 ... 2704 AssertionError: Caused unexpected warning(s): ['RuntimeWarning']. 2705 >>> with assert_produces_warning(UserWarning): 2706 ... warnings.warn(RuntimeWarning()) 2707 Traceback (most recent call last): 2708 ... 2709 AssertionError: Did not see expected warning of class 'UserWarning'. 2710 2711 ..warn:: This is *not* thread-safe. 2712 """ 2713 __tracebackhide__ = True 2714 2715 with warnings.catch_warnings(record=True) as w: 2716 2717 saw_warning = False 2718 matched_message = False 2719 2720 warnings.simplefilter(filter_level) 2721 yield w 2722 extra_warnings = [] 2723 2724 for actual_warning in w: 2725 if not expected_warning: 2726 continue 2727 2728 expected_warning = cast(Type[Warning], expected_warning) 2729 if issubclass(actual_warning.category, expected_warning): 2730 saw_warning = True 2731 2732 if check_stacklevel and issubclass( 2733 actual_warning.category, (FutureWarning, DeprecationWarning) 2734 ): 2735 _assert_raised_with_correct_stacklevel(actual_warning) 2736 2737 if match is not None and re.search(match, str(actual_warning.message)): 2738 matched_message = True 2739 2740 else: 2741 extra_warnings.append( 2742 ( 2743 actual_warning.category.__name__, 2744 actual_warning.message, 2745 actual_warning.filename, 2746 actual_warning.lineno, 2747 ) 2748 ) 2749 2750 if expected_warning: 2751 expected_warning = cast(Type[Warning], expected_warning) 2752 if not saw_warning: 2753 raise AssertionError( 2754 f"Did not see expected warning of class " 2755 f"{repr(expected_warning.__name__)}" 2756 ) 2757 2758 if match and not matched_message: 2759 raise AssertionError( 2760 f"Did not see warning {repr(expected_warning.__name__)} " 2761 f"matching {match}" 2762 ) 2763 2764 if raise_on_extra_warnings and extra_warnings: 2765 raise AssertionError( 2766 f"Caused unexpected warning(s): {repr(extra_warnings)}" 2767 ) 2768 2769 2770def _assert_raised_with_correct_stacklevel( 2771 actual_warning: warnings.WarningMessage, 2772) -> None: 2773 from inspect import getframeinfo, stack 2774 2775 caller = getframeinfo(stack()[3][0]) 2776 msg = ( 2777 "Warning not set with correct stacklevel. " 2778 f"File where warning is raised: {actual_warning.filename} != " 2779 f"{caller.filename}. Warning message: {actual_warning.message}" 2780 ) 2781 assert actual_warning.filename == caller.filename, msg 2782 2783 2784class RNGContext: 2785 """ 2786 Context manager to set the numpy random number generator speed. Returns 2787 to the original value upon exiting the context manager. 2788 2789 Parameters 2790 ---------- 2791 seed : int 2792 Seed for numpy.random.seed 2793 2794 Examples 2795 -------- 2796 with RNGContext(42): 2797 np.random.randn() 2798 """ 2799 2800 def __init__(self, seed): 2801 self.seed = seed 2802 2803 def __enter__(self): 2804 2805 self.start_state = np.random.get_state() 2806 np.random.seed(self.seed) 2807 2808 def __exit__(self, exc_type, exc_value, traceback): 2809 2810 np.random.set_state(self.start_state) 2811 2812 2813@contextmanager 2814def with_csv_dialect(name, **kwargs): 2815 """ 2816 Context manager to temporarily register a CSV dialect for parsing CSV. 2817 2818 Parameters 2819 ---------- 2820 name : str 2821 The name of the dialect. 2822 kwargs : mapping 2823 The parameters for the dialect. 2824 2825 Raises 2826 ------ 2827 ValueError : the name of the dialect conflicts with a builtin one. 2828 2829 See Also 2830 -------- 2831 csv : Python's CSV library. 2832 """ 2833 import csv 2834 2835 _BUILTIN_DIALECTS = {"excel", "excel-tab", "unix"} 2836 2837 if name in _BUILTIN_DIALECTS: 2838 raise ValueError("Cannot override builtin dialect.") 2839 2840 csv.register_dialect(name, **kwargs) 2841 yield 2842 csv.unregister_dialect(name) 2843 2844 2845@contextmanager 2846def use_numexpr(use, min_elements=None): 2847 from pandas.core.computation import expressions as expr 2848 2849 if min_elements is None: 2850 min_elements = expr._MIN_ELEMENTS 2851 2852 olduse = expr.USE_NUMEXPR 2853 oldmin = expr._MIN_ELEMENTS 2854 expr.set_use_numexpr(use) 2855 expr._MIN_ELEMENTS = min_elements 2856 yield 2857 expr._MIN_ELEMENTS = oldmin 2858 expr.set_use_numexpr(olduse) 2859 2860 2861def test_parallel(num_threads=2, kwargs_list=None): 2862 """ 2863 Decorator to run the same function multiple times in parallel. 2864 2865 Parameters 2866 ---------- 2867 num_threads : int, optional 2868 The number of times the function is run in parallel. 2869 kwargs_list : list of dicts, optional 2870 The list of kwargs to update original 2871 function kwargs on different threads. 2872 2873 Notes 2874 ----- 2875 This decorator does not pass the return value of the decorated function. 2876 2877 Original from scikit-image: 2878 2879 https://github.com/scikit-image/scikit-image/pull/1519 2880 2881 """ 2882 assert num_threads > 0 2883 has_kwargs_list = kwargs_list is not None 2884 if has_kwargs_list: 2885 assert len(kwargs_list) == num_threads 2886 import threading 2887 2888 def wrapper(func): 2889 @wraps(func) 2890 def inner(*args, **kwargs): 2891 if has_kwargs_list: 2892 update_kwargs = lambda i: dict(kwargs, **kwargs_list[i]) 2893 else: 2894 update_kwargs = lambda i: kwargs 2895 threads = [] 2896 for i in range(num_threads): 2897 updated_kwargs = update_kwargs(i) 2898 thread = threading.Thread(target=func, args=args, kwargs=updated_kwargs) 2899 threads.append(thread) 2900 for thread in threads: 2901 thread.start() 2902 for thread in threads: 2903 thread.join() 2904 2905 return inner 2906 2907 return wrapper 2908 2909 2910class SubclassedSeries(Series): 2911 _metadata = ["testattr", "name"] 2912 2913 @property 2914 def _constructor(self): 2915 return SubclassedSeries 2916 2917 @property 2918 def _constructor_expanddim(self): 2919 return SubclassedDataFrame 2920 2921 2922class SubclassedDataFrame(DataFrame): 2923 _metadata = ["testattr"] 2924 2925 @property 2926 def _constructor(self): 2927 return SubclassedDataFrame 2928 2929 @property 2930 def _constructor_sliced(self): 2931 return SubclassedSeries 2932 2933 2934class SubclassedCategorical(Categorical): 2935 @property 2936 def _constructor(self): 2937 return SubclassedCategorical 2938 2939 2940@contextmanager 2941def set_timezone(tz: str): 2942 """ 2943 Context manager for temporarily setting a timezone. 2944 2945 Parameters 2946 ---------- 2947 tz : str 2948 A string representing a valid timezone. 2949 2950 Examples 2951 -------- 2952 >>> from datetime import datetime 2953 >>> from dateutil.tz import tzlocal 2954 >>> tzlocal().tzname(datetime.now()) 2955 'IST' 2956 2957 >>> with set_timezone('US/Eastern'): 2958 ... tzlocal().tzname(datetime.now()) 2959 ... 2960 'EDT' 2961 """ 2962 import os 2963 import time 2964 2965 def setTZ(tz): 2966 if tz is None: 2967 try: 2968 del os.environ["TZ"] 2969 except KeyError: 2970 pass 2971 else: 2972 os.environ["TZ"] = tz 2973 time.tzset() 2974 2975 orig_tz = os.environ.get("TZ") 2976 setTZ(tz) 2977 try: 2978 yield 2979 finally: 2980 setTZ(orig_tz) 2981 2982 2983def _make_skipna_wrapper(alternative, skipna_alternative=None): 2984 """ 2985 Create a function for calling on an array. 2986 2987 Parameters 2988 ---------- 2989 alternative : function 2990 The function to be called on the array with no NaNs. 2991 Only used when 'skipna_alternative' is None. 2992 skipna_alternative : function 2993 The function to be called on the original array 2994 2995 Returns 2996 ------- 2997 function 2998 """ 2999 if skipna_alternative: 3000 3001 def skipna_wrapper(x): 3002 return skipna_alternative(x.values) 3003 3004 else: 3005 3006 def skipna_wrapper(x): 3007 nona = x.dropna() 3008 if len(nona) == 0: 3009 return np.nan 3010 return alternative(nona) 3011 3012 return skipna_wrapper 3013 3014 3015def convert_rows_list_to_csv_str(rows_list: List[str]): 3016 """ 3017 Convert list of CSV rows to single CSV-formatted string for current OS. 3018 3019 This method is used for creating expected value of to_csv() method. 3020 3021 Parameters 3022 ---------- 3023 rows_list : List[str] 3024 Each element represents the row of csv. 3025 3026 Returns 3027 ------- 3028 str 3029 Expected output of to_csv() in current OS. 3030 """ 3031 sep = os.linesep 3032 return sep.join(rows_list) + sep 3033 3034 3035def external_error_raised(expected_exception: Type[Exception]) -> ContextManager: 3036 """ 3037 Helper function to mark pytest.raises that have an external error message. 3038 3039 Parameters 3040 ---------- 3041 expected_exception : Exception 3042 Expected error to raise. 3043 3044 Returns 3045 ------- 3046 Callable 3047 Regular `pytest.raises` function with `match` equal to `None`. 3048 """ 3049 import pytest 3050 3051 return pytest.raises(expected_exception, match=None) 3052 3053 3054cython_table = pd.core.base.SelectionMixin._cython_table.items() 3055 3056 3057def get_cython_table_params(ndframe, func_names_and_expected): 3058 """ 3059 Combine frame, functions from SelectionMixin._cython_table 3060 keys and expected result. 3061 3062 Parameters 3063 ---------- 3064 ndframe : DataFrame or Series 3065 func_names_and_expected : Sequence of two items 3066 The first item is a name of a NDFrame method ('sum', 'prod') etc. 3067 The second item is the expected return value. 3068 3069 Returns 3070 ------- 3071 list 3072 List of three items (DataFrame, function, expected result) 3073 """ 3074 results = [] 3075 for func_name, expected in func_names_and_expected: 3076 results.append((ndframe, func_name, expected)) 3077 results += [ 3078 (ndframe, func, expected) 3079 for func, name in cython_table 3080 if name == func_name 3081 ] 3082 return results 3083 3084 3085def get_op_from_name(op_name: str) -> Callable: 3086 """ 3087 The operator function for a given op name. 3088 3089 Parameters 3090 ---------- 3091 op_name : string 3092 The op name, in form of "add" or "__add__". 3093 3094 Returns 3095 ------- 3096 function 3097 A function performing the operation. 3098 """ 3099 short_opname = op_name.strip("_") 3100 try: 3101 op = getattr(operator, short_opname) 3102 except AttributeError: 3103 # Assume it is the reverse operator 3104 rop = getattr(operator, short_opname[1:]) 3105 op = lambda x, y: rop(y, x) 3106 3107 return op 3108