1from datetime import datetime
2from typing import List
3
4import numpy as np
5import pytest
6
7import pandas.util._test_decorators as td
8
9from pandas.core.dtypes.cast import astype_nansafe
10import pandas.core.dtypes.common as com
11from pandas.core.dtypes.dtypes import (
12    CategoricalDtype,
13    CategoricalDtypeType,
14    DatetimeTZDtype,
15    IntervalDtype,
16    PeriodDtype,
17)
18from pandas.core.dtypes.missing import isna
19
20import pandas as pd
21import pandas._testing as tm
22from pandas.arrays import SparseArray
23
24
25# EA & Actual Dtypes
26def to_ea_dtypes(dtypes):
27    """ convert list of string dtypes to EA dtype """
28    return [getattr(pd, dt + "Dtype") for dt in dtypes]
29
30
31def to_numpy_dtypes(dtypes):
32    """ convert list of string dtypes to numpy dtype """
33    return [getattr(np, dt) for dt in dtypes if isinstance(dt, str)]
34
35
36class TestPandasDtype:
37
38    # Passing invalid dtype, both as a string or object, must raise TypeError
39    # Per issue GH15520
40    @pytest.mark.parametrize("box", [pd.Timestamp, "pd.Timestamp", list])
41    def test_invalid_dtype_error(self, box):
42        with pytest.raises(TypeError, match="not understood"):
43            com.pandas_dtype(box)
44
45    @pytest.mark.parametrize(
46        "dtype",
47        [
48            object,
49            "float64",
50            np.object_,
51            np.dtype("object"),
52            "O",
53            np.float64,
54            float,
55            np.dtype("float64"),
56        ],
57    )
58    def test_pandas_dtype_valid(self, dtype):
59        assert com.pandas_dtype(dtype) == dtype
60
61    @pytest.mark.parametrize(
62        "dtype", ["M8[ns]", "m8[ns]", "object", "float64", "int64"]
63    )
64    def test_numpy_dtype(self, dtype):
65        assert com.pandas_dtype(dtype) == np.dtype(dtype)
66
67    def test_numpy_string_dtype(self):
68        # do not parse freq-like string as period dtype
69        assert com.pandas_dtype("U") == np.dtype("U")
70        assert com.pandas_dtype("S") == np.dtype("S")
71
72    @pytest.mark.parametrize(
73        "dtype",
74        [
75            "datetime64[ns, US/Eastern]",
76            "datetime64[ns, Asia/Tokyo]",
77            "datetime64[ns, UTC]",
78            # GH#33885 check that the M8 alias is understood
79            "M8[ns, US/Eastern]",
80            "M8[ns, Asia/Tokyo]",
81            "M8[ns, UTC]",
82        ],
83    )
84    def test_datetimetz_dtype(self, dtype):
85        assert com.pandas_dtype(dtype) == DatetimeTZDtype.construct_from_string(dtype)
86        assert com.pandas_dtype(dtype) == dtype
87
88    def test_categorical_dtype(self):
89        assert com.pandas_dtype("category") == CategoricalDtype()
90
91    @pytest.mark.parametrize(
92        "dtype",
93        [
94            "period[D]",
95            "period[3M]",
96            "period[U]",
97            "Period[D]",
98            "Period[3M]",
99            "Period[U]",
100        ],
101    )
102    def test_period_dtype(self, dtype):
103        assert com.pandas_dtype(dtype) is PeriodDtype(dtype)
104        assert com.pandas_dtype(dtype) == PeriodDtype(dtype)
105        assert com.pandas_dtype(dtype) == dtype
106
107
108dtypes = {
109    "datetime_tz": com.pandas_dtype("datetime64[ns, US/Eastern]"),
110    "datetime": com.pandas_dtype("datetime64[ns]"),
111    "timedelta": com.pandas_dtype("timedelta64[ns]"),
112    "period": PeriodDtype("D"),
113    "integer": np.dtype(np.int64),
114    "float": np.dtype(np.float64),
115    "object": np.dtype(object),
116    "category": com.pandas_dtype("category"),
117}
118
119
120@pytest.mark.parametrize("name1,dtype1", list(dtypes.items()), ids=lambda x: str(x))
121@pytest.mark.parametrize("name2,dtype2", list(dtypes.items()), ids=lambda x: str(x))
122def test_dtype_equal(name1, dtype1, name2, dtype2):
123
124    # match equal to self, but not equal to other
125    assert com.is_dtype_equal(dtype1, dtype1)
126    if name1 != name2:
127        assert not com.is_dtype_equal(dtype1, dtype2)
128
129
130@pytest.mark.parametrize(
131    "dtype1,dtype2",
132    [
133        (np.int8, np.int64),
134        (np.int16, np.int64),
135        (np.int32, np.int64),
136        (np.float32, np.float64),
137        (PeriodDtype("D"), PeriodDtype("2D")),  # PeriodType
138        (
139            com.pandas_dtype("datetime64[ns, US/Eastern]"),
140            com.pandas_dtype("datetime64[ns, CET]"),
141        ),  # Datetime
142        (None, None),  # gh-15941: no exception should be raised.
143    ],
144)
145def test_dtype_equal_strict(dtype1, dtype2):
146    assert not com.is_dtype_equal(dtype1, dtype2)
147
148
149def get_is_dtype_funcs():
150    """
151    Get all functions in pandas.core.dtypes.common that
152    begin with 'is_' and end with 'dtype'
153
154    """
155    fnames = [f for f in dir(com) if (f.startswith("is_") and f.endswith("dtype"))]
156    return [getattr(com, fname) for fname in fnames]
157
158
159@pytest.mark.parametrize("func", get_is_dtype_funcs(), ids=lambda x: x.__name__)
160def test_get_dtype_error_catch(func):
161    # see gh-15941
162    #
163    # No exception should be raised.
164
165    assert not func(None)
166
167
168def test_is_object():
169    assert com.is_object_dtype(object)
170    assert com.is_object_dtype(np.array([], dtype=object))
171
172    assert not com.is_object_dtype(int)
173    assert not com.is_object_dtype(np.array([], dtype=int))
174    assert not com.is_object_dtype([1, 2, 3])
175
176
177@pytest.mark.parametrize(
178    "check_scipy", [False, pytest.param(True, marks=td.skip_if_no_scipy)]
179)
180def test_is_sparse(check_scipy):
181    assert com.is_sparse(SparseArray([1, 2, 3]))
182
183    assert not com.is_sparse(np.array([1, 2, 3]))
184
185    if check_scipy:
186        import scipy.sparse
187
188        assert not com.is_sparse(scipy.sparse.bsr_matrix([1, 2, 3]))
189
190
191@td.skip_if_no_scipy
192def test_is_scipy_sparse():
193    from scipy.sparse import bsr_matrix
194
195    assert com.is_scipy_sparse(bsr_matrix([1, 2, 3]))
196
197    assert not com.is_scipy_sparse(SparseArray([1, 2, 3]))
198
199
200def test_is_categorical():
201    cat = pd.Categorical([1, 2, 3])
202    with tm.assert_produces_warning(FutureWarning):
203        assert com.is_categorical(cat)
204        assert com.is_categorical(pd.Series(cat))
205        assert com.is_categorical(pd.CategoricalIndex([1, 2, 3]))
206
207        assert not com.is_categorical([1, 2, 3])
208
209
210def test_is_categorical_deprecation():
211    # GH#33385
212    with tm.assert_produces_warning(FutureWarning):
213        com.is_categorical([1, 2, 3])
214
215
216def test_is_datetime64_dtype():
217    assert not com.is_datetime64_dtype(object)
218    assert not com.is_datetime64_dtype([1, 2, 3])
219    assert not com.is_datetime64_dtype(np.array([], dtype=int))
220
221    assert com.is_datetime64_dtype(np.datetime64)
222    assert com.is_datetime64_dtype(np.array([], dtype=np.datetime64))
223
224
225def test_is_datetime64tz_dtype():
226    assert not com.is_datetime64tz_dtype(object)
227    assert not com.is_datetime64tz_dtype([1, 2, 3])
228    assert not com.is_datetime64tz_dtype(pd.DatetimeIndex([1, 2, 3]))
229    assert com.is_datetime64tz_dtype(pd.DatetimeIndex(["2000"], tz="US/Eastern"))
230
231
232def test_is_timedelta64_dtype():
233    assert not com.is_timedelta64_dtype(object)
234    assert not com.is_timedelta64_dtype(None)
235    assert not com.is_timedelta64_dtype([1, 2, 3])
236    assert not com.is_timedelta64_dtype(np.array([], dtype=np.datetime64))
237    assert not com.is_timedelta64_dtype("0 days")
238    assert not com.is_timedelta64_dtype("0 days 00:00:00")
239    assert not com.is_timedelta64_dtype(["0 days 00:00:00"])
240    assert not com.is_timedelta64_dtype("NO DATE")
241
242    assert com.is_timedelta64_dtype(np.timedelta64)
243    assert com.is_timedelta64_dtype(pd.Series([], dtype="timedelta64[ns]"))
244    assert com.is_timedelta64_dtype(pd.to_timedelta(["0 days", "1 days"]))
245
246
247def test_is_period_dtype():
248    assert not com.is_period_dtype(object)
249    assert not com.is_period_dtype([1, 2, 3])
250    assert not com.is_period_dtype(pd.Period("2017-01-01"))
251
252    assert com.is_period_dtype(PeriodDtype(freq="D"))
253    assert com.is_period_dtype(pd.PeriodIndex([], freq="A"))
254
255
256def test_is_interval_dtype():
257    assert not com.is_interval_dtype(object)
258    assert not com.is_interval_dtype([1, 2, 3])
259
260    assert com.is_interval_dtype(IntervalDtype())
261
262    interval = pd.Interval(1, 2, closed="right")
263    assert not com.is_interval_dtype(interval)
264    assert com.is_interval_dtype(pd.IntervalIndex([interval]))
265
266
267def test_is_categorical_dtype():
268    assert not com.is_categorical_dtype(object)
269    assert not com.is_categorical_dtype([1, 2, 3])
270
271    assert com.is_categorical_dtype(CategoricalDtype())
272    assert com.is_categorical_dtype(pd.Categorical([1, 2, 3]))
273    assert com.is_categorical_dtype(pd.CategoricalIndex([1, 2, 3]))
274
275
276def test_is_string_dtype():
277    assert not com.is_string_dtype(int)
278    assert not com.is_string_dtype(pd.Series([1, 2]))
279
280    assert com.is_string_dtype(str)
281    assert com.is_string_dtype(object)
282    assert com.is_string_dtype(np.array(["a", "b"]))
283    assert com.is_string_dtype(pd.StringDtype())
284    assert com.is_string_dtype(pd.array(["a", "b"], dtype="string"))
285
286
287integer_dtypes: List = []
288
289
290@pytest.mark.parametrize(
291    "dtype",
292    integer_dtypes
293    + [pd.Series([1, 2])]
294    + tm.ALL_INT_DTYPES
295    + to_numpy_dtypes(tm.ALL_INT_DTYPES)
296    + tm.ALL_EA_INT_DTYPES
297    + to_ea_dtypes(tm.ALL_EA_INT_DTYPES),
298)
299def test_is_integer_dtype(dtype):
300    assert com.is_integer_dtype(dtype)
301
302
303@pytest.mark.parametrize(
304    "dtype",
305    [
306        str,
307        float,
308        np.datetime64,
309        np.timedelta64,
310        pd.Index([1, 2.0]),
311        np.array(["a", "b"]),
312        np.array([], dtype=np.timedelta64),
313    ],
314)
315def test_is_not_integer_dtype(dtype):
316    assert not com.is_integer_dtype(dtype)
317
318
319signed_integer_dtypes: List = []
320
321
322@pytest.mark.parametrize(
323    "dtype",
324    signed_integer_dtypes
325    + [pd.Series([1, 2])]
326    + tm.SIGNED_INT_DTYPES
327    + to_numpy_dtypes(tm.SIGNED_INT_DTYPES)
328    + tm.SIGNED_EA_INT_DTYPES
329    + to_ea_dtypes(tm.SIGNED_EA_INT_DTYPES),
330)
331def test_is_signed_integer_dtype(dtype):
332    assert com.is_integer_dtype(dtype)
333
334
335@pytest.mark.parametrize(
336    "dtype",
337    [
338        str,
339        float,
340        np.datetime64,
341        np.timedelta64,
342        pd.Index([1, 2.0]),
343        np.array(["a", "b"]),
344        np.array([], dtype=np.timedelta64),
345    ]
346    + tm.UNSIGNED_INT_DTYPES
347    + to_numpy_dtypes(tm.UNSIGNED_INT_DTYPES)
348    + tm.UNSIGNED_EA_INT_DTYPES
349    + to_ea_dtypes(tm.UNSIGNED_EA_INT_DTYPES),
350)
351def test_is_not_signed_integer_dtype(dtype):
352    assert not com.is_signed_integer_dtype(dtype)
353
354
355unsigned_integer_dtypes: List = []
356
357
358@pytest.mark.parametrize(
359    "dtype",
360    unsigned_integer_dtypes
361    + [pd.Series([1, 2], dtype=np.uint32)]
362    + tm.UNSIGNED_INT_DTYPES
363    + to_numpy_dtypes(tm.UNSIGNED_INT_DTYPES)
364    + tm.UNSIGNED_EA_INT_DTYPES
365    + to_ea_dtypes(tm.UNSIGNED_EA_INT_DTYPES),
366)
367def test_is_unsigned_integer_dtype(dtype):
368    assert com.is_unsigned_integer_dtype(dtype)
369
370
371@pytest.mark.parametrize(
372    "dtype",
373    [
374        str,
375        float,
376        np.datetime64,
377        np.timedelta64,
378        pd.Index([1, 2.0]),
379        np.array(["a", "b"]),
380        np.array([], dtype=np.timedelta64),
381    ]
382    + tm.SIGNED_INT_DTYPES
383    + to_numpy_dtypes(tm.SIGNED_INT_DTYPES)
384    + tm.SIGNED_EA_INT_DTYPES
385    + to_ea_dtypes(tm.SIGNED_EA_INT_DTYPES),
386)
387def test_is_not_unsigned_integer_dtype(dtype):
388    assert not com.is_unsigned_integer_dtype(dtype)
389
390
391@pytest.mark.parametrize(
392    "dtype", [np.int64, np.array([1, 2], dtype=np.int64), "Int64", pd.Int64Dtype]
393)
394def test_is_int64_dtype(dtype):
395    assert com.is_int64_dtype(dtype)
396
397
398@pytest.mark.parametrize(
399    "dtype",
400    [
401        str,
402        float,
403        np.int32,
404        np.uint64,
405        pd.Index([1, 2.0]),
406        np.array(["a", "b"]),
407        np.array([1, 2], dtype=np.uint32),
408        "int8",
409        "Int8",
410        pd.Int8Dtype,
411    ],
412)
413def test_is_not_int64_dtype(dtype):
414    assert not com.is_int64_dtype(dtype)
415
416
417def test_is_datetime64_any_dtype():
418    assert not com.is_datetime64_any_dtype(int)
419    assert not com.is_datetime64_any_dtype(str)
420    assert not com.is_datetime64_any_dtype(np.array([1, 2]))
421    assert not com.is_datetime64_any_dtype(np.array(["a", "b"]))
422
423    assert com.is_datetime64_any_dtype(np.datetime64)
424    assert com.is_datetime64_any_dtype(np.array([], dtype=np.datetime64))
425    assert com.is_datetime64_any_dtype(DatetimeTZDtype("ns", "US/Eastern"))
426    assert com.is_datetime64_any_dtype(
427        pd.DatetimeIndex([1, 2, 3], dtype="datetime64[ns]")
428    )
429
430
431def test_is_datetime64_ns_dtype():
432    assert not com.is_datetime64_ns_dtype(int)
433    assert not com.is_datetime64_ns_dtype(str)
434    assert not com.is_datetime64_ns_dtype(np.datetime64)
435    assert not com.is_datetime64_ns_dtype(np.array([1, 2]))
436    assert not com.is_datetime64_ns_dtype(np.array(["a", "b"]))
437    assert not com.is_datetime64_ns_dtype(np.array([], dtype=np.datetime64))
438
439    # This datetime array has the wrong unit (ps instead of ns)
440    assert not com.is_datetime64_ns_dtype(np.array([], dtype="datetime64[ps]"))
441
442    assert com.is_datetime64_ns_dtype(DatetimeTZDtype("ns", "US/Eastern"))
443    assert com.is_datetime64_ns_dtype(
444        pd.DatetimeIndex([1, 2, 3], dtype=np.dtype("datetime64[ns]"))
445    )
446
447
448def test_is_timedelta64_ns_dtype():
449    assert not com.is_timedelta64_ns_dtype(np.dtype("m8[ps]"))
450    assert not com.is_timedelta64_ns_dtype(np.array([1, 2], dtype=np.timedelta64))
451
452    assert com.is_timedelta64_ns_dtype(np.dtype("m8[ns]"))
453    assert com.is_timedelta64_ns_dtype(np.array([1, 2], dtype="m8[ns]"))
454
455
456def test_is_datetime_or_timedelta_dtype():
457    assert not com.is_datetime_or_timedelta_dtype(int)
458    assert not com.is_datetime_or_timedelta_dtype(str)
459    assert not com.is_datetime_or_timedelta_dtype(pd.Series([1, 2]))
460    assert not com.is_datetime_or_timedelta_dtype(np.array(["a", "b"]))
461
462    # TODO(jreback), this is slightly suspect
463    assert not com.is_datetime_or_timedelta_dtype(DatetimeTZDtype("ns", "US/Eastern"))
464
465    assert com.is_datetime_or_timedelta_dtype(np.datetime64)
466    assert com.is_datetime_or_timedelta_dtype(np.timedelta64)
467    assert com.is_datetime_or_timedelta_dtype(np.array([], dtype=np.timedelta64))
468    assert com.is_datetime_or_timedelta_dtype(np.array([], dtype=np.datetime64))
469
470
471def test_is_numeric_v_string_like():
472    assert not com.is_numeric_v_string_like(1, 1)
473    assert not com.is_numeric_v_string_like(1, "foo")
474    assert not com.is_numeric_v_string_like("foo", "foo")
475    assert not com.is_numeric_v_string_like(np.array([1]), np.array([2]))
476    assert not com.is_numeric_v_string_like(np.array(["foo"]), np.array(["foo"]))
477
478    assert com.is_numeric_v_string_like(np.array([1]), "foo")
479    assert com.is_numeric_v_string_like("foo", np.array([1]))
480    assert com.is_numeric_v_string_like(np.array([1, 2]), np.array(["foo"]))
481    assert com.is_numeric_v_string_like(np.array(["foo"]), np.array([1, 2]))
482
483
484def test_is_datetimelike_v_numeric():
485    dt = np.datetime64(datetime(2017, 1, 1))
486
487    assert not com.is_datetimelike_v_numeric(1, 1)
488    assert not com.is_datetimelike_v_numeric(dt, dt)
489    assert not com.is_datetimelike_v_numeric(np.array([1]), np.array([2]))
490    assert not com.is_datetimelike_v_numeric(np.array([dt]), np.array([dt]))
491
492    assert com.is_datetimelike_v_numeric(1, dt)
493    assert com.is_datetimelike_v_numeric(1, dt)
494    assert com.is_datetimelike_v_numeric(np.array([dt]), 1)
495    assert com.is_datetimelike_v_numeric(np.array([1]), dt)
496    assert com.is_datetimelike_v_numeric(np.array([dt]), np.array([1]))
497
498
499def test_needs_i8_conversion():
500    assert not com.needs_i8_conversion(str)
501    assert not com.needs_i8_conversion(np.int64)
502    assert not com.needs_i8_conversion(pd.Series([1, 2]))
503    assert not com.needs_i8_conversion(np.array(["a", "b"]))
504
505    assert com.needs_i8_conversion(np.datetime64)
506    assert com.needs_i8_conversion(pd.Series([], dtype="timedelta64[ns]"))
507    assert com.needs_i8_conversion(pd.DatetimeIndex(["2000"], tz="US/Eastern"))
508
509
510def test_is_numeric_dtype():
511    assert not com.is_numeric_dtype(str)
512    assert not com.is_numeric_dtype(np.datetime64)
513    assert not com.is_numeric_dtype(np.timedelta64)
514    assert not com.is_numeric_dtype(np.array(["a", "b"]))
515    assert not com.is_numeric_dtype(np.array([], dtype=np.timedelta64))
516
517    assert com.is_numeric_dtype(int)
518    assert com.is_numeric_dtype(float)
519    assert com.is_numeric_dtype(np.uint64)
520    assert com.is_numeric_dtype(pd.Series([1, 2]))
521    assert com.is_numeric_dtype(pd.Index([1, 2.0]))
522
523
524def test_is_string_like_dtype():
525    assert not com.is_string_like_dtype(object)
526    assert not com.is_string_like_dtype(pd.Series([1, 2]))
527
528    assert com.is_string_like_dtype(str)
529    assert com.is_string_like_dtype(np.array(["a", "b"]))
530
531
532def test_is_float_dtype():
533    assert not com.is_float_dtype(str)
534    assert not com.is_float_dtype(int)
535    assert not com.is_float_dtype(pd.Series([1, 2]))
536    assert not com.is_float_dtype(np.array(["a", "b"]))
537
538    assert com.is_float_dtype(float)
539    assert com.is_float_dtype(pd.Index([1, 2.0]))
540
541
542def test_is_bool_dtype():
543    assert not com.is_bool_dtype(int)
544    assert not com.is_bool_dtype(str)
545    assert not com.is_bool_dtype(pd.Series([1, 2]))
546    assert not com.is_bool_dtype(np.array(["a", "b"]))
547    assert not com.is_bool_dtype(pd.Index(["a", "b"]))
548    assert not com.is_bool_dtype("Int64")
549
550    assert com.is_bool_dtype(bool)
551    assert com.is_bool_dtype(np.bool_)
552    assert com.is_bool_dtype(np.array([True, False]))
553    assert com.is_bool_dtype(pd.Index([True, False]))
554
555    assert com.is_bool_dtype(pd.BooleanDtype())
556    assert com.is_bool_dtype(pd.array([True, False, None], dtype="boolean"))
557    assert com.is_bool_dtype("boolean")
558
559
560def test_is_bool_dtype_numpy_error():
561    # GH39010
562    assert not com.is_bool_dtype("0 - Name")
563
564
565@pytest.mark.filterwarnings("ignore:'is_extension_type' is deprecated:FutureWarning")
566@pytest.mark.parametrize(
567    "check_scipy", [False, pytest.param(True, marks=td.skip_if_no_scipy)]
568)
569def test_is_extension_type(check_scipy):
570    assert not com.is_extension_type([1, 2, 3])
571    assert not com.is_extension_type(np.array([1, 2, 3]))
572    assert not com.is_extension_type(pd.DatetimeIndex([1, 2, 3]))
573
574    cat = pd.Categorical([1, 2, 3])
575    assert com.is_extension_type(cat)
576    assert com.is_extension_type(pd.Series(cat))
577    assert com.is_extension_type(SparseArray([1, 2, 3]))
578    assert com.is_extension_type(pd.DatetimeIndex(["2000"], tz="US/Eastern"))
579
580    dtype = DatetimeTZDtype("ns", tz="US/Eastern")
581    s = pd.Series([], dtype=dtype)
582    assert com.is_extension_type(s)
583
584    if check_scipy:
585        import scipy.sparse
586
587        assert not com.is_extension_type(scipy.sparse.bsr_matrix([1, 2, 3]))
588
589
590def test_is_extension_type_deprecation():
591    with tm.assert_produces_warning(FutureWarning):
592        com.is_extension_type([1, 2, 3])
593
594
595@pytest.mark.parametrize(
596    "check_scipy", [False, pytest.param(True, marks=td.skip_if_no_scipy)]
597)
598def test_is_extension_array_dtype(check_scipy):
599    assert not com.is_extension_array_dtype([1, 2, 3])
600    assert not com.is_extension_array_dtype(np.array([1, 2, 3]))
601    assert not com.is_extension_array_dtype(pd.DatetimeIndex([1, 2, 3]))
602
603    cat = pd.Categorical([1, 2, 3])
604    assert com.is_extension_array_dtype(cat)
605    assert com.is_extension_array_dtype(pd.Series(cat))
606    assert com.is_extension_array_dtype(SparseArray([1, 2, 3]))
607    assert com.is_extension_array_dtype(pd.DatetimeIndex(["2000"], tz="US/Eastern"))
608
609    dtype = DatetimeTZDtype("ns", tz="US/Eastern")
610    s = pd.Series([], dtype=dtype)
611    assert com.is_extension_array_dtype(s)
612
613    if check_scipy:
614        import scipy.sparse
615
616        assert not com.is_extension_array_dtype(scipy.sparse.bsr_matrix([1, 2, 3]))
617
618
619def test_is_complex_dtype():
620    assert not com.is_complex_dtype(int)
621    assert not com.is_complex_dtype(str)
622    assert not com.is_complex_dtype(pd.Series([1, 2]))
623    assert not com.is_complex_dtype(np.array(["a", "b"]))
624
625    assert com.is_complex_dtype(np.complex_)
626    assert com.is_complex_dtype(complex)
627    assert com.is_complex_dtype(np.array([1 + 1j, 5]))
628
629
630@pytest.mark.parametrize(
631    "input_param,result",
632    [
633        (int, np.dtype(int)),
634        ("int32", np.dtype("int32")),
635        (float, np.dtype(float)),
636        ("float64", np.dtype("float64")),
637        (np.dtype("float64"), np.dtype("float64")),
638        (str, np.dtype(str)),
639        (pd.Series([1, 2], dtype=np.dtype("int16")), np.dtype("int16")),
640        (pd.Series(["a", "b"]), np.dtype(object)),
641        (pd.Index([1, 2]), np.dtype("int64")),
642        (pd.Index(["a", "b"]), np.dtype(object)),
643        ("category", "category"),
644        (pd.Categorical(["a", "b"]).dtype, CategoricalDtype(["a", "b"])),
645        (pd.Categorical(["a", "b"]), CategoricalDtype(["a", "b"])),
646        (pd.CategoricalIndex(["a", "b"]).dtype, CategoricalDtype(["a", "b"])),
647        (pd.CategoricalIndex(["a", "b"]), CategoricalDtype(["a", "b"])),
648        (CategoricalDtype(), CategoricalDtype()),
649        (CategoricalDtype(["a", "b"]), CategoricalDtype()),
650        (pd.DatetimeIndex([1, 2]), np.dtype("=M8[ns]")),
651        (pd.DatetimeIndex([1, 2]).dtype, np.dtype("=M8[ns]")),
652        ("<M8[ns]", np.dtype("<M8[ns]")),
653        ("datetime64[ns, Europe/London]", DatetimeTZDtype("ns", "Europe/London")),
654        (PeriodDtype(freq="D"), PeriodDtype(freq="D")),
655        ("period[D]", PeriodDtype(freq="D")),
656        (IntervalDtype(), IntervalDtype()),
657    ],
658)
659def test_get_dtype(input_param, result):
660    assert com.get_dtype(input_param) == result
661
662
663@pytest.mark.parametrize(
664    "input_param,expected_error_message",
665    [
666        (None, "Cannot deduce dtype from null object"),
667        (1, "data type not understood"),
668        (1.2, "data type not understood"),
669        # numpy dev changed from double-quotes to single quotes
670        ("random string", "data type [\"']random string[\"'] not understood"),
671        (pd.DataFrame([1, 2]), "data type not understood"),
672    ],
673)
674def test_get_dtype_fails(input_param, expected_error_message):
675    # python objects
676    # 2020-02-02 npdev changed error message
677    expected_error_message += f"|Cannot interpret '{input_param}' as a data type"
678    with pytest.raises(TypeError, match=expected_error_message):
679        com.get_dtype(input_param)
680
681
682@pytest.mark.parametrize(
683    "input_param,result",
684    [
685        (int, np.dtype(int).type),
686        ("int32", np.int32),
687        (float, np.dtype(float).type),
688        ("float64", np.float64),
689        (np.dtype("float64"), np.float64),
690        (str, np.dtype(str).type),
691        (pd.Series([1, 2], dtype=np.dtype("int16")), np.int16),
692        (pd.Series(["a", "b"]), np.object_),
693        (pd.Index([1, 2], dtype="int64"), np.int64),
694        (pd.Index(["a", "b"]), np.object_),
695        ("category", CategoricalDtypeType),
696        (pd.Categorical(["a", "b"]).dtype, CategoricalDtypeType),
697        (pd.Categorical(["a", "b"]), CategoricalDtypeType),
698        (pd.CategoricalIndex(["a", "b"]).dtype, CategoricalDtypeType),
699        (pd.CategoricalIndex(["a", "b"]), CategoricalDtypeType),
700        (pd.DatetimeIndex([1, 2]), np.datetime64),
701        (pd.DatetimeIndex([1, 2]).dtype, np.datetime64),
702        ("<M8[ns]", np.datetime64),
703        (pd.DatetimeIndex(["2000"], tz="Europe/London"), pd.Timestamp),
704        (pd.DatetimeIndex(["2000"], tz="Europe/London").dtype, pd.Timestamp),
705        ("datetime64[ns, Europe/London]", pd.Timestamp),
706        (PeriodDtype(freq="D"), pd.Period),
707        ("period[D]", pd.Period),
708        (IntervalDtype(), pd.Interval),
709        (None, type(None)),
710        (1, type(None)),
711        (1.2, type(None)),
712        (pd.DataFrame([1, 2]), type(None)),  # composite dtype
713    ],
714)
715def test__is_dtype_type(input_param, result):
716    assert com._is_dtype_type(input_param, lambda tipo: tipo == result)
717
718
719@pytest.mark.parametrize("val", [np.datetime64("NaT"), np.timedelta64("NaT")])
720@pytest.mark.parametrize("typ", [np.int64])
721def test_astype_nansafe(val, typ):
722    arr = np.array([val])
723
724    msg = "Cannot convert NaT values to integer"
725    with pytest.raises(ValueError, match=msg):
726        astype_nansafe(arr, dtype=typ)
727
728
729@pytest.mark.parametrize("from_type", [np.datetime64, np.timedelta64])
730@pytest.mark.parametrize(
731    "to_type",
732    [
733        np.uint8,
734        np.uint16,
735        np.uint32,
736        np.int8,
737        np.int16,
738        np.int32,
739        np.float16,
740        np.float32,
741    ],
742)
743def test_astype_datetime64_bad_dtype_raises(from_type, to_type):
744    arr = np.array([from_type("2018")])
745
746    with pytest.raises(TypeError, match="cannot astype"):
747        astype_nansafe(arr, dtype=to_type)
748
749
750@pytest.mark.parametrize("from_type", [np.datetime64, np.timedelta64])
751def test_astype_object_preserves_datetime_na(from_type):
752    arr = np.array([from_type("NaT")])
753    result = astype_nansafe(arr, dtype="object")
754
755    assert isna(result)[0]
756
757
758def test_validate_allhashable():
759    assert com.validate_all_hashable(1, "a") is None
760
761    with pytest.raises(TypeError, match="All elements must be hashable"):
762        com.validate_all_hashable([])
763
764    with pytest.raises(TypeError, match="list must be a hashable type"):
765        com.validate_all_hashable([], error_name="list")
766