1from distutils.version import LooseVersion
2
3import numpy as np
4import pandas as pd
5import pytest
6
7import xarray as xr
8
9from . import (
10    assert_array_equal,
11    assert_chunks_equal,
12    assert_equal,
13    assert_identical,
14    raise_if_dask_computes,
15    requires_cftime,
16    requires_dask,
17)
18
19
20class TestDatetimeAccessor:
21    @pytest.fixture(autouse=True)
22    def setup(self):
23        nt = 100
24        data = np.random.rand(10, 10, nt)
25        lons = np.linspace(0, 11, 10)
26        lats = np.linspace(0, 20, 10)
27        self.times = pd.date_range(start="2000/01/01", freq="H", periods=nt)
28
29        self.data = xr.DataArray(
30            data,
31            coords=[lons, lats, self.times],
32            dims=["lon", "lat", "time"],
33            name="data",
34        )
35
36        self.times_arr = np.random.choice(self.times, size=(10, 10, nt))
37        self.times_data = xr.DataArray(
38            self.times_arr,
39            coords=[lons, lats, self.times],
40            dims=["lon", "lat", "time"],
41            name="data",
42        )
43
44    @pytest.mark.parametrize(
45        "field",
46        [
47            "year",
48            "month",
49            "day",
50            "hour",
51            "minute",
52            "second",
53            "microsecond",
54            "nanosecond",
55            "week",
56            "weekofyear",
57            "dayofweek",
58            "weekday",
59            "dayofyear",
60            "quarter",
61            "date",
62            "time",
63            "is_month_start",
64            "is_month_end",
65            "is_quarter_start",
66            "is_quarter_end",
67            "is_year_start",
68            "is_year_end",
69            "is_leap_year",
70        ],
71    )
72    def test_field_access(self, field) -> None:
73
74        if field in ["week", "weekofyear"]:
75            data = self.times.isocalendar()["week"]
76        else:
77            data = getattr(self.times, field)
78
79        expected = xr.DataArray(data, name=field, coords=[self.times], dims=["time"])
80
81        if field in ["week", "weekofyear"]:
82            with pytest.warns(
83                FutureWarning, match="dt.weekofyear and dt.week have been deprecated"
84            ):
85                actual = getattr(self.data.time.dt, field)
86        else:
87            actual = getattr(self.data.time.dt, field)
88
89        assert_equal(expected, actual)
90
91    @pytest.mark.parametrize(
92        "field, pandas_field",
93        [
94            ("year", "year"),
95            ("week", "week"),
96            ("weekday", "day"),
97        ],
98    )
99    def test_isocalendar(self, field, pandas_field) -> None:
100
101        # pandas isocalendar has dtypy UInt32Dtype, convert to Int64
102        expected = pd.Int64Index(getattr(self.times.isocalendar(), pandas_field))
103        expected = xr.DataArray(
104            expected, name=field, coords=[self.times], dims=["time"]
105        )
106
107        actual = self.data.time.dt.isocalendar()[field]
108        assert_equal(expected, actual)
109
110    def test_strftime(self) -> None:
111        assert (
112            "2000-01-01 01:00:00" == self.data.time.dt.strftime("%Y-%m-%d %H:%M:%S")[1]
113        )
114
115    def test_not_datetime_type(self) -> None:
116        nontime_data = self.data.copy()
117        int_data = np.arange(len(self.data.time)).astype("int8")
118        nontime_data = nontime_data.assign_coords(time=int_data)
119        with pytest.raises(TypeError, match=r"dt"):
120            nontime_data.time.dt
121
122    @pytest.mark.filterwarnings("ignore:dt.weekofyear and dt.week have been deprecated")
123    @requires_dask
124    @pytest.mark.parametrize(
125        "field",
126        [
127            "year",
128            "month",
129            "day",
130            "hour",
131            "minute",
132            "second",
133            "microsecond",
134            "nanosecond",
135            "week",
136            "weekofyear",
137            "dayofweek",
138            "weekday",
139            "dayofyear",
140            "quarter",
141            "date",
142            "time",
143            "is_month_start",
144            "is_month_end",
145            "is_quarter_start",
146            "is_quarter_end",
147            "is_year_start",
148            "is_year_end",
149            "is_leap_year",
150        ],
151    )
152    def test_dask_field_access(self, field) -> None:
153        import dask.array as da
154
155        expected = getattr(self.times_data.dt, field)
156
157        dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50))
158        dask_times_2d = xr.DataArray(
159            dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data"
160        )
161
162        with raise_if_dask_computes():
163            actual = getattr(dask_times_2d.dt, field)
164
165        assert isinstance(actual.data, da.Array)
166        assert_chunks_equal(actual, dask_times_2d)
167        assert_equal(actual.compute(), expected.compute())
168
169    @requires_dask
170    @pytest.mark.parametrize(
171        "field",
172        [
173            "year",
174            "week",
175            "weekday",
176        ],
177    )
178    def test_isocalendar_dask(self, field) -> None:
179        import dask.array as da
180
181        expected = getattr(self.times_data.dt.isocalendar(), field)
182
183        dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50))
184        dask_times_2d = xr.DataArray(
185            dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data"
186        )
187
188        with raise_if_dask_computes():
189            actual = dask_times_2d.dt.isocalendar()[field]
190
191        assert isinstance(actual.data, da.Array)
192        assert_chunks_equal(actual, dask_times_2d)
193        assert_equal(actual.compute(), expected.compute())
194
195    @requires_dask
196    @pytest.mark.parametrize(
197        "method, parameters",
198        [
199            ("floor", "D"),
200            ("ceil", "D"),
201            ("round", "D"),
202            ("strftime", "%Y-%m-%d %H:%M:%S"),
203        ],
204    )
205    def test_dask_accessor_method(self, method, parameters) -> None:
206        import dask.array as da
207
208        expected = getattr(self.times_data.dt, method)(parameters)
209        dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50))
210        dask_times_2d = xr.DataArray(
211            dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data"
212        )
213
214        with raise_if_dask_computes():
215            actual = getattr(dask_times_2d.dt, method)(parameters)
216
217        assert isinstance(actual.data, da.Array)
218        assert_chunks_equal(actual, dask_times_2d)
219        assert_equal(actual.compute(), expected.compute())
220
221    def test_seasons(self) -> None:
222        dates = pd.date_range(start="2000/01/01", freq="M", periods=12)
223        dates = xr.DataArray(dates)
224        seasons = xr.DataArray(
225            [
226                "DJF",
227                "DJF",
228                "MAM",
229                "MAM",
230                "MAM",
231                "JJA",
232                "JJA",
233                "JJA",
234                "SON",
235                "SON",
236                "SON",
237                "DJF",
238            ]
239        )
240
241        assert_array_equal(seasons.values, dates.dt.season.values)
242
243    @pytest.mark.parametrize(
244        "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")]
245    )
246    def test_accessor_method(self, method, parameters) -> None:
247        dates = pd.date_range("2014-01-01", "2014-05-01", freq="H")
248        xdates = xr.DataArray(dates, dims=["time"])
249        expected = getattr(dates, method)(parameters)
250        actual = getattr(xdates.dt, method)(parameters)
251        assert_array_equal(expected, actual)
252
253
254class TestTimedeltaAccessor:
255    @pytest.fixture(autouse=True)
256    def setup(self):
257        nt = 100
258        data = np.random.rand(10, 10, nt)
259        lons = np.linspace(0, 11, 10)
260        lats = np.linspace(0, 20, 10)
261        self.times = pd.timedelta_range(start="1 day", freq="6H", periods=nt)
262
263        self.data = xr.DataArray(
264            data,
265            coords=[lons, lats, self.times],
266            dims=["lon", "lat", "time"],
267            name="data",
268        )
269
270        self.times_arr = np.random.choice(self.times, size=(10, 10, nt))
271        self.times_data = xr.DataArray(
272            self.times_arr,
273            coords=[lons, lats, self.times],
274            dims=["lon", "lat", "time"],
275            name="data",
276        )
277
278    def test_not_datetime_type(self) -> None:
279        nontime_data = self.data.copy()
280        int_data = np.arange(len(self.data.time)).astype("int8")
281        nontime_data = nontime_data.assign_coords(time=int_data)
282        with pytest.raises(TypeError, match=r"dt"):
283            nontime_data.time.dt
284
285    @pytest.mark.parametrize(
286        "field", ["days", "seconds", "microseconds", "nanoseconds"]
287    )
288    def test_field_access(self, field) -> None:
289        expected = xr.DataArray(
290            getattr(self.times, field), name=field, coords=[self.times], dims=["time"]
291        )
292        actual = getattr(self.data.time.dt, field)
293        assert_equal(expected, actual)
294
295    @pytest.mark.parametrize(
296        "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")]
297    )
298    def test_accessor_methods(self, method, parameters) -> None:
299        dates = pd.timedelta_range(start="1 day", end="30 days", freq="6H")
300        xdates = xr.DataArray(dates, dims=["time"])
301        expected = getattr(dates, method)(parameters)
302        actual = getattr(xdates.dt, method)(parameters)
303        assert_array_equal(expected, actual)
304
305    @requires_dask
306    @pytest.mark.parametrize(
307        "field", ["days", "seconds", "microseconds", "nanoseconds"]
308    )
309    def test_dask_field_access(self, field) -> None:
310        import dask.array as da
311
312        expected = getattr(self.times_data.dt, field)
313
314        dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50))
315        dask_times_2d = xr.DataArray(
316            dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data"
317        )
318
319        with raise_if_dask_computes():
320            actual = getattr(dask_times_2d.dt, field)
321
322        assert isinstance(actual.data, da.Array)
323        assert_chunks_equal(actual, dask_times_2d)
324        assert_equal(actual, expected)
325
326    @requires_dask
327    @pytest.mark.parametrize(
328        "method, parameters", [("floor", "D"), ("ceil", "D"), ("round", "D")]
329    )
330    def test_dask_accessor_method(self, method, parameters) -> None:
331        import dask.array as da
332
333        expected = getattr(self.times_data.dt, method)(parameters)
334        dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50))
335        dask_times_2d = xr.DataArray(
336            dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data"
337        )
338
339        with raise_if_dask_computes():
340            actual = getattr(dask_times_2d.dt, method)(parameters)
341
342        assert isinstance(actual.data, da.Array)
343        assert_chunks_equal(actual, dask_times_2d)
344        assert_equal(actual.compute(), expected.compute())
345
346
347_CFTIME_CALENDARS = [
348    "365_day",
349    "360_day",
350    "julian",
351    "all_leap",
352    "366_day",
353    "gregorian",
354    "proleptic_gregorian",
355]
356_NT = 100
357
358
359@pytest.fixture(params=_CFTIME_CALENDARS)
360def calendar(request):
361    return request.param
362
363
364@pytest.fixture()
365def times(calendar):
366    import cftime
367
368    return cftime.num2date(
369        np.arange(_NT),
370        units="hours since 2000-01-01",
371        calendar=calendar,
372        only_use_cftime_datetimes=True,
373    )
374
375
376@pytest.fixture()
377def data(times):
378    data = np.random.rand(10, 10, _NT)
379    lons = np.linspace(0, 11, 10)
380    lats = np.linspace(0, 20, 10)
381    return xr.DataArray(
382        data, coords=[lons, lats, times], dims=["lon", "lat", "time"], name="data"
383    )
384
385
386@pytest.fixture()
387def times_3d(times):
388    lons = np.linspace(0, 11, 10)
389    lats = np.linspace(0, 20, 10)
390    times_arr = np.random.choice(times, size=(10, 10, _NT))
391    return xr.DataArray(
392        times_arr, coords=[lons, lats, times], dims=["lon", "lat", "time"], name="data"
393    )
394
395
396@requires_cftime
397@pytest.mark.parametrize(
398    "field", ["year", "month", "day", "hour", "dayofyear", "dayofweek"]
399)
400def test_field_access(data, field) -> None:
401    if field == "dayofyear" or field == "dayofweek":
402        pytest.importorskip("cftime", minversion="1.0.2.1")
403    result = getattr(data.time.dt, field)
404    expected = xr.DataArray(
405        getattr(xr.coding.cftimeindex.CFTimeIndex(data.time.values), field),
406        name=field,
407        coords=data.time.coords,
408        dims=data.time.dims,
409    )
410
411    assert_equal(result, expected)
412
413
414@requires_cftime
415def test_isocalendar_cftime(data) -> None:
416
417    with pytest.raises(
418        AttributeError, match=r"'CFTimeIndex' object has no attribute 'isocalendar'"
419    ):
420        data.time.dt.isocalendar()
421
422
423@requires_cftime
424def test_date_cftime(data) -> None:
425
426    with pytest.raises(
427        AttributeError,
428        match=r"'CFTimeIndex' object has no attribute `date`. Consider using the floor method instead, for instance: `.time.dt.floor\('D'\)`.",
429    ):
430        data.time.dt.date()
431
432
433@requires_cftime
434@pytest.mark.filterwarnings("ignore::RuntimeWarning")
435def test_cftime_strftime_access(data) -> None:
436    """compare cftime formatting against datetime formatting"""
437    date_format = "%Y%m%d%H"
438    result = data.time.dt.strftime(date_format)
439    datetime_array = xr.DataArray(
440        xr.coding.cftimeindex.CFTimeIndex(data.time.values).to_datetimeindex(),
441        name="stftime",
442        coords=data.time.coords,
443        dims=data.time.dims,
444    )
445    expected = datetime_array.dt.strftime(date_format)
446    assert_equal(result, expected)
447
448
449@requires_cftime
450@requires_dask
451@pytest.mark.parametrize(
452    "field", ["year", "month", "day", "hour", "dayofyear", "dayofweek"]
453)
454def test_dask_field_access_1d(data, field) -> None:
455    import dask.array as da
456
457    if field == "dayofyear" or field == "dayofweek":
458        pytest.importorskip("cftime", minversion="1.0.2.1")
459    expected = xr.DataArray(
460        getattr(xr.coding.cftimeindex.CFTimeIndex(data.time.values), field),
461        name=field,
462        dims=["time"],
463    )
464    times = xr.DataArray(data.time.values, dims=["time"]).chunk({"time": 50})
465    result = getattr(times.dt, field)
466    assert isinstance(result.data, da.Array)
467    assert result.chunks == times.chunks
468    assert_equal(result.compute(), expected)
469
470
471@requires_cftime
472@requires_dask
473@pytest.mark.parametrize(
474    "field", ["year", "month", "day", "hour", "dayofyear", "dayofweek"]
475)
476def test_dask_field_access(times_3d, data, field) -> None:
477    import dask.array as da
478
479    if field == "dayofyear" or field == "dayofweek":
480        pytest.importorskip("cftime", minversion="1.0.2.1")
481    expected = xr.DataArray(
482        getattr(
483            xr.coding.cftimeindex.CFTimeIndex(times_3d.values.ravel()), field
484        ).reshape(times_3d.shape),
485        name=field,
486        coords=times_3d.coords,
487        dims=times_3d.dims,
488    )
489    times_3d = times_3d.chunk({"lon": 5, "lat": 5, "time": 50})
490    result = getattr(times_3d.dt, field)
491    assert isinstance(result.data, da.Array)
492    assert result.chunks == times_3d.chunks
493    assert_equal(result.compute(), expected)
494
495
496@pytest.fixture()
497def cftime_date_type(calendar):
498    from .test_coding_times import _all_cftime_date_types
499
500    return _all_cftime_date_types()[calendar]
501
502
503@requires_cftime
504def test_seasons(cftime_date_type) -> None:
505    dates = xr.DataArray(
506        np.array([cftime_date_type(2000, month, 15) for month in range(1, 13)])
507    )
508    seasons = xr.DataArray(
509        [
510            "DJF",
511            "DJF",
512            "MAM",
513            "MAM",
514            "MAM",
515            "JJA",
516            "JJA",
517            "JJA",
518            "SON",
519            "SON",
520            "SON",
521            "DJF",
522        ]
523    )
524
525    assert_array_equal(seasons.values, dates.dt.season.values)
526
527
528@pytest.fixture
529def cftime_rounding_dataarray(cftime_date_type):
530    return xr.DataArray(
531        [
532            [cftime_date_type(1, 1, 1, 1), cftime_date_type(1, 1, 1, 15)],
533            [cftime_date_type(1, 1, 1, 23), cftime_date_type(1, 1, 2, 1)],
534        ]
535    )
536
537
538@requires_cftime
539@requires_dask
540@pytest.mark.parametrize("use_dask", [False, True])
541def test_cftime_floor_accessor(
542    cftime_rounding_dataarray, cftime_date_type, use_dask
543) -> None:
544    import dask.array as da
545
546    freq = "D"
547    expected = xr.DataArray(
548        [
549            [cftime_date_type(1, 1, 1, 0), cftime_date_type(1, 1, 1, 0)],
550            [cftime_date_type(1, 1, 1, 0), cftime_date_type(1, 1, 2, 0)],
551        ],
552        name="floor",
553    )
554
555    if use_dask:
556        chunks = {"dim_0": 1}
557        # Currently a compute is done to inspect a single value of the array
558        # if it is of object dtype to check if it is a cftime.datetime (if not
559        # we raise an error when using the dt accessor).
560        with raise_if_dask_computes(max_computes=1):
561            result = cftime_rounding_dataarray.chunk(chunks).dt.floor(freq)
562        expected = expected.chunk(chunks)
563        assert isinstance(result.data, da.Array)
564        assert result.chunks == expected.chunks
565    else:
566        result = cftime_rounding_dataarray.dt.floor(freq)
567
568    assert_identical(result, expected)
569
570
571@requires_cftime
572@requires_dask
573@pytest.mark.parametrize("use_dask", [False, True])
574def test_cftime_ceil_accessor(
575    cftime_rounding_dataarray, cftime_date_type, use_dask
576) -> None:
577    import dask.array as da
578
579    freq = "D"
580    expected = xr.DataArray(
581        [
582            [cftime_date_type(1, 1, 2, 0), cftime_date_type(1, 1, 2, 0)],
583            [cftime_date_type(1, 1, 2, 0), cftime_date_type(1, 1, 3, 0)],
584        ],
585        name="ceil",
586    )
587
588    if use_dask:
589        chunks = {"dim_0": 1}
590        # Currently a compute is done to inspect a single value of the array
591        # if it is of object dtype to check if it is a cftime.datetime (if not
592        # we raise an error when using the dt accessor).
593        with raise_if_dask_computes(max_computes=1):
594            result = cftime_rounding_dataarray.chunk(chunks).dt.ceil(freq)
595        expected = expected.chunk(chunks)
596        assert isinstance(result.data, da.Array)
597        assert result.chunks == expected.chunks
598    else:
599        result = cftime_rounding_dataarray.dt.ceil(freq)
600
601    assert_identical(result, expected)
602
603
604@requires_cftime
605@requires_dask
606@pytest.mark.parametrize("use_dask", [False, True])
607def test_cftime_round_accessor(
608    cftime_rounding_dataarray, cftime_date_type, use_dask
609) -> None:
610    import dask.array as da
611
612    freq = "D"
613    expected = xr.DataArray(
614        [
615            [cftime_date_type(1, 1, 1, 0), cftime_date_type(1, 1, 2, 0)],
616            [cftime_date_type(1, 1, 2, 0), cftime_date_type(1, 1, 2, 0)],
617        ],
618        name="round",
619    )
620
621    if use_dask:
622        chunks = {"dim_0": 1}
623        # Currently a compute is done to inspect a single value of the array
624        # if it is of object dtype to check if it is a cftime.datetime (if not
625        # we raise an error when using the dt accessor).
626        with raise_if_dask_computes(max_computes=1):
627            result = cftime_rounding_dataarray.chunk(chunks).dt.round(freq)
628        expected = expected.chunk(chunks)
629        assert isinstance(result.data, da.Array)
630        assert result.chunks == expected.chunks
631    else:
632        result = cftime_rounding_dataarray.dt.round(freq)
633
634    assert_identical(result, expected)
635