1""" isort:skip_file """ 2import pickle 3 4import pytest 5 6dask = pytest.importorskip("dask") # isort:skip 7distributed = pytest.importorskip("distributed") # isort:skip 8 9from dask.distributed import Client, Lock 10from distributed.utils_test import cluster, gen_cluster 11from distributed.utils_test import loop 12from distributed.client import futures_of 13 14import xarray as xr 15from xarray.backends.locks import HDF5_LOCK, CombinedLock 16from xarray.tests.test_backends import ( 17 ON_WINDOWS, 18 create_tmp_file, 19 create_tmp_geotiff, 20 open_example_dataset, 21) 22from xarray.tests.test_dataset import create_test_data 23 24from . import ( 25 assert_allclose, 26 has_h5netcdf, 27 has_netCDF4, 28 requires_rasterio, 29 has_scipy, 30 requires_zarr, 31 requires_cfgrib, 32) 33 34# this is to stop isort throwing errors. May have been easier to just use 35# `isort:skip` in retrospect 36 37 38da = pytest.importorskip("dask.array") 39loop = loop # loop is an imported fixture, which flake8 has issues ack-ing 40 41 42@pytest.fixture 43def tmp_netcdf_filename(tmpdir): 44 return str(tmpdir.join("testfile.nc")) 45 46 47ENGINES = [] 48if has_scipy: 49 ENGINES.append("scipy") 50if has_netCDF4: 51 ENGINES.append("netcdf4") 52if has_h5netcdf: 53 ENGINES.append("h5netcdf") 54 55NC_FORMATS = { 56 "netcdf4": [ 57 "NETCDF3_CLASSIC", 58 "NETCDF3_64BIT_OFFSET", 59 "NETCDF3_64BIT_DATA", 60 "NETCDF4_CLASSIC", 61 "NETCDF4", 62 ], 63 "scipy": ["NETCDF3_CLASSIC", "NETCDF3_64BIT"], 64 "h5netcdf": ["NETCDF4"], 65} 66 67ENGINES_AND_FORMATS = [ 68 ("netcdf4", "NETCDF3_CLASSIC"), 69 ("netcdf4", "NETCDF4_CLASSIC"), 70 ("netcdf4", "NETCDF4"), 71 ("h5netcdf", "NETCDF4"), 72 ("scipy", "NETCDF3_64BIT"), 73] 74 75 76@pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) 77def test_dask_distributed_netcdf_roundtrip( 78 loop, tmp_netcdf_filename, engine, nc_format 79): 80 81 if engine not in ENGINES: 82 pytest.skip("engine not available") 83 84 chunks = {"dim1": 4, "dim2": 3, "dim3": 6} 85 86 with cluster() as (s, [a, b]): 87 with Client(s["address"], loop=loop): 88 89 original = create_test_data().chunk(chunks) 90 91 if engine == "scipy": 92 with pytest.raises(NotImplementedError): 93 original.to_netcdf( 94 tmp_netcdf_filename, engine=engine, format=nc_format 95 ) 96 return 97 98 original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format) 99 100 with xr.open_dataset( 101 tmp_netcdf_filename, chunks=chunks, engine=engine 102 ) as restored: 103 assert isinstance(restored.var1.data, da.Array) 104 computed = restored.compute() 105 assert_allclose(original, computed) 106 107 108@pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS) 109def test_dask_distributed_read_netcdf_integration_test( 110 loop, tmp_netcdf_filename, engine, nc_format 111): 112 113 if engine not in ENGINES: 114 pytest.skip("engine not available") 115 116 chunks = {"dim1": 4, "dim2": 3, "dim3": 6} 117 118 with cluster() as (s, [a, b]): 119 with Client(s["address"], loop=loop): 120 121 original = create_test_data() 122 original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format) 123 124 with xr.open_dataset( 125 tmp_netcdf_filename, chunks=chunks, engine=engine 126 ) as restored: 127 assert isinstance(restored.var1.data, da.Array) 128 computed = restored.compute() 129 assert_allclose(original, computed) 130 131 132@requires_zarr 133@pytest.mark.parametrize("consolidated", [True, False]) 134@pytest.mark.parametrize("compute", [True, False]) 135def test_dask_distributed_zarr_integration_test(loop, consolidated, compute) -> None: 136 if consolidated: 137 pytest.importorskip("zarr", minversion="2.2.1.dev2") 138 write_kwargs = {"consolidated": True} 139 read_kwargs = {"backend_kwargs": {"consolidated": True}} 140 else: 141 write_kwargs = read_kwargs = {} # type: ignore 142 chunks = {"dim1": 4, "dim2": 3, "dim3": 5} 143 with cluster() as (s, [a, b]): 144 with Client(s["address"], loop=loop): 145 original = create_test_data().chunk(chunks) 146 with create_tmp_file( 147 allow_cleanup_failure=ON_WINDOWS, suffix=".zarrc" 148 ) as filename: 149 maybe_futures = original.to_zarr( 150 filename, compute=compute, **write_kwargs 151 ) 152 if not compute: 153 maybe_futures.compute() 154 with xr.open_dataset( 155 filename, chunks="auto", engine="zarr", **read_kwargs 156 ) as restored: 157 assert isinstance(restored.var1.data, da.Array) 158 computed = restored.compute() 159 assert_allclose(original, computed) 160 161 162@requires_rasterio 163def test_dask_distributed_rasterio_integration_test(loop) -> None: 164 with create_tmp_geotiff() as (tmp_file, expected): 165 with cluster() as (s, [a, b]): 166 with pytest.warns(DeprecationWarning), Client(s["address"], loop=loop): 167 da_tiff = xr.open_rasterio(tmp_file, chunks={"band": 1}) 168 assert isinstance(da_tiff.data, da.Array) 169 actual = da_tiff.compute() 170 assert_allclose(actual, expected) 171 172 173@requires_cfgrib 174@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager") 175def test_dask_distributed_cfgrib_integration_test(loop) -> None: 176 with cluster() as (s, [a, b]): 177 with Client(s["address"], loop=loop): 178 with open_example_dataset( 179 "example.grib", engine="cfgrib", chunks={"time": 1} 180 ) as ds: 181 with open_example_dataset("example.grib", engine="cfgrib") as expected: 182 assert isinstance(ds["t"].data, da.Array) 183 actual = ds.compute() 184 assert_allclose(actual, expected) 185 186 187@gen_cluster(client=True) 188async def test_async(c, s, a, b) -> None: 189 x = create_test_data() 190 assert not dask.is_dask_collection(x) 191 y = x.chunk({"dim2": 4}) + 10 192 assert dask.is_dask_collection(y) 193 assert dask.is_dask_collection(y.var1) 194 assert dask.is_dask_collection(y.var2) 195 196 z = y.persist() 197 assert str(z) 198 199 assert dask.is_dask_collection(z) 200 assert dask.is_dask_collection(z.var1) 201 assert dask.is_dask_collection(z.var2) 202 assert len(y.__dask_graph__()) > len(z.__dask_graph__()) 203 204 assert not futures_of(y) 205 assert futures_of(z) 206 207 future = c.compute(z) 208 w = await future 209 assert not dask.is_dask_collection(w) 210 assert_allclose(x + 10, w) 211 212 assert s.tasks 213 214 215def test_hdf5_lock() -> None: 216 assert isinstance(HDF5_LOCK, dask.utils.SerializableLock) 217 218 219@gen_cluster(client=True) 220async def test_serializable_locks(c, s, a, b) -> None: 221 def f(x, lock=None): 222 with lock: 223 return x + 1 224 225 # note, the creation of Lock needs to be done inside a cluster 226 for lock in [ 227 HDF5_LOCK, 228 Lock(), 229 Lock("filename.nc"), 230 CombinedLock([HDF5_LOCK]), 231 CombinedLock([HDF5_LOCK, Lock("filename.nc")]), 232 ]: 233 234 futures = c.map(f, list(range(10)), lock=lock) 235 await c.gather(futures) 236 237 lock2 = pickle.loads(pickle.dumps(lock)) 238 assert type(lock) == type(lock2) 239