1import numpy as np 2import pandas as pd 3import pytest 4 5import xarray as xr 6 7cp = pytest.importorskip("cupy") 8 9 10@pytest.fixture 11def toy_weather_data(): 12 """Construct the example DataSet from the Toy weather data example. 13 14 http://xarray.pydata.org/en/stable/examples/weather-data.html 15 16 Here we construct the DataSet exactly as shown in the example and then 17 convert the numpy arrays to cupy. 18 19 """ 20 np.random.seed(123) 21 times = pd.date_range("2000-01-01", "2001-12-31", name="time") 22 annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28)) 23 24 base = 10 + 15 * annual_cycle.reshape(-1, 1) 25 tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3) 26 tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3) 27 28 ds = xr.Dataset( 29 { 30 "tmin": (("time", "location"), tmin_values), 31 "tmax": (("time", "location"), tmax_values), 32 }, 33 {"time": times, "location": ["IA", "IN", "IL"]}, 34 ) 35 36 ds.tmax.data = cp.asarray(ds.tmax.data) 37 ds.tmin.data = cp.asarray(ds.tmin.data) 38 39 return ds 40 41 42def test_cupy_import() -> None: 43 """Check the import worked.""" 44 assert cp 45 46 47def test_check_data_stays_on_gpu(toy_weather_data) -> None: 48 """Perform some operations and check the data stays on the GPU.""" 49 freeze = (toy_weather_data["tmin"] <= 0).groupby("time.month").mean("time") 50 assert isinstance(freeze.data, cp.ndarray) 51 52 53def test_where() -> None: 54 from xarray.core.duck_array_ops import where 55 56 data = cp.zeros(10) 57 58 output = where(data < 1, 1, data).all() 59 assert output 60 assert isinstance(output, cp.ndarray) 61