1import numpy as np
2import pandas as pd
3import pytest
4
5import xarray as xr
6
7cp = pytest.importorskip("cupy")
8
9
10@pytest.fixture
11def toy_weather_data():
12    """Construct the example DataSet from the Toy weather data example.
13
14    http://xarray.pydata.org/en/stable/examples/weather-data.html
15
16    Here we construct the DataSet exactly as shown in the example and then
17    convert the numpy arrays to cupy.
18
19    """
20    np.random.seed(123)
21    times = pd.date_range("2000-01-01", "2001-12-31", name="time")
22    annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28))
23
24    base = 10 + 15 * annual_cycle.reshape(-1, 1)
25    tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3)
26    tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3)
27
28    ds = xr.Dataset(
29        {
30            "tmin": (("time", "location"), tmin_values),
31            "tmax": (("time", "location"), tmax_values),
32        },
33        {"time": times, "location": ["IA", "IN", "IL"]},
34    )
35
36    ds.tmax.data = cp.asarray(ds.tmax.data)
37    ds.tmin.data = cp.asarray(ds.tmin.data)
38
39    return ds
40
41
42def test_cupy_import() -> None:
43    """Check the import worked."""
44    assert cp
45
46
47def test_check_data_stays_on_gpu(toy_weather_data) -> None:
48    """Perform some operations and check the data stays on the GPU."""
49    freeze = (toy_weather_data["tmin"] <= 0).groupby("time.month").mean("time")
50    assert isinstance(freeze.data, cp.ndarray)
51
52
53def test_where() -> None:
54    from xarray.core.duck_array_ops import where
55
56    data = cp.zeros(10)
57
58    output = where(data < 1, 1, data).all()
59    assert output
60    assert isinstance(output, cp.ndarray)
61