1""" isort:skip_file """
2import pickle
3
4import pytest
5
6dask = pytest.importorskip("dask")  # isort:skip
7distributed = pytest.importorskip("distributed")  # isort:skip
8
9from dask.distributed import Client, Lock
10from distributed.utils_test import cluster, gen_cluster
11from distributed.utils_test import loop
12from distributed.client import futures_of
13
14import xarray as xr
15from xarray.backends.locks import HDF5_LOCK, CombinedLock
16from xarray.tests.test_backends import (
17    ON_WINDOWS,
18    create_tmp_file,
19    create_tmp_geotiff,
20    open_example_dataset,
21)
22from xarray.tests.test_dataset import create_test_data
23
24from . import (
25    assert_allclose,
26    has_h5netcdf,
27    has_netCDF4,
28    requires_rasterio,
29    has_scipy,
30    requires_zarr,
31    requires_cfgrib,
32)
33
34# this is to stop isort throwing errors. May have been easier to just use
35# `isort:skip` in retrospect
36
37
38da = pytest.importorskip("dask.array")
39loop = loop  # loop is an imported fixture, which flake8 has issues ack-ing
40
41
42@pytest.fixture
43def tmp_netcdf_filename(tmpdir):
44    return str(tmpdir.join("testfile.nc"))
45
46
47ENGINES = []
48if has_scipy:
49    ENGINES.append("scipy")
50if has_netCDF4:
51    ENGINES.append("netcdf4")
52if has_h5netcdf:
53    ENGINES.append("h5netcdf")
54
55NC_FORMATS = {
56    "netcdf4": [
57        "NETCDF3_CLASSIC",
58        "NETCDF3_64BIT_OFFSET",
59        "NETCDF3_64BIT_DATA",
60        "NETCDF4_CLASSIC",
61        "NETCDF4",
62    ],
63    "scipy": ["NETCDF3_CLASSIC", "NETCDF3_64BIT"],
64    "h5netcdf": ["NETCDF4"],
65}
66
67ENGINES_AND_FORMATS = [
68    ("netcdf4", "NETCDF3_CLASSIC"),
69    ("netcdf4", "NETCDF4_CLASSIC"),
70    ("netcdf4", "NETCDF4"),
71    ("h5netcdf", "NETCDF4"),
72    ("scipy", "NETCDF3_64BIT"),
73]
74
75
76@pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS)
77def test_dask_distributed_netcdf_roundtrip(
78    loop, tmp_netcdf_filename, engine, nc_format
79):
80
81    if engine not in ENGINES:
82        pytest.skip("engine not available")
83
84    chunks = {"dim1": 4, "dim2": 3, "dim3": 6}
85
86    with cluster() as (s, [a, b]):
87        with Client(s["address"], loop=loop):
88
89            original = create_test_data().chunk(chunks)
90
91            if engine == "scipy":
92                with pytest.raises(NotImplementedError):
93                    original.to_netcdf(
94                        tmp_netcdf_filename, engine=engine, format=nc_format
95                    )
96                return
97
98            original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format)
99
100            with xr.open_dataset(
101                tmp_netcdf_filename, chunks=chunks, engine=engine
102            ) as restored:
103                assert isinstance(restored.var1.data, da.Array)
104                computed = restored.compute()
105                assert_allclose(original, computed)
106
107
108@pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS)
109def test_dask_distributed_read_netcdf_integration_test(
110    loop, tmp_netcdf_filename, engine, nc_format
111):
112
113    if engine not in ENGINES:
114        pytest.skip("engine not available")
115
116    chunks = {"dim1": 4, "dim2": 3, "dim3": 6}
117
118    with cluster() as (s, [a, b]):
119        with Client(s["address"], loop=loop):
120
121            original = create_test_data()
122            original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format)
123
124            with xr.open_dataset(
125                tmp_netcdf_filename, chunks=chunks, engine=engine
126            ) as restored:
127                assert isinstance(restored.var1.data, da.Array)
128                computed = restored.compute()
129                assert_allclose(original, computed)
130
131
132@requires_zarr
133@pytest.mark.parametrize("consolidated", [True, False])
134@pytest.mark.parametrize("compute", [True, False])
135def test_dask_distributed_zarr_integration_test(loop, consolidated, compute) -> None:
136    if consolidated:
137        pytest.importorskip("zarr", minversion="2.2.1.dev2")
138        write_kwargs = {"consolidated": True}
139        read_kwargs = {"backend_kwargs": {"consolidated": True}}
140    else:
141        write_kwargs = read_kwargs = {}  # type: ignore
142    chunks = {"dim1": 4, "dim2": 3, "dim3": 5}
143    with cluster() as (s, [a, b]):
144        with Client(s["address"], loop=loop):
145            original = create_test_data().chunk(chunks)
146            with create_tmp_file(
147                allow_cleanup_failure=ON_WINDOWS, suffix=".zarrc"
148            ) as filename:
149                maybe_futures = original.to_zarr(
150                    filename, compute=compute, **write_kwargs
151                )
152                if not compute:
153                    maybe_futures.compute()
154                with xr.open_dataset(
155                    filename, chunks="auto", engine="zarr", **read_kwargs
156                ) as restored:
157                    assert isinstance(restored.var1.data, da.Array)
158                    computed = restored.compute()
159                    assert_allclose(original, computed)
160
161
162@requires_rasterio
163def test_dask_distributed_rasterio_integration_test(loop) -> None:
164    with create_tmp_geotiff() as (tmp_file, expected):
165        with cluster() as (s, [a, b]):
166            with pytest.warns(DeprecationWarning), Client(s["address"], loop=loop):
167                da_tiff = xr.open_rasterio(tmp_file, chunks={"band": 1})
168                assert isinstance(da_tiff.data, da.Array)
169                actual = da_tiff.compute()
170                assert_allclose(actual, expected)
171
172
173@requires_cfgrib
174@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager")
175def test_dask_distributed_cfgrib_integration_test(loop) -> None:
176    with cluster() as (s, [a, b]):
177        with Client(s["address"], loop=loop):
178            with open_example_dataset(
179                "example.grib", engine="cfgrib", chunks={"time": 1}
180            ) as ds:
181                with open_example_dataset("example.grib", engine="cfgrib") as expected:
182                    assert isinstance(ds["t"].data, da.Array)
183                    actual = ds.compute()
184                    assert_allclose(actual, expected)
185
186
187@gen_cluster(client=True)
188async def test_async(c, s, a, b) -> None:
189    x = create_test_data()
190    assert not dask.is_dask_collection(x)
191    y = x.chunk({"dim2": 4}) + 10
192    assert dask.is_dask_collection(y)
193    assert dask.is_dask_collection(y.var1)
194    assert dask.is_dask_collection(y.var2)
195
196    z = y.persist()
197    assert str(z)
198
199    assert dask.is_dask_collection(z)
200    assert dask.is_dask_collection(z.var1)
201    assert dask.is_dask_collection(z.var2)
202    assert len(y.__dask_graph__()) > len(z.__dask_graph__())
203
204    assert not futures_of(y)
205    assert futures_of(z)
206
207    future = c.compute(z)
208    w = await future
209    assert not dask.is_dask_collection(w)
210    assert_allclose(x + 10, w)
211
212    assert s.tasks
213
214
215def test_hdf5_lock() -> None:
216    assert isinstance(HDF5_LOCK, dask.utils.SerializableLock)
217
218
219@gen_cluster(client=True)
220async def test_serializable_locks(c, s, a, b) -> None:
221    def f(x, lock=None):
222        with lock:
223            return x + 1
224
225    # note, the creation of Lock needs to be done inside a cluster
226    for lock in [
227        HDF5_LOCK,
228        Lock(),
229        Lock("filename.nc"),
230        CombinedLock([HDF5_LOCK]),
231        CombinedLock([HDF5_LOCK, Lock("filename.nc")]),
232    ]:
233
234        futures = c.map(f, list(range(10)), lock=lock)
235        await c.gather(futures)
236
237        lock2 = pickle.loads(pickle.dumps(lock))
238        assert type(lock) == type(lock2)
239