1import bz2
2from collections import Counter
3from contextlib import contextmanager
4from datetime import datetime
5from functools import wraps
6import gzip
7import operator
8import os
9from pathlib import Path
10import random
11import re
12from shutil import rmtree
13import string
14import tempfile
15from typing import IO, Any, Callable, ContextManager, List, Optional, Type, Union, cast
16import warnings
17import zipfile
18
19import numpy as np
20from numpy.random import rand, randn
21
22from pandas._config.localization import (  # noqa:F401
23    can_set_locale,
24    get_locales,
25    set_locale,
26)
27
28from pandas._libs.lib import no_default
29import pandas._libs.testing as _testing
30from pandas._typing import Dtype, FilePathOrBuffer, FrameOrSeries
31from pandas.compat import get_lzma_file, import_lzma
32
33from pandas.core.dtypes.common import (
34    is_bool,
35    is_categorical_dtype,
36    is_datetime64_dtype,
37    is_datetime64tz_dtype,
38    is_extension_array_dtype,
39    is_interval_dtype,
40    is_number,
41    is_numeric_dtype,
42    is_period_dtype,
43    is_sequence,
44    is_timedelta64_dtype,
45    needs_i8_conversion,
46)
47from pandas.core.dtypes.missing import array_equivalent
48
49import pandas as pd
50from pandas import (
51    Categorical,
52    CategoricalIndex,
53    DataFrame,
54    DatetimeIndex,
55    Index,
56    IntervalIndex,
57    MultiIndex,
58    RangeIndex,
59    Series,
60    bdate_range,
61)
62from pandas.core.algorithms import safe_sort, take_1d
63from pandas.core.arrays import (
64    DatetimeArray,
65    ExtensionArray,
66    IntervalArray,
67    PeriodArray,
68    TimedeltaArray,
69    period_array,
70)
71from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin
72
73from pandas.io.common import urlopen
74from pandas.io.formats.printing import pprint_thing
75
76lzma = import_lzma()
77
78_N = 30
79_K = 4
80_RAISE_NETWORK_ERROR_DEFAULT = False
81
82UNSIGNED_INT_DTYPES: List[Dtype] = ["uint8", "uint16", "uint32", "uint64"]
83UNSIGNED_EA_INT_DTYPES: List[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"]
84SIGNED_INT_DTYPES: List[Dtype] = [int, "int8", "int16", "int32", "int64"]
85SIGNED_EA_INT_DTYPES: List[Dtype] = ["Int8", "Int16", "Int32", "Int64"]
86ALL_INT_DTYPES = UNSIGNED_INT_DTYPES + SIGNED_INT_DTYPES
87ALL_EA_INT_DTYPES = UNSIGNED_EA_INT_DTYPES + SIGNED_EA_INT_DTYPES
88
89FLOAT_DTYPES: List[Dtype] = [float, "float32", "float64"]
90FLOAT_EA_DTYPES: List[Dtype] = ["Float32", "Float64"]
91COMPLEX_DTYPES: List[Dtype] = [complex, "complex64", "complex128"]
92STRING_DTYPES: List[Dtype] = [str, "str", "U"]
93
94DATETIME64_DTYPES: List[Dtype] = ["datetime64[ns]", "M8[ns]"]
95TIMEDELTA64_DTYPES: List[Dtype] = ["timedelta64[ns]", "m8[ns]"]
96
97BOOL_DTYPES = [bool, "bool"]
98BYTES_DTYPES = [bytes, "bytes"]
99OBJECT_DTYPES = [object, "object"]
100
101ALL_REAL_DTYPES = FLOAT_DTYPES + ALL_INT_DTYPES
102ALL_NUMPY_DTYPES = (
103    ALL_REAL_DTYPES
104    + COMPLEX_DTYPES
105    + STRING_DTYPES
106    + DATETIME64_DTYPES
107    + TIMEDELTA64_DTYPES
108    + BOOL_DTYPES
109    + OBJECT_DTYPES
110    + BYTES_DTYPES
111)
112
113NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA]
114
115
116# set testing_mode
117_testing_mode_warnings = (DeprecationWarning, ResourceWarning)
118
119
120def set_testing_mode():
121    # set the testing mode filters
122    testing_mode = os.environ.get("PANDAS_TESTING_MODE", "None")
123    if "deprecate" in testing_mode:
124        # pandas\_testing.py:119: error: Argument 2 to "simplefilter" has
125        # incompatible type "Tuple[Type[DeprecationWarning],
126        # Type[ResourceWarning]]"; expected "Type[Warning]"
127        warnings.simplefilter(
128            "always", _testing_mode_warnings  # type: ignore[arg-type]
129        )
130
131
132def reset_testing_mode():
133    # reset the testing mode filters
134    testing_mode = os.environ.get("PANDAS_TESTING_MODE", "None")
135    if "deprecate" in testing_mode:
136        # pandas\_testing.py:126: error: Argument 2 to "simplefilter" has
137        # incompatible type "Tuple[Type[DeprecationWarning],
138        # Type[ResourceWarning]]"; expected "Type[Warning]"
139        warnings.simplefilter(
140            "ignore", _testing_mode_warnings  # type: ignore[arg-type]
141        )
142
143
144set_testing_mode()
145
146
147def reset_display_options():
148    """
149    Reset the display options for printing and representing objects.
150    """
151    pd.reset_option("^display.", silent=True)
152
153
154def round_trip_pickle(
155    obj: Any, path: Optional[FilePathOrBuffer] = None
156) -> FrameOrSeries:
157    """
158    Pickle an object and then read it again.
159
160    Parameters
161    ----------
162    obj : any object
163        The object to pickle and then re-read.
164    path : str, path object or file-like object, default None
165        The path where the pickled object is written and then read.
166
167    Returns
168    -------
169    pandas object
170        The original object that was pickled and then re-read.
171    """
172    _path = path
173    if _path is None:
174        _path = f"__{rands(10)}__.pickle"
175    with ensure_clean(_path) as temp_path:
176        pd.to_pickle(obj, temp_path)
177        return pd.read_pickle(temp_path)
178
179
180def round_trip_pathlib(writer, reader, path: Optional[str] = None):
181    """
182    Write an object to file specified by a pathlib.Path and read it back
183
184    Parameters
185    ----------
186    writer : callable bound to pandas object
187        IO writing function (e.g. DataFrame.to_csv )
188    reader : callable
189        IO reading function (e.g. pd.read_csv )
190    path : str, default None
191        The path where the object is written and then read.
192
193    Returns
194    -------
195    pandas object
196        The original object that was serialized and then re-read.
197    """
198    import pytest
199
200    Path = pytest.importorskip("pathlib").Path
201    if path is None:
202        path = "___pathlib___"
203    with ensure_clean(path) as path:
204        writer(Path(path))
205        obj = reader(Path(path))
206    return obj
207
208
209def round_trip_localpath(writer, reader, path: Optional[str] = None):
210    """
211    Write an object to file specified by a py.path LocalPath and read it back.
212
213    Parameters
214    ----------
215    writer : callable bound to pandas object
216        IO writing function (e.g. DataFrame.to_csv )
217    reader : callable
218        IO reading function (e.g. pd.read_csv )
219    path : str, default None
220        The path where the object is written and then read.
221
222    Returns
223    -------
224    pandas object
225        The original object that was serialized and then re-read.
226    """
227    import pytest
228
229    LocalPath = pytest.importorskip("py.path").local
230    if path is None:
231        path = "___localpath___"
232    with ensure_clean(path) as path:
233        writer(LocalPath(path))
234        obj = reader(LocalPath(path))
235    return obj
236
237
238@contextmanager
239def decompress_file(path, compression):
240    """
241    Open a compressed file and return a file object.
242
243    Parameters
244    ----------
245    path : str
246        The path where the file is read from.
247
248    compression : {'gzip', 'bz2', 'zip', 'xz', None}
249        Name of the decompression to use
250
251    Returns
252    -------
253    file object
254    """
255    if compression is None:
256        f = open(path, "rb")
257    elif compression == "gzip":
258        # pandas\_testing.py:243: error: Incompatible types in assignment
259        # (expression has type "IO[Any]", variable has type "BinaryIO")
260        f = gzip.open(path, "rb")  # type: ignore[assignment]
261    elif compression == "bz2":
262        # pandas\_testing.py:245: error: Incompatible types in assignment
263        # (expression has type "BZ2File", variable has type "BinaryIO")
264        f = bz2.BZ2File(path, "rb")  # type: ignore[assignment]
265    elif compression == "xz":
266        f = get_lzma_file(lzma)(path, "rb")
267    elif compression == "zip":
268        zip_file = zipfile.ZipFile(path)
269        zip_names = zip_file.namelist()
270        if len(zip_names) == 1:
271            # pandas\_testing.py:252: error: Incompatible types in assignment
272            # (expression has type "IO[bytes]", variable has type "BinaryIO")
273            f = zip_file.open(zip_names.pop())  # type: ignore[assignment]
274        else:
275            raise ValueError(f"ZIP file {path} error. Only one file per ZIP.")
276    else:
277        raise ValueError(f"Unrecognized compression type: {compression}")
278
279    try:
280        yield f
281    finally:
282        f.close()
283        if compression == "zip":
284            zip_file.close()
285
286
287def write_to_compressed(compression, path, data, dest="test"):
288    """
289    Write data to a compressed file.
290
291    Parameters
292    ----------
293    compression : {'gzip', 'bz2', 'zip', 'xz'}
294        The compression type to use.
295    path : str
296        The file path to write the data.
297    data : str
298        The data to write.
299    dest : str, default "test"
300        The destination file (for ZIP only)
301
302    Raises
303    ------
304    ValueError : An invalid compression value was passed in.
305    """
306    if compression == "zip":
307        compress_method = zipfile.ZipFile
308    elif compression == "gzip":
309        # pandas\_testing.py:288: error: Incompatible types in assignment
310        # (expression has type "Type[GzipFile]", variable has type
311        # "Type[ZipFile]")
312        compress_method = gzip.GzipFile  # type: ignore[assignment]
313    elif compression == "bz2":
314        # pandas\_testing.py:290: error: Incompatible types in assignment
315        # (expression has type "Type[BZ2File]", variable has type
316        # "Type[ZipFile]")
317        compress_method = bz2.BZ2File  # type: ignore[assignment]
318    elif compression == "xz":
319        compress_method = get_lzma_file(lzma)
320    else:
321        raise ValueError(f"Unrecognized compression type: {compression}")
322
323    if compression == "zip":
324        mode = "w"
325        args = (dest, data)
326        method = "writestr"
327    else:
328        mode = "wb"
329        # pandas\_testing.py:302: error: Incompatible types in assignment
330        # (expression has type "Tuple[Any]", variable has type "Tuple[Any,
331        # Any]")
332        args = (data,)  # type: ignore[assignment]
333        method = "write"
334
335    with compress_method(path, mode=mode) as f:
336        getattr(f, method)(*args)
337
338
339def _get_tol_from_less_precise(check_less_precise: Union[bool, int]) -> float:
340    """
341    Return the tolerance equivalent to the deprecated `check_less_precise`
342    parameter.
343
344    Parameters
345    ----------
346    check_less_precise : bool or int
347
348    Returns
349    -------
350    float
351        Tolerance to be used as relative/absolute tolerance.
352
353    Examples
354    --------
355    >>> # Using check_less_precise as a bool:
356    >>> _get_tol_from_less_precise(False)
357    0.5e-5
358    >>> _get_tol_from_less_precise(True)
359    0.5e-3
360    >>> # Using check_less_precise as an int representing the decimal
361    >>> # tolerance intended:
362    >>> _get_tol_from_less_precise(2)
363    0.5e-2
364    >>> _get_tol_from_less_precise(8)
365    0.5e-8
366
367    """
368    if isinstance(check_less_precise, bool):
369        if check_less_precise:
370            # 3-digit tolerance
371            return 0.5e-3
372        else:
373            # 5-digit tolerance
374            return 0.5e-5
375    else:
376        # Equivalent to setting checking_less_precise=<decimals>
377        return 0.5 * 10 ** -check_less_precise
378
379
380def assert_almost_equal(
381    left,
382    right,
383    check_dtype: Union[bool, str] = "equiv",
384    check_less_precise: Union[bool, int] = no_default,
385    rtol: float = 1.0e-5,
386    atol: float = 1.0e-8,
387    **kwargs,
388):
389    """
390    Check that the left and right objects are approximately equal.
391
392    By approximately equal, we refer to objects that are numbers or that
393    contain numbers which may be equivalent to specific levels of precision.
394
395    Parameters
396    ----------
397    left : object
398    right : object
399    check_dtype : bool or {'equiv'}, default 'equiv'
400        Check dtype if both a and b are the same type. If 'equiv' is passed in,
401        then `RangeIndex` and `Int64Index` are also considered equivalent
402        when doing type checking.
403    check_less_precise : bool or int, default False
404        Specify comparison precision. 5 digits (False) or 3 digits (True)
405        after decimal points are compared. If int, then specify the number
406        of digits to compare.
407
408        When comparing two numbers, if the first number has magnitude less
409        than 1e-5, we compare the two numbers directly and check whether
410        they are equivalent within the specified precision. Otherwise, we
411        compare the **ratio** of the second number to the first number and
412        check whether it is equivalent to 1 within the specified precision.
413
414        .. deprecated:: 1.1.0
415           Use `rtol` and `atol` instead to define relative/absolute
416           tolerance, respectively. Similar to :func:`math.isclose`.
417    rtol : float, default 1e-5
418        Relative tolerance.
419
420        .. versionadded:: 1.1.0
421    atol : float, default 1e-8
422        Absolute tolerance.
423
424        .. versionadded:: 1.1.0
425    """
426    if check_less_precise is not no_default:
427        warnings.warn(
428            "The 'check_less_precise' keyword in testing.assert_*_equal "
429            "is deprecated and will be removed in a future version. "
430            "You can stop passing 'check_less_precise' to silence this warning.",
431            FutureWarning,
432            stacklevel=2,
433        )
434        rtol = atol = _get_tol_from_less_precise(check_less_precise)
435
436    if isinstance(left, pd.Index):
437        assert_index_equal(
438            left,
439            right,
440            check_exact=False,
441            exact=check_dtype,
442            rtol=rtol,
443            atol=atol,
444            **kwargs,
445        )
446
447    elif isinstance(left, pd.Series):
448        assert_series_equal(
449            left,
450            right,
451            check_exact=False,
452            check_dtype=check_dtype,
453            rtol=rtol,
454            atol=atol,
455            **kwargs,
456        )
457
458    elif isinstance(left, pd.DataFrame):
459        assert_frame_equal(
460            left,
461            right,
462            check_exact=False,
463            check_dtype=check_dtype,
464            rtol=rtol,
465            atol=atol,
466            **kwargs,
467        )
468
469    else:
470        # Other sequences.
471        if check_dtype:
472            if is_number(left) and is_number(right):
473                # Do not compare numeric classes, like np.float64 and float.
474                pass
475            elif is_bool(left) and is_bool(right):
476                # Do not compare bool classes, like np.bool_ and bool.
477                pass
478            else:
479                if isinstance(left, np.ndarray) or isinstance(right, np.ndarray):
480                    obj = "numpy array"
481                else:
482                    obj = "Input"
483                assert_class_equal(left, right, obj=obj)
484        _testing.assert_almost_equal(
485            left, right, check_dtype=check_dtype, rtol=rtol, atol=atol, **kwargs
486        )
487
488
489def _check_isinstance(left, right, cls):
490    """
491    Helper method for our assert_* methods that ensures that
492    the two objects being compared have the right type before
493    proceeding with the comparison.
494
495    Parameters
496    ----------
497    left : The first object being compared.
498    right : The second object being compared.
499    cls : The class type to check against.
500
501    Raises
502    ------
503    AssertionError : Either `left` or `right` is not an instance of `cls`.
504    """
505    cls_name = cls.__name__
506
507    if not isinstance(left, cls):
508        raise AssertionError(
509            f"{cls_name} Expected type {cls}, found {type(left)} instead"
510        )
511    if not isinstance(right, cls):
512        raise AssertionError(
513            f"{cls_name} Expected type {cls}, found {type(right)} instead"
514        )
515
516
517def assert_dict_equal(left, right, compare_keys: bool = True):
518
519    _check_isinstance(left, right, dict)
520    _testing.assert_dict_equal(left, right, compare_keys=compare_keys)
521
522
523def randbool(size=(), p: float = 0.5):
524    return rand(*size) <= p
525
526
527RANDS_CHARS = np.array(list(string.ascii_letters + string.digits), dtype=(np.str_, 1))
528RANDU_CHARS = np.array(
529    list("".join(map(chr, range(1488, 1488 + 26))) + string.digits),
530    dtype=(np.unicode_, 1),
531)
532
533
534def rands_array(nchars, size, dtype="O"):
535    """
536    Generate an array of byte strings.
537    """
538    retval = (
539        np.random.choice(RANDS_CHARS, size=nchars * np.prod(size))
540        .view((np.str_, nchars))
541        .reshape(size)
542    )
543    return retval.astype(dtype)
544
545
546def randu_array(nchars, size, dtype="O"):
547    """
548    Generate an array of unicode strings.
549    """
550    retval = (
551        np.random.choice(RANDU_CHARS, size=nchars * np.prod(size))
552        .view((np.unicode_, nchars))
553        .reshape(size)
554    )
555    return retval.astype(dtype)
556
557
558def rands(nchars):
559    """
560    Generate one random byte string.
561
562    See `rands_array` if you want to create an array of random strings.
563
564    """
565    return "".join(np.random.choice(RANDS_CHARS, nchars))
566
567
568def close(fignum=None):
569    from matplotlib.pyplot import close as _close, get_fignums
570
571    if fignum is None:
572        for fignum in get_fignums():
573            _close(fignum)
574    else:
575        _close(fignum)
576
577
578# -----------------------------------------------------------------------------
579# contextmanager to ensure the file cleanup
580
581
582@contextmanager
583def ensure_clean(filename=None, return_filelike: bool = False, **kwargs: Any):
584    """
585    Gets a temporary path and agrees to remove on close.
586
587    This implementation does not use tempfile.mkstemp to avoid having a file handle.
588    If the code using the returned path wants to delete the file itself, windows
589    requires that no program has a file handle to it.
590
591    Parameters
592    ----------
593    filename : str (optional)
594        suffix of the created file.
595    return_filelike : bool (default False)
596        if True, returns a file-like which is *always* cleaned. Necessary for
597        savefig and other functions which want to append extensions.
598    **kwargs
599        Additional keywords are passed to open().
600
601    """
602    folder = Path(tempfile.gettempdir())
603
604    if filename is None:
605        filename = ""
606    filename = (
607        "".join(random.choices(string.ascii_letters + string.digits, k=30)) + filename
608    )
609    path = folder / filename
610
611    path.touch()
612
613    handle_or_str: Union[str, IO] = str(path)
614    if return_filelike:
615        kwargs.setdefault("mode", "w+b")
616        handle_or_str = open(path, **kwargs)
617
618    try:
619        yield handle_or_str
620    finally:
621        if not isinstance(handle_or_str, str):
622            handle_or_str.close()
623        if path.is_file():
624            path.unlink()
625
626
627@contextmanager
628def ensure_clean_dir():
629    """
630    Get a temporary directory path and agrees to remove on close.
631
632    Yields
633    ------
634    Temporary directory path
635    """
636    directory_name = tempfile.mkdtemp(suffix="")
637    try:
638        yield directory_name
639    finally:
640        try:
641            rmtree(directory_name)
642        except OSError:
643            pass
644
645
646@contextmanager
647def ensure_safe_environment_variables():
648    """
649    Get a context manager to safely set environment variables
650
651    All changes will be undone on close, hence environment variables set
652    within this contextmanager will neither persist nor change global state.
653    """
654    saved_environ = dict(os.environ)
655    try:
656        yield
657    finally:
658        os.environ.clear()
659        os.environ.update(saved_environ)
660
661
662# -----------------------------------------------------------------------------
663# Comparators
664
665
666def equalContents(arr1, arr2) -> bool:
667    """
668    Checks if the set of unique elements of arr1 and arr2 are equivalent.
669    """
670    return frozenset(arr1) == frozenset(arr2)
671
672
673def assert_index_equal(
674    left: Index,
675    right: Index,
676    exact: Union[bool, str] = "equiv",
677    check_names: bool = True,
678    check_less_precise: Union[bool, int] = no_default,
679    check_exact: bool = True,
680    check_categorical: bool = True,
681    check_order: bool = True,
682    rtol: float = 1.0e-5,
683    atol: float = 1.0e-8,
684    obj: str = "Index",
685) -> None:
686    """
687    Check that left and right Index are equal.
688
689    Parameters
690    ----------
691    left : Index
692    right : Index
693    exact : bool or {'equiv'}, default 'equiv'
694        Whether to check the Index class, dtype and inferred_type
695        are identical. If 'equiv', then RangeIndex can be substituted for
696        Int64Index as well.
697    check_names : bool, default True
698        Whether to check the names attribute.
699    check_less_precise : bool or int, default False
700        Specify comparison precision. Only used when check_exact is False.
701        5 digits (False) or 3 digits (True) after decimal points are compared.
702        If int, then specify the digits to compare.
703
704        .. deprecated:: 1.1.0
705           Use `rtol` and `atol` instead to define relative/absolute
706           tolerance, respectively. Similar to :func:`math.isclose`.
707    check_exact : bool, default True
708        Whether to compare number exactly.
709    check_categorical : bool, default True
710        Whether to compare internal Categorical exactly.
711    check_order : bool, default True
712        Whether to compare the order of index entries as well as their values.
713        If True, both indexes must contain the same elements, in the same order.
714        If False, both indexes must contain the same elements, but in any order.
715
716        .. versionadded:: 1.2.0
717    rtol : float, default 1e-5
718        Relative tolerance. Only used when check_exact is False.
719
720        .. versionadded:: 1.1.0
721    atol : float, default 1e-8
722        Absolute tolerance. Only used when check_exact is False.
723
724        .. versionadded:: 1.1.0
725    obj : str, default 'Index'
726        Specify object name being compared, internally used to show appropriate
727        assertion message.
728
729    Examples
730    --------
731    >>> from pandas.testing import assert_index_equal
732    >>> a = pd.Index([1, 2, 3])
733    >>> b = pd.Index([1, 2, 3])
734    >>> assert_index_equal(a, b)
735    """
736    __tracebackhide__ = True
737
738    def _check_types(left, right, obj="Index"):
739        if exact:
740            assert_class_equal(left, right, exact=exact, obj=obj)
741
742            # Skip exact dtype checking when `check_categorical` is False
743            if check_categorical:
744                assert_attr_equal("dtype", left, right, obj=obj)
745
746            # allow string-like to have different inferred_types
747            if left.inferred_type in ("string"):
748                assert right.inferred_type in ("string")
749            else:
750                assert_attr_equal("inferred_type", left, right, obj=obj)
751
752    def _get_ilevel_values(index, level):
753        # accept level number only
754        unique = index.levels[level]
755        level_codes = index.codes[level]
756        filled = take_1d(unique._values, level_codes, fill_value=unique._na_value)
757        return unique._shallow_copy(filled, name=index.names[level])
758
759    if check_less_precise is not no_default:
760        warnings.warn(
761            "The 'check_less_precise' keyword in testing.assert_*_equal "
762            "is deprecated and will be removed in a future version. "
763            "You can stop passing 'check_less_precise' to silence this warning.",
764            FutureWarning,
765            stacklevel=2,
766        )
767        rtol = atol = _get_tol_from_less_precise(check_less_precise)
768
769    # instance validation
770    _check_isinstance(left, right, Index)
771
772    # class / dtype comparison
773    _check_types(left, right, obj=obj)
774
775    # level comparison
776    if left.nlevels != right.nlevels:
777        msg1 = f"{obj} levels are different"
778        msg2 = f"{left.nlevels}, {left}"
779        msg3 = f"{right.nlevels}, {right}"
780        raise_assert_detail(obj, msg1, msg2, msg3)
781
782    # length comparison
783    if len(left) != len(right):
784        msg1 = f"{obj} length are different"
785        msg2 = f"{len(left)}, {left}"
786        msg3 = f"{len(right)}, {right}"
787        raise_assert_detail(obj, msg1, msg2, msg3)
788
789    # If order doesn't matter then sort the index entries
790    if not check_order:
791        left = Index(safe_sort(left))
792        right = Index(safe_sort(right))
793
794    # MultiIndex special comparison for little-friendly error messages
795    if left.nlevels > 1:
796        left = cast(MultiIndex, left)
797        right = cast(MultiIndex, right)
798
799        for level in range(left.nlevels):
800            # cannot use get_level_values here because it can change dtype
801            llevel = _get_ilevel_values(left, level)
802            rlevel = _get_ilevel_values(right, level)
803
804            lobj = f"MultiIndex level [{level}]"
805            assert_index_equal(
806                llevel,
807                rlevel,
808                exact=exact,
809                check_names=check_names,
810                check_exact=check_exact,
811                rtol=rtol,
812                atol=atol,
813                obj=lobj,
814            )
815            # get_level_values may change dtype
816            _check_types(left.levels[level], right.levels[level], obj=obj)
817
818    # skip exact index checking when `check_categorical` is False
819    if check_exact and check_categorical:
820        if not left.equals(right):
821            diff = np.sum((left.values != right.values).astype(int)) * 100.0 / len(left)
822            msg = f"{obj} values are different ({np.round(diff, 5)} %)"
823            raise_assert_detail(obj, msg, left, right)
824    else:
825        _testing.assert_almost_equal(
826            left.values,
827            right.values,
828            rtol=rtol,
829            atol=atol,
830            check_dtype=exact,
831            obj=obj,
832            lobj=left,
833            robj=right,
834        )
835
836    # metadata comparison
837    if check_names:
838        assert_attr_equal("names", left, right, obj=obj)
839    if isinstance(left, pd.PeriodIndex) or isinstance(right, pd.PeriodIndex):
840        assert_attr_equal("freq", left, right, obj=obj)
841    if isinstance(left, pd.IntervalIndex) or isinstance(right, pd.IntervalIndex):
842        assert_interval_array_equal(left._values, right._values)
843
844    if check_categorical:
845        if is_categorical_dtype(left.dtype) or is_categorical_dtype(right.dtype):
846            assert_categorical_equal(left._values, right._values, obj=f"{obj} category")
847
848
849def assert_class_equal(left, right, exact: Union[bool, str] = True, obj="Input"):
850    """
851    Checks classes are equal.
852    """
853    __tracebackhide__ = True
854
855    def repr_class(x):
856        if isinstance(x, Index):
857            # return Index as it is to include values in the error message
858            return x
859
860        return type(x).__name__
861
862    if exact == "equiv":
863        if type(left) != type(right):
864            # allow equivalence of Int64Index/RangeIndex
865            types = {type(left).__name__, type(right).__name__}
866            if len(types - {"Int64Index", "RangeIndex"}):
867                msg = f"{obj} classes are not equivalent"
868                raise_assert_detail(obj, msg, repr_class(left), repr_class(right))
869    elif exact:
870        if type(left) != type(right):
871            msg = f"{obj} classes are different"
872            raise_assert_detail(obj, msg, repr_class(left), repr_class(right))
873
874
875def assert_attr_equal(attr: str, left, right, obj: str = "Attributes"):
876    """
877    Check attributes are equal. Both objects must have attribute.
878
879    Parameters
880    ----------
881    attr : str
882        Attribute name being compared.
883    left : object
884    right : object
885    obj : str, default 'Attributes'
886        Specify object name being compared, internally used to show appropriate
887        assertion message
888    """
889    __tracebackhide__ = True
890
891    left_attr = getattr(left, attr)
892    right_attr = getattr(right, attr)
893
894    if left_attr is right_attr:
895        return True
896    elif (
897        is_number(left_attr)
898        and np.isnan(left_attr)
899        and is_number(right_attr)
900        and np.isnan(right_attr)
901    ):
902        # np.nan
903        return True
904
905    try:
906        result = left_attr == right_attr
907    except TypeError:
908        # datetimetz on rhs may raise TypeError
909        result = False
910    if not isinstance(result, bool):
911        result = result.all()
912
913    if result:
914        return True
915    else:
916        msg = f'Attribute "{attr}" are different'
917        raise_assert_detail(obj, msg, left_attr, right_attr)
918
919
920def assert_is_valid_plot_return_object(objs):
921    import matplotlib.pyplot as plt
922
923    if isinstance(objs, (pd.Series, np.ndarray)):
924        for el in objs.ravel():
925            msg = (
926                "one of 'objs' is not a matplotlib Axes instance, "
927                f"type encountered {repr(type(el).__name__)}"
928            )
929            assert isinstance(el, (plt.Axes, dict)), msg
930    else:
931        msg = (
932            "objs is neither an ndarray of Artist instances nor a single "
933            "ArtistArtist instance, tuple, or dict, 'objs' is a "
934            f"{repr(type(objs).__name__)}"
935        )
936        assert isinstance(objs, (plt.Artist, tuple, dict)), msg
937
938
939def assert_is_sorted(seq):
940    """Assert that the sequence is sorted."""
941    if isinstance(seq, (Index, Series)):
942        seq = seq.values
943    # sorting does not change precisions
944    assert_numpy_array_equal(seq, np.sort(np.array(seq)))
945
946
947def assert_categorical_equal(
948    left, right, check_dtype=True, check_category_order=True, obj="Categorical"
949):
950    """
951    Test that Categoricals are equivalent.
952
953    Parameters
954    ----------
955    left : Categorical
956    right : Categorical
957    check_dtype : bool, default True
958        Check that integer dtype of the codes are the same
959    check_category_order : bool, default True
960        Whether the order of the categories should be compared, which
961        implies identical integer codes.  If False, only the resulting
962        values are compared.  The ordered attribute is
963        checked regardless.
964    obj : str, default 'Categorical'
965        Specify object name being compared, internally used to show appropriate
966        assertion message
967    """
968    _check_isinstance(left, right, Categorical)
969
970    if check_category_order:
971        assert_index_equal(left.categories, right.categories, obj=f"{obj}.categories")
972        assert_numpy_array_equal(
973            left.codes, right.codes, check_dtype=check_dtype, obj=f"{obj}.codes"
974        )
975    else:
976        try:
977            lc = left.categories.sort_values()
978            rc = right.categories.sort_values()
979        except TypeError:
980            # e.g. '<' not supported between instances of 'int' and 'str'
981            lc, rc = left.categories, right.categories
982        assert_index_equal(lc, rc, obj=f"{obj}.categories")
983        assert_index_equal(
984            left.categories.take(left.codes),
985            right.categories.take(right.codes),
986            obj=f"{obj}.values",
987        )
988
989    assert_attr_equal("ordered", left, right, obj=obj)
990
991
992def assert_interval_array_equal(left, right, exact="equiv", obj="IntervalArray"):
993    """
994    Test that two IntervalArrays are equivalent.
995
996    Parameters
997    ----------
998    left, right : IntervalArray
999        The IntervalArrays to compare.
1000    exact : bool or {'equiv'}, default 'equiv'
1001        Whether to check the Index class, dtype and inferred_type
1002        are identical. If 'equiv', then RangeIndex can be substituted for
1003        Int64Index as well.
1004    obj : str, default 'IntervalArray'
1005        Specify object name being compared, internally used to show appropriate
1006        assertion message
1007    """
1008    _check_isinstance(left, right, IntervalArray)
1009
1010    kwargs = {}
1011    if left._left.dtype.kind in ["m", "M"]:
1012        # We have a DatetimeArray or TimedeltaArray
1013        kwargs["check_freq"] = False
1014
1015    assert_equal(left._left, right._left, obj=f"{obj}.left", **kwargs)
1016    assert_equal(left._right, right._right, obj=f"{obj}.left", **kwargs)
1017
1018    assert_attr_equal("closed", left, right, obj=obj)
1019
1020
1021def assert_period_array_equal(left, right, obj="PeriodArray"):
1022    _check_isinstance(left, right, PeriodArray)
1023
1024    assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data")
1025    assert_attr_equal("freq", left, right, obj=obj)
1026
1027
1028def assert_datetime_array_equal(left, right, obj="DatetimeArray", check_freq=True):
1029    __tracebackhide__ = True
1030    _check_isinstance(left, right, DatetimeArray)
1031
1032    assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data")
1033    if check_freq:
1034        assert_attr_equal("freq", left, right, obj=obj)
1035    assert_attr_equal("tz", left, right, obj=obj)
1036
1037
1038def assert_timedelta_array_equal(left, right, obj="TimedeltaArray", check_freq=True):
1039    __tracebackhide__ = True
1040    _check_isinstance(left, right, TimedeltaArray)
1041    assert_numpy_array_equal(left._data, right._data, obj=f"{obj}._data")
1042    if check_freq:
1043        assert_attr_equal("freq", left, right, obj=obj)
1044
1045
1046def raise_assert_detail(obj, message, left, right, diff=None, index_values=None):
1047    __tracebackhide__ = True
1048
1049    msg = f"""{obj} are different
1050
1051{message}"""
1052
1053    if isinstance(index_values, np.ndarray):
1054        msg += f"\n[index]: {pprint_thing(index_values)}"
1055
1056    if isinstance(left, np.ndarray):
1057        left = pprint_thing(left)
1058    elif is_categorical_dtype(left):
1059        left = repr(left)
1060
1061    if isinstance(right, np.ndarray):
1062        right = pprint_thing(right)
1063    elif is_categorical_dtype(right):
1064        right = repr(right)
1065
1066    msg += f"""
1067[left]:  {left}
1068[right]: {right}"""
1069
1070    if diff is not None:
1071        msg += f"\n[diff]: {diff}"
1072
1073    raise AssertionError(msg)
1074
1075
1076def assert_numpy_array_equal(
1077    left,
1078    right,
1079    strict_nan=False,
1080    check_dtype=True,
1081    err_msg=None,
1082    check_same=None,
1083    obj="numpy array",
1084    index_values=None,
1085):
1086    """
1087    Check that 'np.ndarray' is equivalent.
1088
1089    Parameters
1090    ----------
1091    left, right : numpy.ndarray or iterable
1092        The two arrays to be compared.
1093    strict_nan : bool, default False
1094        If True, consider NaN and None to be different.
1095    check_dtype : bool, default True
1096        Check dtype if both a and b are np.ndarray.
1097    err_msg : str, default None
1098        If provided, used as assertion message.
1099    check_same : None|'copy'|'same', default None
1100        Ensure left and right refer/do not refer to the same memory area.
1101    obj : str, default 'numpy array'
1102        Specify object name being compared, internally used to show appropriate
1103        assertion message.
1104    index_values : numpy.ndarray, default None
1105        optional index (shared by both left and right), used in output.
1106    """
1107    __tracebackhide__ = True
1108
1109    # instance validation
1110    # Show a detailed error message when classes are different
1111    assert_class_equal(left, right, obj=obj)
1112    # both classes must be an np.ndarray
1113    _check_isinstance(left, right, np.ndarray)
1114
1115    def _get_base(obj):
1116        return obj.base if getattr(obj, "base", None) is not None else obj
1117
1118    left_base = _get_base(left)
1119    right_base = _get_base(right)
1120
1121    if check_same == "same":
1122        if left_base is not right_base:
1123            raise AssertionError(f"{repr(left_base)} is not {repr(right_base)}")
1124    elif check_same == "copy":
1125        if left_base is right_base:
1126            raise AssertionError(f"{repr(left_base)} is {repr(right_base)}")
1127
1128    def _raise(left, right, err_msg):
1129        if err_msg is None:
1130            if left.shape != right.shape:
1131                raise_assert_detail(
1132                    obj, f"{obj} shapes are different", left.shape, right.shape
1133                )
1134
1135            diff = 0
1136            for left_arr, right_arr in zip(left, right):
1137                # count up differences
1138                if not array_equivalent(left_arr, right_arr, strict_nan=strict_nan):
1139                    diff += 1
1140
1141            diff = diff * 100.0 / left.size
1142            msg = f"{obj} values are different ({np.round(diff, 5)} %)"
1143            raise_assert_detail(obj, msg, left, right, index_values=index_values)
1144
1145        raise AssertionError(err_msg)
1146
1147    # compare shape and values
1148    if not array_equivalent(left, right, strict_nan=strict_nan):
1149        _raise(left, right, err_msg)
1150
1151    if check_dtype:
1152        if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
1153            assert_attr_equal("dtype", left, right, obj=obj)
1154
1155
1156def assert_extension_array_equal(
1157    left,
1158    right,
1159    check_dtype=True,
1160    index_values=None,
1161    check_less_precise=no_default,
1162    check_exact=False,
1163    rtol: float = 1.0e-5,
1164    atol: float = 1.0e-8,
1165):
1166    """
1167    Check that left and right ExtensionArrays are equal.
1168
1169    Parameters
1170    ----------
1171    left, right : ExtensionArray
1172        The two arrays to compare.
1173    check_dtype : bool, default True
1174        Whether to check if the ExtensionArray dtypes are identical.
1175    index_values : numpy.ndarray, default None
1176        Optional index (shared by both left and right), used in output.
1177    check_less_precise : bool or int, default False
1178        Specify comparison precision. Only used when check_exact is False.
1179        5 digits (False) or 3 digits (True) after decimal points are compared.
1180        If int, then specify the digits to compare.
1181
1182        .. deprecated:: 1.1.0
1183           Use `rtol` and `atol` instead to define relative/absolute
1184           tolerance, respectively. Similar to :func:`math.isclose`.
1185    check_exact : bool, default False
1186        Whether to compare number exactly.
1187    rtol : float, default 1e-5
1188        Relative tolerance. Only used when check_exact is False.
1189
1190        .. versionadded:: 1.1.0
1191    atol : float, default 1e-8
1192        Absolute tolerance. Only used when check_exact is False.
1193
1194        .. versionadded:: 1.1.0
1195
1196    Notes
1197    -----
1198    Missing values are checked separately from valid values.
1199    A mask of missing values is computed for each and checked to match.
1200    The remaining all-valid values are cast to object dtype and checked.
1201
1202    Examples
1203    --------
1204    >>> from pandas.testing import assert_extension_array_equal
1205    >>> a = pd.Series([1, 2, 3, 4])
1206    >>> b, c = a.array, a.array
1207    >>> assert_extension_array_equal(b, c)
1208    """
1209    if check_less_precise is not no_default:
1210        warnings.warn(
1211            "The 'check_less_precise' keyword in testing.assert_*_equal "
1212            "is deprecated and will be removed in a future version. "
1213            "You can stop passing 'check_less_precise' to silence this warning.",
1214            FutureWarning,
1215            stacklevel=2,
1216        )
1217        rtol = atol = _get_tol_from_less_precise(check_less_precise)
1218
1219    assert isinstance(left, ExtensionArray), "left is not an ExtensionArray"
1220    assert isinstance(right, ExtensionArray), "right is not an ExtensionArray"
1221    if check_dtype:
1222        assert_attr_equal("dtype", left, right, obj="ExtensionArray")
1223
1224    if (
1225        isinstance(left, DatetimeLikeArrayMixin)
1226        and isinstance(right, DatetimeLikeArrayMixin)
1227        and type(right) == type(left)
1228    ):
1229        # Avoid slow object-dtype comparisons
1230        # np.asarray for case where we have a np.MaskedArray
1231        assert_numpy_array_equal(
1232            np.asarray(left.asi8), np.asarray(right.asi8), index_values=index_values
1233        )
1234        return
1235
1236    left_na = np.asarray(left.isna())
1237    right_na = np.asarray(right.isna())
1238    assert_numpy_array_equal(
1239        left_na, right_na, obj="ExtensionArray NA mask", index_values=index_values
1240    )
1241
1242    left_valid = np.asarray(left[~left_na].astype(object))
1243    right_valid = np.asarray(right[~right_na].astype(object))
1244    if check_exact:
1245        assert_numpy_array_equal(
1246            left_valid, right_valid, obj="ExtensionArray", index_values=index_values
1247        )
1248    else:
1249        _testing.assert_almost_equal(
1250            left_valid,
1251            right_valid,
1252            check_dtype=check_dtype,
1253            rtol=rtol,
1254            atol=atol,
1255            obj="ExtensionArray",
1256            index_values=index_values,
1257        )
1258
1259
1260# This could be refactored to use the NDFrame.equals method
1261def assert_series_equal(
1262    left,
1263    right,
1264    check_dtype=True,
1265    check_index_type="equiv",
1266    check_series_type=True,
1267    check_less_precise=no_default,
1268    check_names=True,
1269    check_exact=False,
1270    check_datetimelike_compat=False,
1271    check_categorical=True,
1272    check_category_order=True,
1273    check_freq=True,
1274    check_flags=True,
1275    rtol=1.0e-5,
1276    atol=1.0e-8,
1277    obj="Series",
1278):
1279    """
1280    Check that left and right Series are equal.
1281
1282    Parameters
1283    ----------
1284    left : Series
1285    right : Series
1286    check_dtype : bool, default True
1287        Whether to check the Series dtype is identical.
1288    check_index_type : bool or {'equiv'}, default 'equiv'
1289        Whether to check the Index class, dtype and inferred_type
1290        are identical.
1291    check_series_type : bool, default True
1292         Whether to check the Series class is identical.
1293    check_less_precise : bool or int, default False
1294        Specify comparison precision. Only used when check_exact is False.
1295        5 digits (False) or 3 digits (True) after decimal points are compared.
1296        If int, then specify the digits to compare.
1297
1298        When comparing two numbers, if the first number has magnitude less
1299        than 1e-5, we compare the two numbers directly and check whether
1300        they are equivalent within the specified precision. Otherwise, we
1301        compare the **ratio** of the second number to the first number and
1302        check whether it is equivalent to 1 within the specified precision.
1303
1304        .. deprecated:: 1.1.0
1305           Use `rtol` and `atol` instead to define relative/absolute
1306           tolerance, respectively. Similar to :func:`math.isclose`.
1307    check_names : bool, default True
1308        Whether to check the Series and Index names attribute.
1309    check_exact : bool, default False
1310        Whether to compare number exactly.
1311    check_datetimelike_compat : bool, default False
1312        Compare datetime-like which is comparable ignoring dtype.
1313    check_categorical : bool, default True
1314        Whether to compare internal Categorical exactly.
1315    check_category_order : bool, default True
1316        Whether to compare category order of internal Categoricals.
1317
1318        .. versionadded:: 1.0.2
1319    check_freq : bool, default True
1320        Whether to check the `freq` attribute on a DatetimeIndex or TimedeltaIndex.
1321
1322        .. versionadded:: 1.1.0
1323    check_flags : bool, default True
1324        Whether to check the `flags` attribute.
1325
1326        .. versionadded:: 1.2.0
1327
1328    rtol : float, default 1e-5
1329        Relative tolerance. Only used when check_exact is False.
1330
1331        .. versionadded:: 1.1.0
1332    atol : float, default 1e-8
1333        Absolute tolerance. Only used when check_exact is False.
1334
1335        .. versionadded:: 1.1.0
1336    obj : str, default 'Series'
1337        Specify object name being compared, internally used to show appropriate
1338        assertion message.
1339
1340    Examples
1341    --------
1342    >>> from pandas.testing import assert_series_equal
1343    >>> a = pd.Series([1, 2, 3, 4])
1344    >>> b = pd.Series([1, 2, 3, 4])
1345    >>> assert_series_equal(a, b)
1346    """
1347    __tracebackhide__ = True
1348
1349    if check_less_precise is not no_default:
1350        warnings.warn(
1351            "The 'check_less_precise' keyword in testing.assert_*_equal "
1352            "is deprecated and will be removed in a future version. "
1353            "You can stop passing 'check_less_precise' to silence this warning.",
1354            FutureWarning,
1355            stacklevel=2,
1356        )
1357        rtol = atol = _get_tol_from_less_precise(check_less_precise)
1358
1359    # instance validation
1360    _check_isinstance(left, right, Series)
1361
1362    if check_series_type:
1363        assert_class_equal(left, right, obj=obj)
1364
1365    # length comparison
1366    if len(left) != len(right):
1367        msg1 = f"{len(left)}, {left.index}"
1368        msg2 = f"{len(right)}, {right.index}"
1369        raise_assert_detail(obj, "Series length are different", msg1, msg2)
1370
1371    if check_flags:
1372        assert left.flags == right.flags, f"{repr(left.flags)} != {repr(right.flags)}"
1373
1374    # index comparison
1375    assert_index_equal(
1376        left.index,
1377        right.index,
1378        exact=check_index_type,
1379        check_names=check_names,
1380        check_exact=check_exact,
1381        check_categorical=check_categorical,
1382        rtol=rtol,
1383        atol=atol,
1384        obj=f"{obj}.index",
1385    )
1386    if check_freq and isinstance(left.index, (pd.DatetimeIndex, pd.TimedeltaIndex)):
1387        lidx = left.index
1388        ridx = right.index
1389        assert lidx.freq == ridx.freq, (lidx.freq, ridx.freq)
1390
1391    if check_dtype:
1392        # We want to skip exact dtype checking when `check_categorical`
1393        # is False. We'll still raise if only one is a `Categorical`,
1394        # regardless of `check_categorical`
1395        if (
1396            is_categorical_dtype(left.dtype)
1397            and is_categorical_dtype(right.dtype)
1398            and not check_categorical
1399        ):
1400            pass
1401        else:
1402            assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}")
1403
1404    if check_exact and is_numeric_dtype(left.dtype) and is_numeric_dtype(right.dtype):
1405        left_values = left._values
1406        right_values = right._values
1407        # Only check exact if dtype is numeric
1408        if is_extension_array_dtype(left_values) and is_extension_array_dtype(
1409            right_values
1410        ):
1411            assert_extension_array_equal(
1412                left_values,
1413                right_values,
1414                check_dtype=check_dtype,
1415                index_values=np.asarray(left.index),
1416            )
1417        else:
1418            assert_numpy_array_equal(
1419                left_values,
1420                right_values,
1421                check_dtype=check_dtype,
1422                obj=str(obj),
1423                index_values=np.asarray(left.index),
1424            )
1425    elif check_datetimelike_compat and (
1426        needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype)
1427    ):
1428        # we want to check only if we have compat dtypes
1429        # e.g. integer and M|m are NOT compat, but we can simply check
1430        # the values in that case
1431
1432        # datetimelike may have different objects (e.g. datetime.datetime
1433        # vs Timestamp) but will compare equal
1434        if not Index(left._values).equals(Index(right._values)):
1435            msg = (
1436                f"[datetimelike_compat=True] {left._values} "
1437                f"is not equal to {right._values}."
1438            )
1439            raise AssertionError(msg)
1440    elif is_interval_dtype(left.dtype) and is_interval_dtype(right.dtype):
1441        assert_interval_array_equal(left.array, right.array)
1442    elif is_categorical_dtype(left.dtype) or is_categorical_dtype(right.dtype):
1443        _testing.assert_almost_equal(
1444            left._values,
1445            right._values,
1446            rtol=rtol,
1447            atol=atol,
1448            check_dtype=check_dtype,
1449            obj=str(obj),
1450            index_values=np.asarray(left.index),
1451        )
1452    elif is_extension_array_dtype(left.dtype) and is_extension_array_dtype(right.dtype):
1453        assert_extension_array_equal(
1454            left._values,
1455            right._values,
1456            check_dtype=check_dtype,
1457            index_values=np.asarray(left.index),
1458        )
1459    elif is_extension_array_dtype_and_needs_i8_conversion(
1460        left.dtype, right.dtype
1461    ) or is_extension_array_dtype_and_needs_i8_conversion(right.dtype, left.dtype):
1462        assert_extension_array_equal(
1463            left._values,
1464            right._values,
1465            check_dtype=check_dtype,
1466            index_values=np.asarray(left.index),
1467        )
1468    elif needs_i8_conversion(left.dtype) and needs_i8_conversion(right.dtype):
1469        # DatetimeArray or TimedeltaArray
1470        assert_extension_array_equal(
1471            left._values,
1472            right._values,
1473            check_dtype=check_dtype,
1474            index_values=np.asarray(left.index),
1475        )
1476    else:
1477        _testing.assert_almost_equal(
1478            left._values,
1479            right._values,
1480            rtol=rtol,
1481            atol=atol,
1482            check_dtype=check_dtype,
1483            obj=str(obj),
1484            index_values=np.asarray(left.index),
1485        )
1486
1487    # metadata comparison
1488    if check_names:
1489        assert_attr_equal("name", left, right, obj=obj)
1490
1491    if check_categorical:
1492        if is_categorical_dtype(left.dtype) or is_categorical_dtype(right.dtype):
1493            assert_categorical_equal(
1494                left._values,
1495                right._values,
1496                obj=f"{obj} category",
1497                check_category_order=check_category_order,
1498            )
1499
1500
1501# This could be refactored to use the NDFrame.equals method
1502def assert_frame_equal(
1503    left,
1504    right,
1505    check_dtype=True,
1506    check_index_type="equiv",
1507    check_column_type="equiv",
1508    check_frame_type=True,
1509    check_less_precise=no_default,
1510    check_names=True,
1511    by_blocks=False,
1512    check_exact=False,
1513    check_datetimelike_compat=False,
1514    check_categorical=True,
1515    check_like=False,
1516    check_freq=True,
1517    check_flags=True,
1518    rtol=1.0e-5,
1519    atol=1.0e-8,
1520    obj="DataFrame",
1521):
1522    """
1523    Check that left and right DataFrame are equal.
1524
1525    This function is intended to compare two DataFrames and output any
1526    differences. Is is mostly intended for use in unit tests.
1527    Additional parameters allow varying the strictness of the
1528    equality checks performed.
1529
1530    Parameters
1531    ----------
1532    left : DataFrame
1533        First DataFrame to compare.
1534    right : DataFrame
1535        Second DataFrame to compare.
1536    check_dtype : bool, default True
1537        Whether to check the DataFrame dtype is identical.
1538    check_index_type : bool or {'equiv'}, default 'equiv'
1539        Whether to check the Index class, dtype and inferred_type
1540        are identical.
1541    check_column_type : bool or {'equiv'}, default 'equiv'
1542        Whether to check the columns class, dtype and inferred_type
1543        are identical. Is passed as the ``exact`` argument of
1544        :func:`assert_index_equal`.
1545    check_frame_type : bool, default True
1546        Whether to check the DataFrame class is identical.
1547    check_less_precise : bool or int, default False
1548        Specify comparison precision. Only used when check_exact is False.
1549        5 digits (False) or 3 digits (True) after decimal points are compared.
1550        If int, then specify the digits to compare.
1551
1552        When comparing two numbers, if the first number has magnitude less
1553        than 1e-5, we compare the two numbers directly and check whether
1554        they are equivalent within the specified precision. Otherwise, we
1555        compare the **ratio** of the second number to the first number and
1556        check whether it is equivalent to 1 within the specified precision.
1557
1558        .. deprecated:: 1.1.0
1559           Use `rtol` and `atol` instead to define relative/absolute
1560           tolerance, respectively. Similar to :func:`math.isclose`.
1561    check_names : bool, default True
1562        Whether to check that the `names` attribute for both the `index`
1563        and `column` attributes of the DataFrame is identical.
1564    by_blocks : bool, default False
1565        Specify how to compare internal data. If False, compare by columns.
1566        If True, compare by blocks.
1567    check_exact : bool, default False
1568        Whether to compare number exactly.
1569    check_datetimelike_compat : bool, default False
1570        Compare datetime-like which is comparable ignoring dtype.
1571    check_categorical : bool, default True
1572        Whether to compare internal Categorical exactly.
1573    check_like : bool, default False
1574        If True, ignore the order of index & columns.
1575        Note: index labels must match their respective rows
1576        (same as in columns) - same labels must be with the same data.
1577    check_freq : bool, default True
1578        Whether to check the `freq` attribute on a DatetimeIndex or TimedeltaIndex.
1579
1580        .. versionadded:: 1.1.0
1581    check_flags : bool, default True
1582        Whether to check the `flags` attribute.
1583    rtol : float, default 1e-5
1584        Relative tolerance. Only used when check_exact is False.
1585
1586        .. versionadded:: 1.1.0
1587    atol : float, default 1e-8
1588        Absolute tolerance. Only used when check_exact is False.
1589
1590        .. versionadded:: 1.1.0
1591    obj : str, default 'DataFrame'
1592        Specify object name being compared, internally used to show appropriate
1593        assertion message.
1594
1595    See Also
1596    --------
1597    assert_series_equal : Equivalent method for asserting Series equality.
1598    DataFrame.equals : Check DataFrame equality.
1599
1600    Examples
1601    --------
1602    This example shows comparing two DataFrames that are equal
1603    but with columns of differing dtypes.
1604
1605    >>> from pandas._testing import assert_frame_equal
1606    >>> df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
1607    >>> df2 = pd.DataFrame({'a': [1, 2], 'b': [3.0, 4.0]})
1608
1609    df1 equals itself.
1610
1611    >>> assert_frame_equal(df1, df1)
1612
1613    df1 differs from df2 as column 'b' is of a different type.
1614
1615    >>> assert_frame_equal(df1, df2)
1616    Traceback (most recent call last):
1617    ...
1618    AssertionError: Attributes of DataFrame.iloc[:, 1] (column name="b") are different
1619
1620    Attribute "dtype" are different
1621    [left]:  int64
1622    [right]: float64
1623
1624    Ignore differing dtypes in columns with check_dtype.
1625
1626    >>> assert_frame_equal(df1, df2, check_dtype=False)
1627    """
1628    __tracebackhide__ = True
1629
1630    if check_less_precise is not no_default:
1631        warnings.warn(
1632            "The 'check_less_precise' keyword in testing.assert_*_equal "
1633            "is deprecated and will be removed in a future version. "
1634            "You can stop passing 'check_less_precise' to silence this warning.",
1635            FutureWarning,
1636            stacklevel=2,
1637        )
1638        rtol = atol = _get_tol_from_less_precise(check_less_precise)
1639
1640    # instance validation
1641    _check_isinstance(left, right, DataFrame)
1642
1643    if check_frame_type:
1644        assert isinstance(left, type(right))
1645        # assert_class_equal(left, right, obj=obj)
1646
1647    # shape comparison
1648    if left.shape != right.shape:
1649        raise_assert_detail(
1650            obj, f"{obj} shape mismatch", f"{repr(left.shape)}", f"{repr(right.shape)}"
1651        )
1652
1653    if check_flags:
1654        assert left.flags == right.flags, f"{repr(left.flags)} != {repr(right.flags)}"
1655
1656    # index comparison
1657    assert_index_equal(
1658        left.index,
1659        right.index,
1660        exact=check_index_type,
1661        check_names=check_names,
1662        check_exact=check_exact,
1663        check_categorical=check_categorical,
1664        check_order=not check_like,
1665        rtol=rtol,
1666        atol=atol,
1667        obj=f"{obj}.index",
1668    )
1669
1670    # column comparison
1671    assert_index_equal(
1672        left.columns,
1673        right.columns,
1674        exact=check_column_type,
1675        check_names=check_names,
1676        check_exact=check_exact,
1677        check_categorical=check_categorical,
1678        check_order=not check_like,
1679        rtol=rtol,
1680        atol=atol,
1681        obj=f"{obj}.columns",
1682    )
1683
1684    if check_like:
1685        left, right = left.reindex_like(right), right
1686
1687    # compare by blocks
1688    if by_blocks:
1689        rblocks = right._to_dict_of_blocks()
1690        lblocks = left._to_dict_of_blocks()
1691        for dtype in list(set(list(lblocks.keys()) + list(rblocks.keys()))):
1692            assert dtype in lblocks
1693            assert dtype in rblocks
1694            assert_frame_equal(
1695                lblocks[dtype], rblocks[dtype], check_dtype=check_dtype, obj=obj
1696            )
1697
1698    # compare by columns
1699    else:
1700        for i, col in enumerate(left.columns):
1701            assert col in right
1702            lcol = left.iloc[:, i]
1703            rcol = right.iloc[:, i]
1704            assert_series_equal(
1705                lcol,
1706                rcol,
1707                check_dtype=check_dtype,
1708                check_index_type=check_index_type,
1709                check_exact=check_exact,
1710                check_names=check_names,
1711                check_datetimelike_compat=check_datetimelike_compat,
1712                check_categorical=check_categorical,
1713                check_freq=check_freq,
1714                obj=f'{obj}.iloc[:, {i}] (column name="{col}")',
1715                rtol=rtol,
1716                atol=atol,
1717            )
1718
1719
1720def assert_equal(left, right, **kwargs):
1721    """
1722    Wrapper for tm.assert_*_equal to dispatch to the appropriate test function.
1723
1724    Parameters
1725    ----------
1726    left, right : Index, Series, DataFrame, ExtensionArray, or np.ndarray
1727        The two items to be compared.
1728    **kwargs
1729        All keyword arguments are passed through to the underlying assert method.
1730    """
1731    __tracebackhide__ = True
1732
1733    if isinstance(left, pd.Index):
1734        assert_index_equal(left, right, **kwargs)
1735        if isinstance(left, (pd.DatetimeIndex, pd.TimedeltaIndex)):
1736            assert left.freq == right.freq, (left.freq, right.freq)
1737    elif isinstance(left, pd.Series):
1738        assert_series_equal(left, right, **kwargs)
1739    elif isinstance(left, pd.DataFrame):
1740        assert_frame_equal(left, right, **kwargs)
1741    elif isinstance(left, IntervalArray):
1742        assert_interval_array_equal(left, right, **kwargs)
1743    elif isinstance(left, PeriodArray):
1744        assert_period_array_equal(left, right, **kwargs)
1745    elif isinstance(left, DatetimeArray):
1746        assert_datetime_array_equal(left, right, **kwargs)
1747    elif isinstance(left, TimedeltaArray):
1748        assert_timedelta_array_equal(left, right, **kwargs)
1749    elif isinstance(left, ExtensionArray):
1750        assert_extension_array_equal(left, right, **kwargs)
1751    elif isinstance(left, np.ndarray):
1752        assert_numpy_array_equal(left, right, **kwargs)
1753    elif isinstance(left, str):
1754        assert kwargs == {}
1755        assert left == right
1756    else:
1757        raise NotImplementedError(type(left))
1758
1759
1760def box_expected(expected, box_cls, transpose=True):
1761    """
1762    Helper function to wrap the expected output of a test in a given box_class.
1763
1764    Parameters
1765    ----------
1766    expected : np.ndarray, Index, Series
1767    box_cls : {Index, Series, DataFrame}
1768
1769    Returns
1770    -------
1771    subclass of box_cls
1772    """
1773    if box_cls is pd.array:
1774        expected = pd.array(expected)
1775    elif box_cls is pd.Index:
1776        expected = pd.Index(expected)
1777    elif box_cls is pd.Series:
1778        expected = pd.Series(expected)
1779    elif box_cls is pd.DataFrame:
1780        expected = pd.Series(expected).to_frame()
1781        if transpose:
1782            # for vector operations, we need a DataFrame to be a single-row,
1783            #  not a single-column, in order to operate against non-DataFrame
1784            #  vectors of the same length.
1785            expected = expected.T
1786    elif box_cls is PeriodArray:
1787        # the PeriodArray constructor is not as flexible as period_array
1788        expected = period_array(expected)
1789    elif box_cls is DatetimeArray:
1790        expected = DatetimeArray(expected)
1791    elif box_cls is TimedeltaArray:
1792        expected = TimedeltaArray(expected)
1793    elif box_cls is np.ndarray:
1794        expected = np.array(expected)
1795    elif box_cls is to_array:
1796        expected = to_array(expected)
1797    else:
1798        raise NotImplementedError(box_cls)
1799    return expected
1800
1801
1802def to_array(obj):
1803    # temporary implementation until we get pd.array in place
1804    dtype = getattr(obj, "dtype", None)
1805
1806    if is_period_dtype(dtype):
1807        return period_array(obj)
1808    elif is_datetime64_dtype(dtype) or is_datetime64tz_dtype(dtype):
1809        return DatetimeArray._from_sequence(obj)
1810    elif is_timedelta64_dtype(dtype):
1811        return TimedeltaArray._from_sequence(obj)
1812    else:
1813        return np.array(obj)
1814
1815
1816# -----------------------------------------------------------------------------
1817# Sparse
1818
1819
1820def assert_sp_array_equal(left, right):
1821    """
1822    Check that the left and right SparseArray are equal.
1823
1824    Parameters
1825    ----------
1826    left : SparseArray
1827    right : SparseArray
1828    """
1829    _check_isinstance(left, right, pd.arrays.SparseArray)
1830
1831    assert_numpy_array_equal(left.sp_values, right.sp_values)
1832
1833    # SparseIndex comparison
1834    assert isinstance(left.sp_index, pd._libs.sparse.SparseIndex)
1835    assert isinstance(right.sp_index, pd._libs.sparse.SparseIndex)
1836
1837    left_index = left.sp_index
1838    right_index = right.sp_index
1839
1840    if not left_index.equals(right_index):
1841        raise_assert_detail(
1842            "SparseArray.index", "index are not equal", left_index, right_index
1843        )
1844    else:
1845        # Just ensure a
1846        pass
1847
1848    assert_attr_equal("fill_value", left, right)
1849    assert_attr_equal("dtype", left, right)
1850    assert_numpy_array_equal(left.to_dense(), right.to_dense())
1851
1852
1853# -----------------------------------------------------------------------------
1854# Others
1855
1856
1857def assert_contains_all(iterable, dic):
1858    for k in iterable:
1859        assert k in dic, f"Did not contain item: {repr(k)}"
1860
1861
1862def assert_copy(iter1, iter2, **eql_kwargs):
1863    """
1864    iter1, iter2: iterables that produce elements
1865    comparable with assert_almost_equal
1866
1867    Checks that the elements are equal, but not
1868    the same object. (Does not check that items
1869    in sequences are also not the same object)
1870    """
1871    for elem1, elem2 in zip(iter1, iter2):
1872        assert_almost_equal(elem1, elem2, **eql_kwargs)
1873        msg = (
1874            f"Expected object {repr(type(elem1))} and object {repr(type(elem2))} to be "
1875            "different objects, but they were the same object."
1876        )
1877        assert elem1 is not elem2, msg
1878
1879
1880def is_extension_array_dtype_and_needs_i8_conversion(left_dtype, right_dtype) -> bool:
1881    """
1882    Checks that we have the combination of an ExtensionArraydtype and
1883    a dtype that should be converted to int64
1884
1885    Returns
1886    -------
1887    bool
1888
1889    Related to issue #37609
1890    """
1891    return is_extension_array_dtype(left_dtype) and needs_i8_conversion(right_dtype)
1892
1893
1894def getCols(k):
1895    return string.ascii_uppercase[:k]
1896
1897
1898# make index
1899def makeStringIndex(k=10, name=None):
1900    return Index(rands_array(nchars=10, size=k), name=name)
1901
1902
1903def makeUnicodeIndex(k=10, name=None):
1904    return Index(randu_array(nchars=10, size=k), name=name)
1905
1906
1907def makeCategoricalIndex(k=10, n=3, name=None, **kwargs):
1908    """ make a length k index or n categories """
1909    x = rands_array(nchars=4, size=n)
1910    return CategoricalIndex(
1911        Categorical.from_codes(np.arange(k) % n, categories=x), name=name, **kwargs
1912    )
1913
1914
1915def makeIntervalIndex(k=10, name=None, **kwargs):
1916    """ make a length k IntervalIndex """
1917    x = np.linspace(0, 100, num=(k + 1))
1918    return IntervalIndex.from_breaks(x, name=name, **kwargs)
1919
1920
1921def makeBoolIndex(k=10, name=None):
1922    if k == 1:
1923        return Index([True], name=name)
1924    elif k == 2:
1925        return Index([False, True], name=name)
1926    return Index([False, True] + [False] * (k - 2), name=name)
1927
1928
1929def makeIntIndex(k=10, name=None):
1930    return Index(list(range(k)), name=name)
1931
1932
1933def makeUIntIndex(k=10, name=None):
1934    return Index([2 ** 63 + i for i in range(k)], name=name)
1935
1936
1937def makeRangeIndex(k=10, name=None, **kwargs):
1938    return RangeIndex(0, k, 1, name=name, **kwargs)
1939
1940
1941def makeFloatIndex(k=10, name=None):
1942    values = sorted(np.random.random_sample(k)) - np.random.random_sample(1)
1943    return Index(values * (10 ** np.random.randint(0, 9)), name=name)
1944
1945
1946def makeDateIndex(k=10, freq="B", name=None, **kwargs):
1947    dt = datetime(2000, 1, 1)
1948    dr = bdate_range(dt, periods=k, freq=freq, name=name)
1949    return DatetimeIndex(dr, name=name, **kwargs)
1950
1951
1952def makeTimedeltaIndex(k=10, freq="D", name=None, **kwargs):
1953    return pd.timedelta_range(start="1 day", periods=k, freq=freq, name=name, **kwargs)
1954
1955
1956def makePeriodIndex(k=10, name=None, **kwargs):
1957    dt = datetime(2000, 1, 1)
1958    return pd.period_range(start=dt, periods=k, freq="B", name=name, **kwargs)
1959
1960
1961def makeMultiIndex(k=10, names=None, **kwargs):
1962    return MultiIndex.from_product((("foo", "bar"), (1, 2)), names=names, **kwargs)
1963
1964
1965_names = [
1966    "Alice",
1967    "Bob",
1968    "Charlie",
1969    "Dan",
1970    "Edith",
1971    "Frank",
1972    "George",
1973    "Hannah",
1974    "Ingrid",
1975    "Jerry",
1976    "Kevin",
1977    "Laura",
1978    "Michael",
1979    "Norbert",
1980    "Oliver",
1981    "Patricia",
1982    "Quinn",
1983    "Ray",
1984    "Sarah",
1985    "Tim",
1986    "Ursula",
1987    "Victor",
1988    "Wendy",
1989    "Xavier",
1990    "Yvonne",
1991    "Zelda",
1992]
1993
1994
1995def _make_timeseries(start="2000-01-01", end="2000-12-31", freq="1D", seed=None):
1996    """
1997    Make a DataFrame with a DatetimeIndex
1998
1999    Parameters
2000    ----------
2001    start : str or Timestamp, default "2000-01-01"
2002        The start of the index. Passed to date_range with `freq`.
2003    end : str or Timestamp, default "2000-12-31"
2004        The end of the index. Passed to date_range with `freq`.
2005    freq : str or Freq
2006        The frequency to use for the DatetimeIndex
2007    seed : int, optional
2008        The random state seed.
2009
2010        * name : object dtype with string names
2011        * id : int dtype with
2012        * x, y : float dtype
2013
2014    Examples
2015    --------
2016    >>> _make_timeseries()
2017                  id    name         x         y
2018    timestamp
2019    2000-01-01   982   Frank  0.031261  0.986727
2020    2000-01-02  1025   Edith -0.086358 -0.032920
2021    2000-01-03   982   Edith  0.473177  0.298654
2022    2000-01-04  1009   Sarah  0.534344 -0.750377
2023    2000-01-05   963   Zelda -0.271573  0.054424
2024    ...          ...     ...       ...       ...
2025    2000-12-27   980  Ingrid -0.132333 -0.422195
2026    2000-12-28   972   Frank -0.376007 -0.298687
2027    2000-12-29  1009  Ursula -0.865047 -0.503133
2028    2000-12-30  1000  Hannah -0.063757 -0.507336
2029    2000-12-31   972     Tim -0.869120  0.531685
2030    """
2031    index = pd.date_range(start=start, end=end, freq=freq, name="timestamp")
2032    n = len(index)
2033    state = np.random.RandomState(seed)
2034    columns = {
2035        "name": state.choice(_names, size=n),
2036        "id": state.poisson(1000, size=n),
2037        "x": state.rand(n) * 2 - 1,
2038        "y": state.rand(n) * 2 - 1,
2039    }
2040    df = pd.DataFrame(columns, index=index, columns=sorted(columns))
2041    if df.index[-1] == end:
2042        df = df.iloc[:-1]
2043    return df
2044
2045
2046def index_subclass_makers_generator():
2047    make_index_funcs = [
2048        makeDateIndex,
2049        makePeriodIndex,
2050        makeTimedeltaIndex,
2051        makeRangeIndex,
2052        makeIntervalIndex,
2053        makeCategoricalIndex,
2054        makeMultiIndex,
2055    ]
2056    yield from make_index_funcs
2057
2058
2059def all_timeseries_index_generator(k=10):
2060    """
2061    Generator which can be iterated over to get instances of all the classes
2062    which represent time-series.
2063
2064    Parameters
2065    ----------
2066    k: length of each of the index instances
2067    """
2068    make_index_funcs = [makeDateIndex, makePeriodIndex, makeTimedeltaIndex]
2069    for make_index_func in make_index_funcs:
2070        # pandas\_testing.py:1986: error: Cannot call function of unknown type
2071        yield make_index_func(k=k)  # type: ignore[operator]
2072
2073
2074# make series
2075def makeFloatSeries(name=None):
2076    index = makeStringIndex(_N)
2077    return Series(randn(_N), index=index, name=name)
2078
2079
2080def makeStringSeries(name=None):
2081    index = makeStringIndex(_N)
2082    return Series(randn(_N), index=index, name=name)
2083
2084
2085def makeObjectSeries(name=None):
2086    data = makeStringIndex(_N)
2087    data = Index(data, dtype=object)
2088    index = makeStringIndex(_N)
2089    return Series(data, index=index, name=name)
2090
2091
2092def getSeriesData():
2093    index = makeStringIndex(_N)
2094    return {c: Series(randn(_N), index=index) for c in getCols(_K)}
2095
2096
2097def makeTimeSeries(nper=None, freq="B", name=None):
2098    if nper is None:
2099        nper = _N
2100    return Series(randn(nper), index=makeDateIndex(nper, freq=freq), name=name)
2101
2102
2103def makePeriodSeries(nper=None, name=None):
2104    if nper is None:
2105        nper = _N
2106    return Series(randn(nper), index=makePeriodIndex(nper), name=name)
2107
2108
2109def getTimeSeriesData(nper=None, freq="B"):
2110    return {c: makeTimeSeries(nper, freq) for c in getCols(_K)}
2111
2112
2113def getPeriodData(nper=None):
2114    return {c: makePeriodSeries(nper) for c in getCols(_K)}
2115
2116
2117# make frame
2118def makeTimeDataFrame(nper=None, freq="B"):
2119    data = getTimeSeriesData(nper, freq)
2120    return DataFrame(data)
2121
2122
2123def makeDataFrame():
2124    data = getSeriesData()
2125    return DataFrame(data)
2126
2127
2128def getMixedTypeDict():
2129    index = Index(["a", "b", "c", "d", "e"])
2130
2131    data = {
2132        "A": [0.0, 1.0, 2.0, 3.0, 4.0],
2133        "B": [0.0, 1.0, 0.0, 1.0, 0.0],
2134        "C": ["foo1", "foo2", "foo3", "foo4", "foo5"],
2135        "D": bdate_range("1/1/2009", periods=5),
2136    }
2137
2138    return index, data
2139
2140
2141def makeMixedDataFrame():
2142    return DataFrame(getMixedTypeDict()[1])
2143
2144
2145def makePeriodFrame(nper=None):
2146    data = getPeriodData(nper)
2147    return DataFrame(data)
2148
2149
2150def makeCustomIndex(
2151    nentries, nlevels, prefix="#", names=False, ndupe_l=None, idx_type=None
2152):
2153    """
2154    Create an index/multindex with given dimensions, levels, names, etc'
2155
2156    nentries - number of entries in index
2157    nlevels - number of levels (> 1 produces multindex)
2158    prefix - a string prefix for labels
2159    names - (Optional), bool or list of strings. if True will use default
2160       names, if false will use no names, if a list is given, the name of
2161       each level in the index will be taken from the list.
2162    ndupe_l - (Optional), list of ints, the number of rows for which the
2163       label will repeated at the corresponding level, you can specify just
2164       the first few, the rest will use the default ndupe_l of 1.
2165       len(ndupe_l) <= nlevels.
2166    idx_type - "i"/"f"/"s"/"u"/"dt"/"p"/"td".
2167       If idx_type is not None, `idx_nlevels` must be 1.
2168       "i"/"f" creates an integer/float index,
2169       "s"/"u" creates a string/unicode index
2170       "dt" create a datetime index.
2171       "td" create a datetime index.
2172
2173        if unspecified, string labels will be generated.
2174    """
2175    if ndupe_l is None:
2176        ndupe_l = [1] * nlevels
2177    assert is_sequence(ndupe_l) and len(ndupe_l) <= nlevels
2178    assert names is None or names is False or names is True or len(names) is nlevels
2179    assert idx_type is None or (
2180        idx_type in ("i", "f", "s", "u", "dt", "p", "td") and nlevels == 1
2181    )
2182
2183    if names is True:
2184        # build default names
2185        names = [prefix + str(i) for i in range(nlevels)]
2186    if names is False:
2187        # pass None to index constructor for no name
2188        names = None
2189
2190    # make singleton case uniform
2191    if isinstance(names, str) and nlevels == 1:
2192        names = [names]
2193
2194    # specific 1D index type requested?
2195    idx_func = {
2196        "i": makeIntIndex,
2197        "f": makeFloatIndex,
2198        "s": makeStringIndex,
2199        "u": makeUnicodeIndex,
2200        "dt": makeDateIndex,
2201        "td": makeTimedeltaIndex,
2202        "p": makePeriodIndex,
2203    }.get(idx_type)
2204    if idx_func:
2205        # pandas\_testing.py:2120: error: Cannot call function of unknown type
2206        idx = idx_func(nentries)  # type: ignore[operator]
2207        # but we need to fill in the name
2208        if names:
2209            idx.name = names[0]
2210        return idx
2211    elif idx_type is not None:
2212        raise ValueError(
2213            f"{repr(idx_type)} is not a legal value for `idx_type`, "
2214            "use  'i'/'f'/'s'/'u'/'dt'/'p'/'td'."
2215        )
2216
2217    if len(ndupe_l) < nlevels:
2218        ndupe_l.extend([1] * (nlevels - len(ndupe_l)))
2219    assert len(ndupe_l) == nlevels
2220
2221    assert all(x > 0 for x in ndupe_l)
2222
2223    tuples = []
2224    for i in range(nlevels):
2225
2226        def keyfunc(x):
2227            import re
2228
2229            numeric_tuple = re.sub(r"[^\d_]_?", "", x).split("_")
2230            return [int(num) for num in numeric_tuple]
2231
2232        # build a list of lists to create the index from
2233        div_factor = nentries // ndupe_l[i] + 1
2234        # pandas\_testing.py:2148: error: Need type annotation for 'cnt'
2235        cnt = Counter()  # type: ignore[var-annotated]
2236        for j in range(div_factor):
2237            label = f"{prefix}_l{i}_g{j}"
2238            cnt[label] = ndupe_l[i]
2239        # cute Counter trick
2240        result = sorted(cnt.elements(), key=keyfunc)[:nentries]
2241        tuples.append(result)
2242
2243    tuples = list(zip(*tuples))
2244
2245    # convert tuples to index
2246    if nentries == 1:
2247        # we have a single level of tuples, i.e. a regular Index
2248        index = Index(tuples[0], name=names[0])
2249    elif nlevels == 1:
2250        name = None if names is None else names[0]
2251        index = Index((x[0] for x in tuples), name=name)
2252    else:
2253        index = MultiIndex.from_tuples(tuples, names=names)
2254    return index
2255
2256
2257def makeCustomDataframe(
2258    nrows,
2259    ncols,
2260    c_idx_names=True,
2261    r_idx_names=True,
2262    c_idx_nlevels=1,
2263    r_idx_nlevels=1,
2264    data_gen_f=None,
2265    c_ndupe_l=None,
2266    r_ndupe_l=None,
2267    dtype=None,
2268    c_idx_type=None,
2269    r_idx_type=None,
2270):
2271    """
2272    Create a DataFrame using supplied parameters.
2273
2274    Parameters
2275    ----------
2276    nrows,  ncols - number of data rows/cols
2277    c_idx_names, idx_names  - False/True/list of strings,  yields No names ,
2278            default names or uses the provided names for the levels of the
2279            corresponding index. You can provide a single string when
2280            c_idx_nlevels ==1.
2281    c_idx_nlevels - number of levels in columns index. > 1 will yield MultiIndex
2282    r_idx_nlevels - number of levels in rows index. > 1 will yield MultiIndex
2283    data_gen_f - a function f(row,col) which return the data value
2284            at that position, the default generator used yields values of the form
2285            "RxCy" based on position.
2286    c_ndupe_l, r_ndupe_l - list of integers, determines the number
2287            of duplicates for each label at a given level of the corresponding
2288            index. The default `None` value produces a multiplicity of 1 across
2289            all levels, i.e. a unique index. Will accept a partial list of length
2290            N < idx_nlevels, for just the first N levels. If ndupe doesn't divide
2291            nrows/ncol, the last label might have lower multiplicity.
2292    dtype - passed to the DataFrame constructor as is, in case you wish to
2293            have more control in conjunction with a custom `data_gen_f`
2294    r_idx_type, c_idx_type -  "i"/"f"/"s"/"u"/"dt"/"td".
2295        If idx_type is not None, `idx_nlevels` must be 1.
2296        "i"/"f" creates an integer/float index,
2297        "s"/"u" creates a string/unicode index
2298        "dt" create a datetime index.
2299        "td" create a timedelta index.
2300
2301            if unspecified, string labels will be generated.
2302
2303    Examples
2304    --------
2305    # 5 row, 3 columns, default names on both, single index on both axis
2306    >> makeCustomDataframe(5,3)
2307
2308    # make the data a random int between 1 and 100
2309    >> mkdf(5,3,data_gen_f=lambda r,c:randint(1,100))
2310
2311    # 2-level multiindex on rows with each label duplicated
2312    # twice on first level, default names on both axis, single
2313    # index on both axis
2314    >> a=makeCustomDataframe(5,3,r_idx_nlevels=2,r_ndupe_l=[2])
2315
2316    # DatetimeIndex on row, index with unicode labels on columns
2317    # no names on either axis
2318    >> a=makeCustomDataframe(5,3,c_idx_names=False,r_idx_names=False,
2319                             r_idx_type="dt",c_idx_type="u")
2320
2321    # 4-level multindex on rows with names provided, 2-level multindex
2322    # on columns with default labels and default names.
2323    >> a=makeCustomDataframe(5,3,r_idx_nlevels=4,
2324                             r_idx_names=["FEE","FI","FO","FAM"],
2325                             c_idx_nlevels=2)
2326
2327    >> a=mkdf(5,3,r_idx_nlevels=2,c_idx_nlevels=4)
2328    """
2329    assert c_idx_nlevels > 0
2330    assert r_idx_nlevels > 0
2331    assert r_idx_type is None or (
2332        r_idx_type in ("i", "f", "s", "u", "dt", "p", "td") and r_idx_nlevels == 1
2333    )
2334    assert c_idx_type is None or (
2335        c_idx_type in ("i", "f", "s", "u", "dt", "p", "td") and c_idx_nlevels == 1
2336    )
2337
2338    columns = makeCustomIndex(
2339        ncols,
2340        nlevels=c_idx_nlevels,
2341        prefix="C",
2342        names=c_idx_names,
2343        ndupe_l=c_ndupe_l,
2344        idx_type=c_idx_type,
2345    )
2346    index = makeCustomIndex(
2347        nrows,
2348        nlevels=r_idx_nlevels,
2349        prefix="R",
2350        names=r_idx_names,
2351        ndupe_l=r_ndupe_l,
2352        idx_type=r_idx_type,
2353    )
2354
2355    # by default, generate data based on location
2356    if data_gen_f is None:
2357        data_gen_f = lambda r, c: f"R{r}C{c}"
2358
2359    data = [[data_gen_f(r, c) for c in range(ncols)] for r in range(nrows)]
2360
2361    return DataFrame(data, index, columns, dtype=dtype)
2362
2363
2364def _create_missing_idx(nrows, ncols, density, random_state=None):
2365    if random_state is None:
2366        random_state = np.random
2367    else:
2368        random_state = np.random.RandomState(random_state)
2369
2370    # below is cribbed from scipy.sparse
2371    size = int(np.round((1 - density) * nrows * ncols))
2372    # generate a few more to ensure unique values
2373    min_rows = 5
2374    fac = 1.02
2375    extra_size = min(size + min_rows, fac * size)
2376
2377    def _gen_unique_rand(rng, _extra_size):
2378        ind = rng.rand(int(_extra_size))
2379        return np.unique(np.floor(ind * nrows * ncols))[:size]
2380
2381    ind = _gen_unique_rand(random_state, extra_size)
2382    while ind.size < size:
2383        extra_size *= 1.05
2384        ind = _gen_unique_rand(random_state, extra_size)
2385
2386    j = np.floor(ind * 1.0 / nrows).astype(int)
2387    i = (ind - j * nrows).astype(int)
2388    return i.tolist(), j.tolist()
2389
2390
2391def makeMissingDataframe(density=0.9, random_state=None):
2392    df = makeDataFrame()
2393    # pandas\_testing.py:2306: error: "_create_missing_idx" gets multiple
2394    # values for keyword argument "density"  [misc]
2395
2396    # pandas\_testing.py:2306: error: "_create_missing_idx" gets multiple
2397    # values for keyword argument "random_state"  [misc]
2398    i, j = _create_missing_idx(  # type: ignore[misc]
2399        *df.shape, density=density, random_state=random_state
2400    )
2401    df.values[i, j] = np.nan
2402    return df
2403
2404
2405def optional_args(decorator):
2406    """
2407    allows a decorator to take optional positional and keyword arguments.
2408    Assumes that taking a single, callable, positional argument means that
2409    it is decorating a function, i.e. something like this::
2410
2411        @my_decorator
2412        def function(): pass
2413
2414    Calls decorator with decorator(f, *args, **kwargs)
2415    """
2416
2417    @wraps(decorator)
2418    def wrapper(*args, **kwargs):
2419        def dec(f):
2420            return decorator(f, *args, **kwargs)
2421
2422        is_decorating = not kwargs and len(args) == 1 and callable(args[0])
2423        if is_decorating:
2424            f = args[0]
2425            # pandas\_testing.py:2331: error: Incompatible types in assignment
2426            # (expression has type "List[<nothing>]", variable has type
2427            # "Tuple[Any, ...]")
2428            args = []  # type: ignore[assignment]
2429            return dec(f)
2430        else:
2431            return dec
2432
2433    return wrapper
2434
2435
2436# skip tests on exceptions with this message
2437_network_error_messages = (
2438    # 'urlopen error timed out',
2439    # 'timeout: timed out',
2440    # 'socket.timeout: timed out',
2441    "timed out",
2442    "Server Hangup",
2443    "HTTP Error 503: Service Unavailable",
2444    "502: Proxy Error",
2445    "HTTP Error 502: internal error",
2446    "HTTP Error 502",
2447    "HTTP Error 503",
2448    "HTTP Error 403",
2449    "HTTP Error 400",
2450    "Temporary failure in name resolution",
2451    "Name or service not known",
2452    "Connection refused",
2453    "certificate verify",
2454)
2455
2456# or this e.errno/e.reason.errno
2457_network_errno_vals = (
2458    101,  # Network is unreachable
2459    111,  # Connection refused
2460    110,  # Connection timed out
2461    104,  # Connection reset Error
2462    54,  # Connection reset by peer
2463    60,  # urllib.error.URLError: [Errno 60] Connection timed out
2464)
2465
2466# Both of the above shouldn't mask real issues such as 404's
2467# or refused connections (changed DNS).
2468# But some tests (test_data yahoo) contact incredibly flakey
2469# servers.
2470
2471# and conditionally raise on exception types in _get_default_network_errors
2472
2473
2474def _get_default_network_errors():
2475    # Lazy import for http.client because it imports many things from the stdlib
2476    import http.client
2477
2478    return (IOError, http.client.HTTPException, TimeoutError)
2479
2480
2481def can_connect(url, error_classes=None):
2482    """
2483    Try to connect to the given url. True if succeeds, False if IOError
2484    raised
2485
2486    Parameters
2487    ----------
2488    url : basestring
2489        The URL to try to connect to
2490
2491    Returns
2492    -------
2493    connectable : bool
2494        Return True if no IOError (unable to connect) or URLError (bad url) was
2495        raised
2496    """
2497    if error_classes is None:
2498        error_classes = _get_default_network_errors()
2499
2500    try:
2501        with urlopen(url):
2502            pass
2503    except error_classes:
2504        return False
2505    else:
2506        return True
2507
2508
2509@optional_args
2510def network(
2511    t,
2512    url="https://www.google.com",
2513    raise_on_error=_RAISE_NETWORK_ERROR_DEFAULT,
2514    check_before_test=False,
2515    error_classes=None,
2516    skip_errnos=_network_errno_vals,
2517    _skip_on_messages=_network_error_messages,
2518):
2519    """
2520    Label a test as requiring network connection and, if an error is
2521    encountered, only raise if it does not find a network connection.
2522
2523    In comparison to ``network``, this assumes an added contract to your test:
2524    you must assert that, under normal conditions, your test will ONLY fail if
2525    it does not have network connectivity.
2526
2527    You can call this in 3 ways: as a standard decorator, with keyword
2528    arguments, or with a positional argument that is the url to check.
2529
2530    Parameters
2531    ----------
2532    t : callable
2533        The test requiring network connectivity.
2534    url : path
2535        The url to test via ``pandas.io.common.urlopen`` to check
2536        for connectivity. Defaults to 'https://www.google.com'.
2537    raise_on_error : bool
2538        If True, never catches errors.
2539    check_before_test : bool
2540        If True, checks connectivity before running the test case.
2541    error_classes : tuple or Exception
2542        error classes to ignore. If not in ``error_classes``, raises the error.
2543        defaults to IOError. Be careful about changing the error classes here.
2544    skip_errnos : iterable of int
2545        Any exception that has .errno or .reason.erno set to one
2546        of these values will be skipped with an appropriate
2547        message.
2548    _skip_on_messages: iterable of string
2549        any exception e for which one of the strings is
2550        a substring of str(e) will be skipped with an appropriate
2551        message. Intended to suppress errors where an errno isn't available.
2552
2553    Notes
2554    -----
2555    * ``raise_on_error`` supersedes ``check_before_test``
2556
2557    Returns
2558    -------
2559    t : callable
2560        The decorated test ``t``, with checks for connectivity errors.
2561
2562    Example
2563    -------
2564
2565    Tests decorated with @network will fail if it's possible to make a network
2566    connection to another URL (defaults to google.com)::
2567
2568      >>> from pandas._testing import network
2569      >>> from pandas.io.common import urlopen
2570      >>> @network
2571      ... def test_network():
2572      ...     with urlopen("rabbit://bonanza.com"):
2573      ...         pass
2574      Traceback
2575         ...
2576      URLError: <urlopen error unknown url type: rabit>
2577
2578      You can specify alternative URLs::
2579
2580        >>> @network("https://www.yahoo.com")
2581        ... def test_something_with_yahoo():
2582        ...    raise IOError("Failure Message")
2583        >>> test_something_with_yahoo()
2584        Traceback (most recent call last):
2585            ...
2586        IOError: Failure Message
2587
2588    If you set check_before_test, it will check the url first and not run the
2589    test on failure::
2590
2591        >>> @network("failing://url.blaher", check_before_test=True)
2592        ... def test_something():
2593        ...     print("I ran!")
2594        ...     raise ValueError("Failure")
2595        >>> test_something()
2596        Traceback (most recent call last):
2597            ...
2598
2599    Errors not related to networking will always be raised.
2600    """
2601    from pytest import skip
2602
2603    if error_classes is None:
2604        error_classes = _get_default_network_errors()
2605
2606    t.network = True
2607
2608    @wraps(t)
2609    def wrapper(*args, **kwargs):
2610        if (
2611            check_before_test
2612            and not raise_on_error
2613            and not can_connect(url, error_classes)
2614        ):
2615            skip()
2616        try:
2617            return t(*args, **kwargs)
2618        except Exception as err:
2619            errno = getattr(err, "errno", None)
2620            if not errno and hasattr(errno, "reason"):
2621                # pandas\_testing.py:2521: error: "Exception" has no attribute
2622                # "reason"
2623                errno = getattr(err.reason, "errno", None)  # type: ignore[attr-defined]
2624
2625            if errno in skip_errnos:
2626                skip(f"Skipping test due to known errno and error {err}")
2627
2628            e_str = str(err)
2629
2630            if any(m.lower() in e_str.lower() for m in _skip_on_messages):
2631                skip(
2632                    f"Skipping test because exception message is known and error {err}"
2633                )
2634
2635            if not isinstance(err, error_classes):
2636                raise
2637
2638            if raise_on_error or can_connect(url, error_classes):
2639                raise
2640            else:
2641                skip(f"Skipping test due to lack of connectivity and error {err}")
2642
2643    return wrapper
2644
2645
2646with_connectivity_check = network
2647
2648
2649@contextmanager
2650def assert_produces_warning(
2651    expected_warning: Optional[Union[Type[Warning], bool]] = Warning,
2652    filter_level="always",
2653    check_stacklevel: bool = True,
2654    raise_on_extra_warnings: bool = True,
2655    match: Optional[str] = None,
2656):
2657    """
2658    Context manager for running code expected to either raise a specific
2659    warning, or not raise any warnings. Verifies that the code raises the
2660    expected warning, and that it does not raise any other unexpected
2661    warnings. It is basically a wrapper around ``warnings.catch_warnings``.
2662
2663    Parameters
2664    ----------
2665    expected_warning : {Warning, False, None}, default Warning
2666        The type of Exception raised. ``exception.Warning`` is the base
2667        class for all warnings. To check that no warning is returned,
2668        specify ``False`` or ``None``.
2669    filter_level : str or None, default "always"
2670        Specifies whether warnings are ignored, displayed, or turned
2671        into errors.
2672        Valid values are:
2673
2674        * "error" - turns matching warnings into exceptions
2675        * "ignore" - discard the warning
2676        * "always" - always emit a warning
2677        * "default" - print the warning the first time it is generated
2678          from each location
2679        * "module" - print the warning the first time it is generated
2680          from each module
2681        * "once" - print the warning the first time it is generated
2682
2683    check_stacklevel : bool, default True
2684        If True, displays the line that called the function containing
2685        the warning to show were the function is called. Otherwise, the
2686        line that implements the function is displayed.
2687    raise_on_extra_warnings : bool, default True
2688        Whether extra warnings not of the type `expected_warning` should
2689        cause the test to fail.
2690    match : str, optional
2691        Match warning message.
2692
2693    Examples
2694    --------
2695    >>> import warnings
2696    >>> with assert_produces_warning():
2697    ...     warnings.warn(UserWarning())
2698    ...
2699    >>> with assert_produces_warning(False):
2700    ...     warnings.warn(RuntimeWarning())
2701    ...
2702    Traceback (most recent call last):
2703        ...
2704    AssertionError: Caused unexpected warning(s): ['RuntimeWarning'].
2705    >>> with assert_produces_warning(UserWarning):
2706    ...     warnings.warn(RuntimeWarning())
2707    Traceback (most recent call last):
2708        ...
2709    AssertionError: Did not see expected warning of class 'UserWarning'.
2710
2711    ..warn:: This is *not* thread-safe.
2712    """
2713    __tracebackhide__ = True
2714
2715    with warnings.catch_warnings(record=True) as w:
2716
2717        saw_warning = False
2718        matched_message = False
2719
2720        warnings.simplefilter(filter_level)
2721        yield w
2722        extra_warnings = []
2723
2724        for actual_warning in w:
2725            if not expected_warning:
2726                continue
2727
2728            expected_warning = cast(Type[Warning], expected_warning)
2729            if issubclass(actual_warning.category, expected_warning):
2730                saw_warning = True
2731
2732                if check_stacklevel and issubclass(
2733                    actual_warning.category, (FutureWarning, DeprecationWarning)
2734                ):
2735                    _assert_raised_with_correct_stacklevel(actual_warning)
2736
2737                if match is not None and re.search(match, str(actual_warning.message)):
2738                    matched_message = True
2739
2740            else:
2741                extra_warnings.append(
2742                    (
2743                        actual_warning.category.__name__,
2744                        actual_warning.message,
2745                        actual_warning.filename,
2746                        actual_warning.lineno,
2747                    )
2748                )
2749
2750        if expected_warning:
2751            expected_warning = cast(Type[Warning], expected_warning)
2752            if not saw_warning:
2753                raise AssertionError(
2754                    f"Did not see expected warning of class "
2755                    f"{repr(expected_warning.__name__)}"
2756                )
2757
2758            if match and not matched_message:
2759                raise AssertionError(
2760                    f"Did not see warning {repr(expected_warning.__name__)} "
2761                    f"matching {match}"
2762                )
2763
2764        if raise_on_extra_warnings and extra_warnings:
2765            raise AssertionError(
2766                f"Caused unexpected warning(s): {repr(extra_warnings)}"
2767            )
2768
2769
2770def _assert_raised_with_correct_stacklevel(
2771    actual_warning: warnings.WarningMessage,
2772) -> None:
2773    from inspect import getframeinfo, stack
2774
2775    caller = getframeinfo(stack()[3][0])
2776    msg = (
2777        "Warning not set with correct stacklevel. "
2778        f"File where warning is raised: {actual_warning.filename} != "
2779        f"{caller.filename}. Warning message: {actual_warning.message}"
2780    )
2781    assert actual_warning.filename == caller.filename, msg
2782
2783
2784class RNGContext:
2785    """
2786    Context manager to set the numpy random number generator speed. Returns
2787    to the original value upon exiting the context manager.
2788
2789    Parameters
2790    ----------
2791    seed : int
2792        Seed for numpy.random.seed
2793
2794    Examples
2795    --------
2796    with RNGContext(42):
2797        np.random.randn()
2798    """
2799
2800    def __init__(self, seed):
2801        self.seed = seed
2802
2803    def __enter__(self):
2804
2805        self.start_state = np.random.get_state()
2806        np.random.seed(self.seed)
2807
2808    def __exit__(self, exc_type, exc_value, traceback):
2809
2810        np.random.set_state(self.start_state)
2811
2812
2813@contextmanager
2814def with_csv_dialect(name, **kwargs):
2815    """
2816    Context manager to temporarily register a CSV dialect for parsing CSV.
2817
2818    Parameters
2819    ----------
2820    name : str
2821        The name of the dialect.
2822    kwargs : mapping
2823        The parameters for the dialect.
2824
2825    Raises
2826    ------
2827    ValueError : the name of the dialect conflicts with a builtin one.
2828
2829    See Also
2830    --------
2831    csv : Python's CSV library.
2832    """
2833    import csv
2834
2835    _BUILTIN_DIALECTS = {"excel", "excel-tab", "unix"}
2836
2837    if name in _BUILTIN_DIALECTS:
2838        raise ValueError("Cannot override builtin dialect.")
2839
2840    csv.register_dialect(name, **kwargs)
2841    yield
2842    csv.unregister_dialect(name)
2843
2844
2845@contextmanager
2846def use_numexpr(use, min_elements=None):
2847    from pandas.core.computation import expressions as expr
2848
2849    if min_elements is None:
2850        min_elements = expr._MIN_ELEMENTS
2851
2852    olduse = expr.USE_NUMEXPR
2853    oldmin = expr._MIN_ELEMENTS
2854    expr.set_use_numexpr(use)
2855    expr._MIN_ELEMENTS = min_elements
2856    yield
2857    expr._MIN_ELEMENTS = oldmin
2858    expr.set_use_numexpr(olduse)
2859
2860
2861def test_parallel(num_threads=2, kwargs_list=None):
2862    """
2863    Decorator to run the same function multiple times in parallel.
2864
2865    Parameters
2866    ----------
2867    num_threads : int, optional
2868        The number of times the function is run in parallel.
2869    kwargs_list : list of dicts, optional
2870        The list of kwargs to update original
2871        function kwargs on different threads.
2872
2873    Notes
2874    -----
2875    This decorator does not pass the return value of the decorated function.
2876
2877    Original from scikit-image:
2878
2879    https://github.com/scikit-image/scikit-image/pull/1519
2880
2881    """
2882    assert num_threads > 0
2883    has_kwargs_list = kwargs_list is not None
2884    if has_kwargs_list:
2885        assert len(kwargs_list) == num_threads
2886    import threading
2887
2888    def wrapper(func):
2889        @wraps(func)
2890        def inner(*args, **kwargs):
2891            if has_kwargs_list:
2892                update_kwargs = lambda i: dict(kwargs, **kwargs_list[i])
2893            else:
2894                update_kwargs = lambda i: kwargs
2895            threads = []
2896            for i in range(num_threads):
2897                updated_kwargs = update_kwargs(i)
2898                thread = threading.Thread(target=func, args=args, kwargs=updated_kwargs)
2899                threads.append(thread)
2900            for thread in threads:
2901                thread.start()
2902            for thread in threads:
2903                thread.join()
2904
2905        return inner
2906
2907    return wrapper
2908
2909
2910class SubclassedSeries(Series):
2911    _metadata = ["testattr", "name"]
2912
2913    @property
2914    def _constructor(self):
2915        return SubclassedSeries
2916
2917    @property
2918    def _constructor_expanddim(self):
2919        return SubclassedDataFrame
2920
2921
2922class SubclassedDataFrame(DataFrame):
2923    _metadata = ["testattr"]
2924
2925    @property
2926    def _constructor(self):
2927        return SubclassedDataFrame
2928
2929    @property
2930    def _constructor_sliced(self):
2931        return SubclassedSeries
2932
2933
2934class SubclassedCategorical(Categorical):
2935    @property
2936    def _constructor(self):
2937        return SubclassedCategorical
2938
2939
2940@contextmanager
2941def set_timezone(tz: str):
2942    """
2943    Context manager for temporarily setting a timezone.
2944
2945    Parameters
2946    ----------
2947    tz : str
2948        A string representing a valid timezone.
2949
2950    Examples
2951    --------
2952    >>> from datetime import datetime
2953    >>> from dateutil.tz import tzlocal
2954    >>> tzlocal().tzname(datetime.now())
2955    'IST'
2956
2957    >>> with set_timezone('US/Eastern'):
2958    ...     tzlocal().tzname(datetime.now())
2959    ...
2960    'EDT'
2961    """
2962    import os
2963    import time
2964
2965    def setTZ(tz):
2966        if tz is None:
2967            try:
2968                del os.environ["TZ"]
2969            except KeyError:
2970                pass
2971        else:
2972            os.environ["TZ"] = tz
2973            time.tzset()
2974
2975    orig_tz = os.environ.get("TZ")
2976    setTZ(tz)
2977    try:
2978        yield
2979    finally:
2980        setTZ(orig_tz)
2981
2982
2983def _make_skipna_wrapper(alternative, skipna_alternative=None):
2984    """
2985    Create a function for calling on an array.
2986
2987    Parameters
2988    ----------
2989    alternative : function
2990        The function to be called on the array with no NaNs.
2991        Only used when 'skipna_alternative' is None.
2992    skipna_alternative : function
2993        The function to be called on the original array
2994
2995    Returns
2996    -------
2997    function
2998    """
2999    if skipna_alternative:
3000
3001        def skipna_wrapper(x):
3002            return skipna_alternative(x.values)
3003
3004    else:
3005
3006        def skipna_wrapper(x):
3007            nona = x.dropna()
3008            if len(nona) == 0:
3009                return np.nan
3010            return alternative(nona)
3011
3012    return skipna_wrapper
3013
3014
3015def convert_rows_list_to_csv_str(rows_list: List[str]):
3016    """
3017    Convert list of CSV rows to single CSV-formatted string for current OS.
3018
3019    This method is used for creating expected value of to_csv() method.
3020
3021    Parameters
3022    ----------
3023    rows_list : List[str]
3024        Each element represents the row of csv.
3025
3026    Returns
3027    -------
3028    str
3029        Expected output of to_csv() in current OS.
3030    """
3031    sep = os.linesep
3032    return sep.join(rows_list) + sep
3033
3034
3035def external_error_raised(expected_exception: Type[Exception]) -> ContextManager:
3036    """
3037    Helper function to mark pytest.raises that have an external error message.
3038
3039    Parameters
3040    ----------
3041    expected_exception : Exception
3042        Expected error to raise.
3043
3044    Returns
3045    -------
3046    Callable
3047        Regular `pytest.raises` function with `match` equal to `None`.
3048    """
3049    import pytest
3050
3051    return pytest.raises(expected_exception, match=None)
3052
3053
3054cython_table = pd.core.base.SelectionMixin._cython_table.items()
3055
3056
3057def get_cython_table_params(ndframe, func_names_and_expected):
3058    """
3059    Combine frame, functions from SelectionMixin._cython_table
3060    keys and expected result.
3061
3062    Parameters
3063    ----------
3064    ndframe : DataFrame or Series
3065    func_names_and_expected : Sequence of two items
3066        The first item is a name of a NDFrame method ('sum', 'prod') etc.
3067        The second item is the expected return value.
3068
3069    Returns
3070    -------
3071    list
3072        List of three items (DataFrame, function, expected result)
3073    """
3074    results = []
3075    for func_name, expected in func_names_and_expected:
3076        results.append((ndframe, func_name, expected))
3077        results += [
3078            (ndframe, func, expected)
3079            for func, name in cython_table
3080            if name == func_name
3081        ]
3082    return results
3083
3084
3085def get_op_from_name(op_name: str) -> Callable:
3086    """
3087    The operator function for a given op name.
3088
3089    Parameters
3090    ----------
3091    op_name : string
3092        The op name, in form of "add" or "__add__".
3093
3094    Returns
3095    -------
3096    function
3097        A function performing the operation.
3098    """
3099    short_opname = op_name.strip("_")
3100    try:
3101        op = getattr(operator, short_opname)
3102    except AttributeError:
3103        # Assume it is the reverse operator
3104        rop = getattr(operator, short_opname[1:])
3105        op = lambda x, y: rop(y, x)
3106
3107    return op
3108