1"""
2Tests for DatetimeArray
3"""
4import operator
5
6import numpy as np
7import pytest
8
9from pandas.core.dtypes.dtypes import DatetimeTZDtype
10
11import pandas as pd
12from pandas import NaT
13import pandas._testing as tm
14from pandas.core.arrays import DatetimeArray
15from pandas.core.arrays.datetimes import sequence_to_dt64ns
16
17
18class TestDatetimeArrayConstructor:
19    def test_from_sequence_invalid_type(self):
20        mi = pd.MultiIndex.from_product([np.arange(5), np.arange(5)])
21        with pytest.raises(TypeError, match="Cannot create a DatetimeArray"):
22            DatetimeArray._from_sequence(mi)
23
24    def test_only_1dim_accepted(self):
25        arr = np.array([0, 1, 2, 3], dtype="M8[h]").astype("M8[ns]")
26
27        with pytest.raises(ValueError, match="Only 1-dimensional"):
28            # 3-dim, we allow 2D to sneak in for ops purposes GH#29853
29            DatetimeArray(arr.reshape(2, 2, 1))
30
31        with pytest.raises(ValueError, match="Only 1-dimensional"):
32            # 0-dim
33            DatetimeArray(arr[[0]].squeeze())
34
35    def test_freq_validation(self):
36        # GH#24623 check that invalid instances cannot be created with the
37        #  public constructor
38        arr = np.arange(5, dtype=np.int64) * 3600 * 10 ** 9
39
40        msg = (
41            "Inferred frequency H from passed values does not "
42            "conform to passed frequency W-SUN"
43        )
44        with pytest.raises(ValueError, match=msg):
45            DatetimeArray(arr, freq="W")
46
47    @pytest.mark.parametrize(
48        "meth",
49        [
50            DatetimeArray._from_sequence,
51            sequence_to_dt64ns,
52            pd.to_datetime,
53            pd.DatetimeIndex,
54        ],
55    )
56    def test_mixing_naive_tzaware_raises(self, meth):
57        # GH#24569
58        arr = np.array([pd.Timestamp("2000"), pd.Timestamp("2000", tz="CET")])
59
60        msg = (
61            "Cannot mix tz-aware with tz-naive values|"
62            "Tz-aware datetime.datetime cannot be converted "
63            "to datetime64 unless utc=True"
64        )
65
66        for obj in [arr, arr[::-1]]:
67            # check that we raise regardless of whether naive is found
68            #  before aware or vice-versa
69            with pytest.raises(ValueError, match=msg):
70                meth(obj)
71
72    def test_from_pandas_array(self):
73        arr = pd.array(np.arange(5, dtype=np.int64)) * 3600 * 10 ** 9
74
75        result = DatetimeArray._from_sequence(arr)._with_freq("infer")
76
77        expected = pd.date_range("1970-01-01", periods=5, freq="H")._data
78        tm.assert_datetime_array_equal(result, expected)
79
80    def test_mismatched_timezone_raises(self):
81        arr = DatetimeArray(
82            np.array(["2000-01-01T06:00:00"], dtype="M8[ns]"),
83            dtype=DatetimeTZDtype(tz="US/Central"),
84        )
85        dtype = DatetimeTZDtype(tz="US/Eastern")
86        with pytest.raises(TypeError, match="Timezone of the array"):
87            DatetimeArray(arr, dtype=dtype)
88
89    def test_non_array_raises(self):
90        with pytest.raises(ValueError, match="list"):
91            DatetimeArray([1, 2, 3])
92
93    def test_bool_dtype_raises(self):
94        arr = np.array([1, 2, 3], dtype="bool")
95
96        with pytest.raises(
97            ValueError, match="The dtype of 'values' is incorrect.*bool"
98        ):
99            DatetimeArray(arr)
100
101        msg = r"dtype bool cannot be converted to datetime64\[ns\]"
102        with pytest.raises(TypeError, match=msg):
103            DatetimeArray._from_sequence(arr)
104
105        with pytest.raises(TypeError, match=msg):
106            sequence_to_dt64ns(arr)
107
108        with pytest.raises(TypeError, match=msg):
109            pd.DatetimeIndex(arr)
110
111        with pytest.raises(TypeError, match=msg):
112            pd.to_datetime(arr)
113
114    def test_incorrect_dtype_raises(self):
115        with pytest.raises(ValueError, match="Unexpected value for 'dtype'."):
116            DatetimeArray(np.array([1, 2, 3], dtype="i8"), dtype="category")
117
118    def test_freq_infer_raises(self):
119        with pytest.raises(ValueError, match="Frequency inference"):
120            DatetimeArray(np.array([1, 2, 3], dtype="i8"), freq="infer")
121
122    def test_copy(self):
123        data = np.array([1, 2, 3], dtype="M8[ns]")
124        arr = DatetimeArray(data, copy=False)
125        assert arr._data is data
126
127        arr = DatetimeArray(data, copy=True)
128        assert arr._data is not data
129
130
131class TestDatetimeArrayComparisons:
132    # TODO: merge this into tests/arithmetic/test_datetime64 once it is
133    #  sufficiently robust
134
135    def test_cmp_dt64_arraylike_tznaive(self, all_compare_operators):
136        # arbitrary tz-naive DatetimeIndex
137        opname = all_compare_operators.strip("_")
138        op = getattr(operator, opname)
139
140        dti = pd.date_range("2016-01-1", freq="MS", periods=9, tz=None)
141        arr = DatetimeArray(dti)
142        assert arr.freq == dti.freq
143        assert arr.tz == dti.tz
144
145        right = dti
146
147        expected = np.ones(len(arr), dtype=bool)
148        if opname in ["ne", "gt", "lt"]:
149            # for these the comparisons should be all-False
150            expected = ~expected
151
152        result = op(arr, arr)
153        tm.assert_numpy_array_equal(result, expected)
154        for other in [right, np.array(right)]:
155            # TODO: add list and tuple, and object-dtype once those
156            #  are fixed in the constructor
157            result = op(arr, other)
158            tm.assert_numpy_array_equal(result, expected)
159
160            result = op(other, arr)
161            tm.assert_numpy_array_equal(result, expected)
162
163
164class TestDatetimeArray:
165    def test_astype_to_same(self):
166        arr = DatetimeArray._from_sequence(
167            ["2000"], dtype=DatetimeTZDtype(tz="US/Central")
168        )
169        result = arr.astype(DatetimeTZDtype(tz="US/Central"), copy=False)
170        assert result is arr
171
172    @pytest.mark.parametrize("dtype", ["datetime64[ns]", "datetime64[ns, UTC]"])
173    @pytest.mark.parametrize(
174        "other", ["datetime64[ns]", "datetime64[ns, UTC]", "datetime64[ns, CET]"]
175    )
176    def test_astype_copies(self, dtype, other):
177        # https://github.com/pandas-dev/pandas/pull/32490
178        s = pd.Series([1, 2], dtype=dtype)
179        orig = s.copy()
180        t = s.astype(other)
181        t[:] = pd.NaT
182        tm.assert_series_equal(s, orig)
183
184    @pytest.mark.parametrize("dtype", [int, np.int32, np.int64, "uint32", "uint64"])
185    def test_astype_int(self, dtype):
186        arr = DatetimeArray._from_sequence([pd.Timestamp("2000"), pd.Timestamp("2001")])
187        result = arr.astype(dtype)
188
189        if np.dtype(dtype).kind == "u":
190            expected_dtype = np.dtype("uint64")
191        else:
192            expected_dtype = np.dtype("int64")
193        expected = arr.astype(expected_dtype)
194
195        assert result.dtype == expected_dtype
196        tm.assert_numpy_array_equal(result, expected)
197
198    def test_tz_setter_raises(self):
199        arr = DatetimeArray._from_sequence(
200            ["2000"], dtype=DatetimeTZDtype(tz="US/Central")
201        )
202        with pytest.raises(AttributeError, match="tz_localize"):
203            arr.tz = "UTC"
204
205    def test_setitem_str_impute_tz(self, tz_naive_fixture):
206        # Like for getitem, if we are passed a naive-like string, we impute
207        #  our own timezone.
208        tz = tz_naive_fixture
209
210        data = np.array([1, 2, 3], dtype="M8[ns]")
211        dtype = data.dtype if tz is None else DatetimeTZDtype(tz=tz)
212        arr = DatetimeArray(data, dtype=dtype)
213        expected = arr.copy()
214
215        ts = pd.Timestamp("2020-09-08 16:50").tz_localize(tz)
216        setter = str(ts.tz_localize(None))
217
218        # Setting a scalar tznaive string
219        expected[0] = ts
220        arr[0] = setter
221        tm.assert_equal(arr, expected)
222
223        # Setting a listlike of tznaive strings
224        expected[1] = ts
225        arr[:2] = [setter, setter]
226        tm.assert_equal(arr, expected)
227
228    def test_setitem_different_tz_raises(self):
229        data = np.array([1, 2, 3], dtype="M8[ns]")
230        arr = DatetimeArray(data, copy=False, dtype=DatetimeTZDtype(tz="US/Central"))
231        with pytest.raises(TypeError, match="Cannot compare tz-naive and tz-aware"):
232            arr[0] = pd.Timestamp("2000")
233
234        with pytest.raises(ValueError, match="US/Central"):
235            arr[0] = pd.Timestamp("2000", tz="US/Eastern")
236
237    def test_setitem_clears_freq(self):
238        a = DatetimeArray(pd.date_range("2000", periods=2, freq="D", tz="US/Central"))
239        a[0] = pd.Timestamp("2000", tz="US/Central")
240        assert a.freq is None
241
242    @pytest.mark.parametrize(
243        "obj",
244        [
245            pd.Timestamp.now(),
246            pd.Timestamp.now().to_datetime64(),
247            pd.Timestamp.now().to_pydatetime(),
248        ],
249    )
250    def test_setitem_objects(self, obj):
251        # make sure we accept datetime64 and datetime in addition to Timestamp
252        dti = pd.date_range("2000", periods=2, freq="D")
253        arr = dti._data
254
255        arr[0] = obj
256        assert arr[0] == obj
257
258    def test_repeat_preserves_tz(self):
259        dti = pd.date_range("2000", periods=2, freq="D", tz="US/Central")
260        arr = DatetimeArray(dti)
261
262        repeated = arr.repeat([1, 1])
263
264        # preserves tz and values, but not freq
265        expected = DatetimeArray(arr.asi8, freq=None, dtype=arr.dtype)
266        tm.assert_equal(repeated, expected)
267
268    def test_value_counts_preserves_tz(self):
269        dti = pd.date_range("2000", periods=2, freq="D", tz="US/Central")
270        arr = DatetimeArray(dti).repeat([4, 3])
271
272        result = arr.value_counts()
273
274        # Note: not tm.assert_index_equal, since `freq`s do not match
275        assert result.index.equals(dti)
276
277        arr[-2] = pd.NaT
278        result = arr.value_counts()
279        expected = pd.Series([1, 4, 2], index=[pd.NaT, dti[0], dti[1]])
280        tm.assert_series_equal(result, expected)
281
282    @pytest.mark.parametrize("method", ["pad", "backfill"])
283    def test_fillna_preserves_tz(self, method):
284        dti = pd.date_range("2000-01-01", periods=5, freq="D", tz="US/Central")
285        arr = DatetimeArray(dti, copy=True)
286        arr[2] = pd.NaT
287
288        fill_val = dti[1] if method == "pad" else dti[3]
289        expected = DatetimeArray._from_sequence(
290            [dti[0], dti[1], fill_val, dti[3], dti[4]],
291            dtype=DatetimeTZDtype(tz="US/Central"),
292        )
293
294        result = arr.fillna(method=method)
295        tm.assert_extension_array_equal(result, expected)
296
297        # assert that arr and dti were not modified in-place
298        assert arr[2] is pd.NaT
299        assert dti[2] == pd.Timestamp("2000-01-03", tz="US/Central")
300
301    def test_array_interface_tz(self):
302        tz = "US/Central"
303        data = DatetimeArray(pd.date_range("2017", periods=2, tz=tz))
304        result = np.asarray(data)
305
306        expected = np.array(
307            [
308                pd.Timestamp("2017-01-01T00:00:00", tz=tz),
309                pd.Timestamp("2017-01-02T00:00:00", tz=tz),
310            ],
311            dtype=object,
312        )
313        tm.assert_numpy_array_equal(result, expected)
314
315        result = np.asarray(data, dtype=object)
316        tm.assert_numpy_array_equal(result, expected)
317
318        result = np.asarray(data, dtype="M8[ns]")
319
320        expected = np.array(
321            ["2017-01-01T06:00:00", "2017-01-02T06:00:00"], dtype="M8[ns]"
322        )
323        tm.assert_numpy_array_equal(result, expected)
324
325    def test_array_interface(self):
326        data = DatetimeArray(pd.date_range("2017", periods=2))
327        expected = np.array(
328            ["2017-01-01T00:00:00", "2017-01-02T00:00:00"], dtype="datetime64[ns]"
329        )
330
331        result = np.asarray(data)
332        tm.assert_numpy_array_equal(result, expected)
333
334        result = np.asarray(data, dtype=object)
335        expected = np.array(
336            [pd.Timestamp("2017-01-01T00:00:00"), pd.Timestamp("2017-01-02T00:00:00")],
337            dtype=object,
338        )
339        tm.assert_numpy_array_equal(result, expected)
340
341    @pytest.mark.parametrize("index", [True, False])
342    def test_searchsorted_different_tz(self, index):
343        data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
344        arr = DatetimeArray(data, freq="D").tz_localize("Asia/Tokyo")
345        if index:
346            arr = pd.Index(arr)
347
348        expected = arr.searchsorted(arr[2])
349        result = arr.searchsorted(arr[2].tz_convert("UTC"))
350        assert result == expected
351
352        expected = arr.searchsorted(arr[2:6])
353        result = arr.searchsorted(arr[2:6].tz_convert("UTC"))
354        tm.assert_equal(result, expected)
355
356    @pytest.mark.parametrize("index", [True, False])
357    def test_searchsorted_tzawareness_compat(self, index):
358        data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
359        arr = DatetimeArray(data, freq="D")
360        if index:
361            arr = pd.Index(arr)
362
363        mismatch = arr.tz_localize("Asia/Tokyo")
364
365        msg = "Cannot compare tz-naive and tz-aware datetime-like objects"
366        with pytest.raises(TypeError, match=msg):
367            arr.searchsorted(mismatch[0])
368        with pytest.raises(TypeError, match=msg):
369            arr.searchsorted(mismatch)
370
371        with pytest.raises(TypeError, match=msg):
372            mismatch.searchsorted(arr[0])
373        with pytest.raises(TypeError, match=msg):
374            mismatch.searchsorted(arr)
375
376    @pytest.mark.parametrize(
377        "other",
378        [
379            1,
380            np.int64(1),
381            1.0,
382            np.timedelta64("NaT"),
383            pd.Timedelta(days=2),
384            "invalid",
385            np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9,
386            np.arange(10).view("timedelta64[ns]") * 24 * 3600 * 10 ** 9,
387            pd.Timestamp.now().to_period("D"),
388        ],
389    )
390    @pytest.mark.parametrize("index", [True, False])
391    def test_searchsorted_invalid_types(self, other, index):
392        data = np.arange(10, dtype="i8") * 24 * 3600 * 10 ** 9
393        arr = DatetimeArray(data, freq="D")
394        if index:
395            arr = pd.Index(arr)
396
397        msg = "|".join(
398            [
399                "searchsorted requires compatible dtype or scalar",
400                "value should be a 'Timestamp', 'NaT', or array of those. Got",
401            ]
402        )
403        with pytest.raises(TypeError, match=msg):
404            arr.searchsorted(other)
405
406    def test_shift_fill_value(self):
407        dti = pd.date_range("2016-01-01", periods=3)
408
409        dta = dti._data
410        expected = DatetimeArray(np.roll(dta._data, 1))
411
412        fv = dta[-1]
413        for fill_value in [fv, fv.to_pydatetime(), fv.to_datetime64()]:
414            result = dta.shift(1, fill_value=fill_value)
415            tm.assert_datetime_array_equal(result, expected)
416
417        dta = dta.tz_localize("UTC")
418        expected = expected.tz_localize("UTC")
419        fv = dta[-1]
420        for fill_value in [fv, fv.to_pydatetime()]:
421            result = dta.shift(1, fill_value=fill_value)
422            tm.assert_datetime_array_equal(result, expected)
423
424    def test_shift_value_tzawareness_mismatch(self):
425        dti = pd.date_range("2016-01-01", periods=3)
426
427        dta = dti._data
428
429        fv = dta[-1].tz_localize("UTC")
430        for invalid in [fv, fv.to_pydatetime()]:
431            with pytest.raises(TypeError, match="Cannot compare"):
432                dta.shift(1, fill_value=invalid)
433
434        dta = dta.tz_localize("UTC")
435        fv = dta[-1].tz_localize(None)
436        for invalid in [fv, fv.to_pydatetime(), fv.to_datetime64()]:
437            with pytest.raises(TypeError, match="Cannot compare"):
438                dta.shift(1, fill_value=invalid)
439
440    def test_shift_requires_tzmatch(self):
441        # since filling is setitem-like, we require a matching timezone,
442        #  not just matching tzawawreness
443        dti = pd.date_range("2016-01-01", periods=3, tz="UTC")
444        dta = dti._data
445
446        fill_value = pd.Timestamp("2020-10-18 18:44", tz="US/Pacific")
447
448        msg = "Timezones don't match. 'UTC' != 'US/Pacific'"
449        with pytest.raises(ValueError, match=msg):
450            dta.shift(1, fill_value=fill_value)
451
452
453class TestSequenceToDT64NS:
454    def test_tz_dtype_mismatch_raises(self):
455        arr = DatetimeArray._from_sequence(
456            ["2000"], dtype=DatetimeTZDtype(tz="US/Central")
457        )
458        with pytest.raises(TypeError, match="data is already tz-aware"):
459            sequence_to_dt64ns(arr, dtype=DatetimeTZDtype(tz="UTC"))
460
461    def test_tz_dtype_matches(self):
462        arr = DatetimeArray._from_sequence(
463            ["2000"], dtype=DatetimeTZDtype(tz="US/Central")
464        )
465        result, _, _ = sequence_to_dt64ns(arr, dtype=DatetimeTZDtype(tz="US/Central"))
466        tm.assert_numpy_array_equal(arr._data, result)
467
468
469class TestReductions:
470    @pytest.fixture
471    def arr1d(self, tz_naive_fixture):
472        tz = tz_naive_fixture
473        dtype = DatetimeTZDtype(tz=tz) if tz is not None else np.dtype("M8[ns]")
474        arr = DatetimeArray._from_sequence(
475            [
476                "2000-01-03",
477                "2000-01-03",
478                "NaT",
479                "2000-01-02",
480                "2000-01-05",
481                "2000-01-04",
482            ],
483            dtype=dtype,
484        )
485        return arr
486
487    def test_min_max(self, arr1d):
488        arr = arr1d
489        tz = arr.tz
490
491        result = arr.min()
492        expected = pd.Timestamp("2000-01-02", tz=tz)
493        assert result == expected
494
495        result = arr.max()
496        expected = pd.Timestamp("2000-01-05", tz=tz)
497        assert result == expected
498
499        result = arr.min(skipna=False)
500        assert result is pd.NaT
501
502        result = arr.max(skipna=False)
503        assert result is pd.NaT
504
505    @pytest.mark.parametrize("tz", [None, "US/Central"])
506    @pytest.mark.parametrize("skipna", [True, False])
507    def test_min_max_empty(self, skipna, tz):
508        dtype = DatetimeTZDtype(tz=tz) if tz is not None else np.dtype("M8[ns]")
509        arr = DatetimeArray._from_sequence([], dtype=dtype)
510        result = arr.min(skipna=skipna)
511        assert result is pd.NaT
512
513        result = arr.max(skipna=skipna)
514        assert result is pd.NaT
515
516    @pytest.mark.parametrize("tz", [None, "US/Central"])
517    @pytest.mark.parametrize("skipna", [True, False])
518    def test_median_empty(self, skipna, tz):
519        dtype = DatetimeTZDtype(tz=tz) if tz is not None else np.dtype("M8[ns]")
520        arr = DatetimeArray._from_sequence([], dtype=dtype)
521        result = arr.median(skipna=skipna)
522        assert result is pd.NaT
523
524        arr = arr.reshape(0, 3)
525        result = arr.median(axis=0, skipna=skipna)
526        expected = type(arr)._from_sequence([pd.NaT, pd.NaT, pd.NaT], dtype=arr.dtype)
527        tm.assert_equal(result, expected)
528
529        result = arr.median(axis=1, skipna=skipna)
530        expected = type(arr)._from_sequence([], dtype=arr.dtype)
531        tm.assert_equal(result, expected)
532
533    def test_median(self, arr1d):
534        arr = arr1d
535
536        result = arr.median()
537        assert result == arr[0]
538        result = arr.median(skipna=False)
539        assert result is pd.NaT
540
541        result = arr.dropna().median(skipna=False)
542        assert result == arr[0]
543
544        result = arr.median(axis=0)
545        assert result == arr[0]
546
547    def test_median_axis(self, arr1d):
548        arr = arr1d
549        assert arr.median(axis=0) == arr.median()
550        assert arr.median(axis=0, skipna=False) is pd.NaT
551
552        msg = r"abs\(axis\) must be less than ndim"
553        with pytest.raises(ValueError, match=msg):
554            arr.median(axis=1)
555
556    @pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning")
557    def test_median_2d(self, arr1d):
558        arr = arr1d.reshape(1, -1)
559
560        # axis = None
561        assert arr.median() == arr1d.median()
562        assert arr.median(skipna=False) is pd.NaT
563
564        # axis = 0
565        result = arr.median(axis=0)
566        expected = arr1d
567        tm.assert_equal(result, expected)
568
569        # Since column 3 is all-NaT, we get NaT there with or without skipna
570        result = arr.median(axis=0, skipna=False)
571        expected = arr1d
572        tm.assert_equal(result, expected)
573
574        # axis = 1
575        result = arr.median(axis=1)
576        expected = type(arr)._from_sequence([arr1d.median()])
577        tm.assert_equal(result, expected)
578
579        result = arr.median(axis=1, skipna=False)
580        expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype)
581        tm.assert_equal(result, expected)
582
583    def test_mean(self, arr1d):
584        arr = arr1d
585
586        # manually verified result
587        expected = arr[0] + 0.4 * pd.Timedelta(days=1)
588
589        result = arr.mean()
590        assert result == expected
591        result = arr.mean(skipna=False)
592        assert result is pd.NaT
593
594        result = arr.dropna().mean(skipna=False)
595        assert result == expected
596
597        result = arr.mean(axis=0)
598        assert result == expected
599
600    def test_mean_2d(self):
601        dti = pd.date_range("2016-01-01", periods=6, tz="US/Pacific")
602        dta = dti._data.reshape(3, 2)
603
604        result = dta.mean(axis=0)
605        expected = dta[1]
606        tm.assert_datetime_array_equal(result, expected)
607
608        result = dta.mean(axis=1)
609        expected = dta[:, 0] + pd.Timedelta(hours=12)
610        tm.assert_datetime_array_equal(result, expected)
611
612        result = dta.mean(axis=None)
613        expected = dti.mean()
614        assert result == expected
615
616    @pytest.mark.parametrize("skipna", [True, False])
617    def test_mean_empty(self, arr1d, skipna):
618        arr = arr1d[:0]
619
620        assert arr.mean(skipna=skipna) is NaT
621
622        arr2d = arr.reshape(0, 3)
623        result = arr2d.mean(axis=0, skipna=skipna)
624        expected = DatetimeArray._from_sequence([NaT, NaT, NaT], dtype=arr.dtype)
625        tm.assert_datetime_array_equal(result, expected)
626
627        result = arr2d.mean(axis=1, skipna=skipna)
628        expected = arr  # i.e. 1D, empty
629        tm.assert_datetime_array_equal(result, expected)
630
631        result = arr2d.mean(axis=None, skipna=skipna)
632        assert result is NaT
633