1import operator
2import pickle
3import sys
4from contextlib import suppress
5from textwrap import dedent
6
7import numpy as np
8import pandas as pd
9import pytest
10
11import xarray as xr
12import xarray.ufuncs as xu
13from xarray import DataArray, Dataset, Variable
14from xarray.core import duck_array_ops
15from xarray.core.pycompat import dask_version
16from xarray.testing import assert_chunks_equal
17from xarray.tests import mock
18
19from ..core.duck_array_ops import lazy_array_equiv
20from . import (
21    assert_allclose,
22    assert_array_equal,
23    assert_equal,
24    assert_frame_equal,
25    assert_identical,
26    raise_if_dask_computes,
27    requires_pint,
28    requires_scipy_or_netCDF4,
29)
30from .test_backends import create_tmp_file
31
32dask = pytest.importorskip("dask")
33da = pytest.importorskip("dask.array")
34dd = pytest.importorskip("dask.dataframe")
35
36ON_WINDOWS = sys.platform == "win32"
37
38
39def test_raise_if_dask_computes():
40    data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2))
41    with pytest.raises(RuntimeError, match=r"Too many computes"):
42        with raise_if_dask_computes():
43            data.compute()
44
45
46class DaskTestCase:
47    def assertLazyAnd(self, expected, actual, test):
48        with dask.config.set(scheduler="synchronous"):
49            test(actual, expected)
50
51        if isinstance(actual, Dataset):
52            for k, v in actual.variables.items():
53                if k in actual.dims:
54                    assert isinstance(v.data, np.ndarray)
55                else:
56                    assert isinstance(v.data, da.Array)
57        elif isinstance(actual, DataArray):
58            assert isinstance(actual.data, da.Array)
59            for k, v in actual.coords.items():
60                if k in actual.dims:
61                    assert isinstance(v.data, np.ndarray)
62                else:
63                    assert isinstance(v.data, da.Array)
64        elif isinstance(actual, Variable):
65            assert isinstance(actual.data, da.Array)
66        else:
67            assert False
68
69
70class TestVariable(DaskTestCase):
71    def assertLazyAndIdentical(self, expected, actual):
72        self.assertLazyAnd(expected, actual, assert_identical)
73
74    def assertLazyAndAllClose(self, expected, actual):
75        self.assertLazyAnd(expected, actual, assert_allclose)
76
77    @pytest.fixture(autouse=True)
78    def setUp(self):
79        self.values = np.random.RandomState(0).randn(4, 6)
80        self.data = da.from_array(self.values, chunks=(2, 2))
81
82        self.eager_var = Variable(("x", "y"), self.values)
83        self.lazy_var = Variable(("x", "y"), self.data)
84
85    def test_basics(self):
86        v = self.lazy_var
87        assert self.data is v.data
88        assert self.data.chunks == v.chunks
89        assert_array_equal(self.values, v)
90
91    def test_copy(self):
92        self.assertLazyAndIdentical(self.eager_var, self.lazy_var.copy())
93        self.assertLazyAndIdentical(self.eager_var, self.lazy_var.copy(deep=True))
94
95    def test_chunk(self):
96        for chunks, expected in [
97            ({}, ((2, 2), (2, 2, 2))),
98            (3, ((3, 1), (3, 3))),
99            ({"x": 3, "y": 3}, ((3, 1), (3, 3))),
100            ({"x": 3}, ((3, 1), (2, 2, 2))),
101            ({"x": (3, 1)}, ((3, 1), (2, 2, 2))),
102        ]:
103            rechunked = self.lazy_var.chunk(chunks)
104            assert rechunked.chunks == expected
105            self.assertLazyAndIdentical(self.eager_var, rechunked)
106
107            expected_chunksizes = {
108                dim: chunks for dim, chunks in zip(self.lazy_var.dims, expected)
109            }
110            assert rechunked.chunksizes == expected_chunksizes
111
112    def test_indexing(self):
113        u = self.eager_var
114        v = self.lazy_var
115        self.assertLazyAndIdentical(u[0], v[0])
116        self.assertLazyAndIdentical(u[:1], v[:1])
117        self.assertLazyAndIdentical(u[[0, 1], [0, 1, 2]], v[[0, 1], [0, 1, 2]])
118
119    @pytest.mark.skipif(dask_version < "2021.04.1", reason="Requires dask >= 2021.04.1")
120    @pytest.mark.parametrize(
121        "expected_data, index",
122        [
123            (da.array([99, 2, 3, 4]), 0),
124            (da.array([99, 99, 99, 4]), slice(2, None, -1)),
125            (da.array([99, 99, 3, 99]), [0, -1, 1]),
126            (da.array([99, 99, 99, 4]), np.arange(3)),
127            (da.array([1, 99, 99, 99]), [False, True, True, True]),
128            (da.array([1, 99, 99, 99]), np.arange(4) > 0),
129            (da.array([99, 99, 99, 99]), Variable(("x"), da.array([1, 2, 3, 4])) > 0),
130        ],
131    )
132    def test_setitem_dask_array(self, expected_data, index):
133        arr = Variable(("x"), da.array([1, 2, 3, 4]))
134        expected = Variable(("x"), expected_data)
135        arr[index] = 99
136        assert_identical(arr, expected)
137
138    @pytest.mark.skipif(dask_version >= "2021.04.1", reason="Requires dask < 2021.04.1")
139    def test_setitem_dask_array_error(self):
140        with pytest.raises(TypeError, match=r"stored in a dask array"):
141            v = self.lazy_var
142            v[:1] = 0
143
144    def test_squeeze(self):
145        u = self.eager_var
146        v = self.lazy_var
147        self.assertLazyAndIdentical(u[0].squeeze(), v[0].squeeze())
148
149    def test_equals(self):
150        v = self.lazy_var
151        assert v.equals(v)
152        assert isinstance(v.data, da.Array)
153        assert v.identical(v)
154        assert isinstance(v.data, da.Array)
155
156    def test_transpose(self):
157        u = self.eager_var
158        v = self.lazy_var
159        self.assertLazyAndIdentical(u.T, v.T)
160
161    def test_shift(self):
162        u = self.eager_var
163        v = self.lazy_var
164        self.assertLazyAndIdentical(u.shift(x=2), v.shift(x=2))
165        self.assertLazyAndIdentical(u.shift(x=-2), v.shift(x=-2))
166        assert v.data.chunks == v.shift(x=1).data.chunks
167
168    def test_roll(self):
169        u = self.eager_var
170        v = self.lazy_var
171        self.assertLazyAndIdentical(u.roll(x=2), v.roll(x=2))
172        assert v.data.chunks == v.roll(x=1).data.chunks
173
174    def test_unary_op(self):
175        u = self.eager_var
176        v = self.lazy_var
177        self.assertLazyAndIdentical(-u, -v)
178        self.assertLazyAndIdentical(abs(u), abs(v))
179        self.assertLazyAndIdentical(u.round(), v.round())
180
181    def test_binary_op(self):
182        u = self.eager_var
183        v = self.lazy_var
184        self.assertLazyAndIdentical(2 * u, 2 * v)
185        self.assertLazyAndIdentical(u + u, v + v)
186        self.assertLazyAndIdentical(u[0] + u, v[0] + v)
187
188    def test_repr(self):
189        expected = dedent(
190            """\
191            <xarray.Variable (x: 4, y: 6)>
192            {!r}""".format(
193                self.lazy_var.data
194            )
195        )
196        assert expected == repr(self.lazy_var)
197
198    def test_pickle(self):
199        # Test that pickling/unpickling does not convert the dask
200        # backend to numpy
201        a1 = Variable(["x"], build_dask_array("x"))
202        a1.compute()
203        assert not a1._in_memory
204        assert kernel_call_count == 1
205        a2 = pickle.loads(pickle.dumps(a1))
206        assert kernel_call_count == 1
207        assert_identical(a1, a2)
208        assert not a1._in_memory
209        assert not a2._in_memory
210
211    def test_reduce(self):
212        u = self.eager_var
213        v = self.lazy_var
214        self.assertLazyAndAllClose(u.mean(), v.mean())
215        self.assertLazyAndAllClose(u.std(), v.std())
216        with raise_if_dask_computes():
217            actual = v.argmax(dim="x")
218        self.assertLazyAndAllClose(u.argmax(dim="x"), actual)
219        with raise_if_dask_computes():
220            actual = v.argmin(dim="x")
221        self.assertLazyAndAllClose(u.argmin(dim="x"), actual)
222        self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
223        self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x"))
224        with pytest.raises(NotImplementedError, match=r"only works along an axis"):
225            v.median()
226        with pytest.raises(NotImplementedError, match=r"only works along an axis"):
227            v.median(v.dims)
228        with raise_if_dask_computes():
229            v.reduce(duck_array_ops.mean)
230
231    def test_missing_values(self):
232        values = np.array([0, 1, np.nan, 3])
233        data = da.from_array(values, chunks=(2,))
234
235        eager_var = Variable("x", values)
236        lazy_var = Variable("x", data)
237        self.assertLazyAndIdentical(eager_var, lazy_var.fillna(lazy_var))
238        self.assertLazyAndIdentical(Variable("x", range(4)), lazy_var.fillna(2))
239        self.assertLazyAndIdentical(eager_var.count(), lazy_var.count())
240
241    def test_concat(self):
242        u = self.eager_var
243        v = self.lazy_var
244        self.assertLazyAndIdentical(u, Variable.concat([v[:2], v[2:]], "x"))
245        self.assertLazyAndIdentical(u[:2], Variable.concat([v[0], v[1]], "x"))
246        self.assertLazyAndIdentical(u[:2], Variable.concat([u[0], v[1]], "x"))
247        self.assertLazyAndIdentical(u[:2], Variable.concat([v[0], u[1]], "x"))
248        self.assertLazyAndIdentical(
249            u[:3], Variable.concat([v[[0, 2]], v[[1]]], "x", positions=[[0, 2], [1]])
250        )
251
252    def test_missing_methods(self):
253        v = self.lazy_var
254        try:
255            v.argsort()
256        except NotImplementedError as err:
257            assert "dask" in str(err)
258        try:
259            v[0].item()
260        except NotImplementedError as err:
261            assert "dask" in str(err)
262
263    @pytest.mark.filterwarnings("ignore::FutureWarning")
264    def test_univariate_ufunc(self):
265        u = self.eager_var
266        v = self.lazy_var
267        self.assertLazyAndAllClose(np.sin(u), xu.sin(v))
268
269    @pytest.mark.filterwarnings("ignore::FutureWarning")
270    def test_bivariate_ufunc(self):
271        u = self.eager_var
272        v = self.lazy_var
273        self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(v, 0))
274        self.assertLazyAndAllClose(np.maximum(u, 0), xu.maximum(0, v))
275
276    def test_compute(self):
277        u = self.eager_var
278        v = self.lazy_var
279
280        assert dask.is_dask_collection(v)
281        (v2,) = dask.compute(v + 1)
282        assert not dask.is_dask_collection(v2)
283
284        assert ((u + 1).data == v2.data).all()
285
286    def test_persist(self):
287        u = self.eager_var
288        v = self.lazy_var + 1
289
290        (v2,) = dask.persist(v)
291        assert v is not v2
292        assert len(v2.__dask_graph__()) < len(v.__dask_graph__())
293        assert v2.__dask_keys__() == v.__dask_keys__()
294        assert dask.is_dask_collection(v)
295        assert dask.is_dask_collection(v2)
296
297        self.assertLazyAndAllClose(u + 1, v)
298        self.assertLazyAndAllClose(u + 1, v2)
299
300    @requires_pint
301    def test_tokenize_duck_dask_array(self):
302        import pint
303
304        unit_registry = pint.UnitRegistry()
305
306        q = unit_registry.Quantity(self.data, "meter")
307        variable = xr.Variable(("x", "y"), q)
308
309        token = dask.base.tokenize(variable)
310        post_op = variable + 5 * unit_registry.meter
311
312        assert dask.base.tokenize(variable) != dask.base.tokenize(post_op)
313        # Immutability check
314        assert dask.base.tokenize(variable) == token
315
316
317class TestDataArrayAndDataset(DaskTestCase):
318    def assertLazyAndIdentical(self, expected, actual):
319        self.assertLazyAnd(expected, actual, assert_identical)
320
321    def assertLazyAndAllClose(self, expected, actual):
322        self.assertLazyAnd(expected, actual, assert_allclose)
323
324    def assertLazyAndEqual(self, expected, actual):
325        self.assertLazyAnd(expected, actual, assert_equal)
326
327    @pytest.fixture(autouse=True)
328    def setUp(self):
329        self.values = np.random.randn(4, 6)
330        self.data = da.from_array(self.values, chunks=(2, 2))
331        self.eager_array = DataArray(
332            self.values, coords={"x": range(4)}, dims=("x", "y"), name="foo"
333        )
334        self.lazy_array = DataArray(
335            self.data, coords={"x": range(4)}, dims=("x", "y"), name="foo"
336        )
337
338    def test_chunk(self):
339        for chunks, expected in [
340            ({}, ((2, 2), (2, 2, 2))),
341            (3, ((3, 1), (3, 3))),
342            ({"x": 3, "y": 3}, ((3, 1), (3, 3))),
343            ({"x": 3}, ((3, 1), (2, 2, 2))),
344            ({"x": (3, 1)}, ((3, 1), (2, 2, 2))),
345        ]:
346            # Test DataArray
347            rechunked = self.lazy_array.chunk(chunks)
348            assert rechunked.chunks == expected
349            self.assertLazyAndIdentical(self.eager_array, rechunked)
350
351            expected_chunksizes = {
352                dim: chunks for dim, chunks in zip(self.lazy_array.dims, expected)
353            }
354            assert rechunked.chunksizes == expected_chunksizes
355
356            # Test Dataset
357            lazy_dataset = self.lazy_array.to_dataset()
358            eager_dataset = self.eager_array.to_dataset()
359            expected_chunksizes = {
360                dim: chunks for dim, chunks in zip(lazy_dataset.dims, expected)
361            }
362            rechunked = lazy_dataset.chunk(chunks)
363
364            # Dataset.chunks has a different return type to DataArray.chunks - see issue #5843
365            assert rechunked.chunks == expected_chunksizes
366            self.assertLazyAndIdentical(eager_dataset, rechunked)
367
368            assert rechunked.chunksizes == expected_chunksizes
369
370    def test_rechunk(self):
371        chunked = self.eager_array.chunk({"x": 2}).chunk({"y": 2})
372        assert chunked.chunks == ((2,) * 2, (2,) * 3)
373        self.assertLazyAndIdentical(self.lazy_array, chunked)
374
375    def test_new_chunk(self):
376        chunked = self.eager_array.chunk()
377        assert chunked.data.name.startswith("xarray-<this-array>")
378
379    def test_lazy_dataset(self):
380        lazy_ds = Dataset({"foo": (("x", "y"), self.data)})
381        assert isinstance(lazy_ds.foo.variable.data, da.Array)
382
383    def test_lazy_array(self):
384        u = self.eager_array
385        v = self.lazy_array
386
387        self.assertLazyAndAllClose(u, v)
388        self.assertLazyAndAllClose(-u, -v)
389        self.assertLazyAndAllClose(u.T, v.T)
390        self.assertLazyAndAllClose(u.mean(), v.mean())
391        self.assertLazyAndAllClose(1 + u, 1 + v)
392
393        actual = xr.concat([v[:2], v[2:]], "x")
394        self.assertLazyAndAllClose(u, actual)
395
396    def test_compute(self):
397        u = self.eager_array
398        v = self.lazy_array
399
400        assert dask.is_dask_collection(v)
401        (v2,) = dask.compute(v + 1)
402        assert not dask.is_dask_collection(v2)
403
404        assert ((u + 1).data == v2.data).all()
405
406    def test_persist(self):
407        u = self.eager_array
408        v = self.lazy_array + 1
409
410        (v2,) = dask.persist(v)
411        assert v is not v2
412        assert len(v2.__dask_graph__()) < len(v.__dask_graph__())
413        assert v2.__dask_keys__() == v.__dask_keys__()
414        assert dask.is_dask_collection(v)
415        assert dask.is_dask_collection(v2)
416
417        self.assertLazyAndAllClose(u + 1, v)
418        self.assertLazyAndAllClose(u + 1, v2)
419
420    def test_concat_loads_variables(self):
421        # Test that concat() computes not-in-memory variables at most once
422        # and loads them in the output, while leaving the input unaltered.
423        d1 = build_dask_array("d1")
424        c1 = build_dask_array("c1")
425        d2 = build_dask_array("d2")
426        c2 = build_dask_array("c2")
427        d3 = build_dask_array("d3")
428        c3 = build_dask_array("c3")
429        # Note: c is a non-index coord.
430        # Index coords are loaded by IndexVariable.__init__.
431        ds1 = Dataset(data_vars={"d": ("x", d1)}, coords={"c": ("x", c1)})
432        ds2 = Dataset(data_vars={"d": ("x", d2)}, coords={"c": ("x", c2)})
433        ds3 = Dataset(data_vars={"d": ("x", d3)}, coords={"c": ("x", c3)})
434
435        assert kernel_call_count == 0
436        out = xr.concat(
437            [ds1, ds2, ds3], dim="n", data_vars="different", coords="different"
438        )
439        # each kernel is computed exactly once
440        assert kernel_call_count == 6
441        # variables are loaded in the output
442        assert isinstance(out["d"].data, np.ndarray)
443        assert isinstance(out["c"].data, np.ndarray)
444
445        out = xr.concat([ds1, ds2, ds3], dim="n", data_vars="all", coords="all")
446        # no extra kernel calls
447        assert kernel_call_count == 6
448        assert isinstance(out["d"].data, dask.array.Array)
449        assert isinstance(out["c"].data, dask.array.Array)
450
451        out = xr.concat([ds1, ds2, ds3], dim="n", data_vars=["d"], coords=["c"])
452        # no extra kernel calls
453        assert kernel_call_count == 6
454        assert isinstance(out["d"].data, dask.array.Array)
455        assert isinstance(out["c"].data, dask.array.Array)
456
457        out = xr.concat([ds1, ds2, ds3], dim="n", data_vars=[], coords=[])
458        # variables are loaded once as we are validing that they're identical
459        assert kernel_call_count == 12
460        assert isinstance(out["d"].data, np.ndarray)
461        assert isinstance(out["c"].data, np.ndarray)
462
463        out = xr.concat(
464            [ds1, ds2, ds3],
465            dim="n",
466            data_vars="different",
467            coords="different",
468            compat="identical",
469        )
470        # compat=identical doesn't do any more kernel calls than compat=equals
471        assert kernel_call_count == 18
472        assert isinstance(out["d"].data, np.ndarray)
473        assert isinstance(out["c"].data, np.ndarray)
474
475        # When the test for different turns true halfway through,
476        # stop computing variables as it would not have any benefit
477        ds4 = Dataset(data_vars={"d": ("x", [2.0])}, coords={"c": ("x", [2.0])})
478        out = xr.concat(
479            [ds1, ds2, ds4, ds3], dim="n", data_vars="different", coords="different"
480        )
481        # the variables of ds1 and ds2 were computed, but those of ds3 didn't
482        assert kernel_call_count == 22
483        assert isinstance(out["d"].data, dask.array.Array)
484        assert isinstance(out["c"].data, dask.array.Array)
485        # the data of ds1 and ds2 was loaded into numpy and then
486        # concatenated to the data of ds3. Thus, only ds3 is computed now.
487        out.compute()
488        assert kernel_call_count == 24
489
490        # Finally, test that originals are unaltered
491        assert ds1["d"].data is d1
492        assert ds1["c"].data is c1
493        assert ds2["d"].data is d2
494        assert ds2["c"].data is c2
495        assert ds3["d"].data is d3
496        assert ds3["c"].data is c3
497
498        # now check that concat() is correctly using dask name equality to skip loads
499        out = xr.concat(
500            [ds1, ds1, ds1], dim="n", data_vars="different", coords="different"
501        )
502        assert kernel_call_count == 24
503        # variables are not loaded in the output
504        assert isinstance(out["d"].data, dask.array.Array)
505        assert isinstance(out["c"].data, dask.array.Array)
506
507        out = xr.concat(
508            [ds1, ds1, ds1], dim="n", data_vars=[], coords=[], compat="identical"
509        )
510        assert kernel_call_count == 24
511        # variables are not loaded in the output
512        assert isinstance(out["d"].data, dask.array.Array)
513        assert isinstance(out["c"].data, dask.array.Array)
514
515        out = xr.concat(
516            [ds1, ds2.compute(), ds3],
517            dim="n",
518            data_vars="all",
519            coords="different",
520            compat="identical",
521        )
522        # c1,c3 must be computed for comparison since c2 is numpy;
523        # d2 is computed too
524        assert kernel_call_count == 28
525
526        out = xr.concat(
527            [ds1, ds2.compute(), ds3],
528            dim="n",
529            data_vars="all",
530            coords="all",
531            compat="identical",
532        )
533        # no extra computes
534        assert kernel_call_count == 30
535
536        # Finally, test that originals are unaltered
537        assert ds1["d"].data is d1
538        assert ds1["c"].data is c1
539        assert ds2["d"].data is d2
540        assert ds2["c"].data is c2
541        assert ds3["d"].data is d3
542        assert ds3["c"].data is c3
543
544    def test_groupby(self):
545        u = self.eager_array
546        v = self.lazy_array
547
548        expected = u.groupby("x").mean(...)
549        with raise_if_dask_computes():
550            actual = v.groupby("x").mean(...)
551        self.assertLazyAndAllClose(expected, actual)
552
553    def test_rolling(self):
554        u = self.eager_array
555        v = self.lazy_array
556
557        expected = u.rolling(x=2).mean()
558        with raise_if_dask_computes():
559            actual = v.rolling(x=2).mean()
560        self.assertLazyAndAllClose(expected, actual)
561
562    def test_groupby_first(self):
563        u = self.eager_array
564        v = self.lazy_array
565
566        for coords in [u.coords, v.coords]:
567            coords["ab"] = ("x", ["a", "a", "b", "b"])
568        with pytest.raises(NotImplementedError, match=r"dask"):
569            v.groupby("ab").first()
570        expected = u.groupby("ab").first()
571        with raise_if_dask_computes():
572            actual = v.groupby("ab").first(skipna=False)
573        self.assertLazyAndAllClose(expected, actual)
574
575    def test_reindex(self):
576        u = self.eager_array.assign_coords(y=range(6))
577        v = self.lazy_array.assign_coords(y=range(6))
578
579        for kwargs in [
580            {"x": [2, 3, 4]},
581            {"x": [1, 100, 2, 101, 3]},
582            {"x": [2.5, 3, 3.5], "y": [2, 2.5, 3]},
583        ]:
584            expected = u.reindex(**kwargs)
585            actual = v.reindex(**kwargs)
586            self.assertLazyAndAllClose(expected, actual)
587
588    def test_to_dataset_roundtrip(self):
589        u = self.eager_array
590        v = self.lazy_array
591
592        expected = u.assign_coords(x=u["x"])
593        self.assertLazyAndEqual(expected, v.to_dataset("x").to_array("x"))
594
595    def test_merge(self):
596        def duplicate_and_merge(array):
597            return xr.merge([array, array.rename("bar")]).to_array()
598
599        expected = duplicate_and_merge(self.eager_array)
600        actual = duplicate_and_merge(self.lazy_array)
601        self.assertLazyAndEqual(expected, actual)
602
603    @pytest.mark.filterwarnings("ignore::FutureWarning")
604    def test_ufuncs(self):
605        u = self.eager_array
606        v = self.lazy_array
607        self.assertLazyAndAllClose(np.sin(u), xu.sin(v))
608
609    def test_where_dispatching(self):
610        a = np.arange(10)
611        b = a > 3
612        x = da.from_array(a, 5)
613        y = da.from_array(b, 5)
614        expected = DataArray(a).where(b)
615        self.assertLazyAndEqual(expected, DataArray(a).where(y))
616        self.assertLazyAndEqual(expected, DataArray(x).where(b))
617        self.assertLazyAndEqual(expected, DataArray(x).where(y))
618
619    def test_simultaneous_compute(self):
620        ds = Dataset({"foo": ("x", range(5)), "bar": ("x", range(5))}).chunk()
621
622        count = [0]
623
624        def counting_get(*args, **kwargs):
625            count[0] += 1
626            return dask.get(*args, **kwargs)
627
628        ds.load(scheduler=counting_get)
629
630        assert count[0] == 1
631
632    def test_stack(self):
633        data = da.random.normal(size=(2, 3, 4), chunks=(1, 3, 4))
634        arr = DataArray(data, dims=("w", "x", "y"))
635        stacked = arr.stack(z=("x", "y"))
636        z = pd.MultiIndex.from_product([np.arange(3), np.arange(4)], names=["x", "y"])
637        expected = DataArray(data.reshape(2, -1), {"z": z}, dims=["w", "z"])
638        assert stacked.data.chunks == expected.data.chunks
639        self.assertLazyAndEqual(expected, stacked)
640
641    def test_dot(self):
642        eager = self.eager_array.dot(self.eager_array[0])
643        lazy = self.lazy_array.dot(self.lazy_array[0])
644        self.assertLazyAndAllClose(eager, lazy)
645
646    def test_dataarray_repr(self):
647        data = build_dask_array("data")
648        nonindex_coord = build_dask_array("coord")
649        a = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)})
650        expected = dedent(
651            """\
652            <xarray.DataArray 'data' (x: 1)>
653            {!r}
654            Coordinates:
655                y        (x) int64 dask.array<chunksize=(1,), meta=np.ndarray>
656            Dimensions without coordinates: x""".format(
657                data
658            )
659        )
660        assert expected == repr(a)
661        assert kernel_call_count == 0  # should not evaluate dask array
662
663    def test_dataset_repr(self):
664        data = build_dask_array("data")
665        nonindex_coord = build_dask_array("coord")
666        ds = Dataset(data_vars={"a": ("x", data)}, coords={"y": ("x", nonindex_coord)})
667        expected = dedent(
668            """\
669            <xarray.Dataset>
670            Dimensions:  (x: 1)
671            Coordinates:
672                y        (x) int64 dask.array<chunksize=(1,), meta=np.ndarray>
673            Dimensions without coordinates: x
674            Data variables:
675                a        (x) int64 dask.array<chunksize=(1,), meta=np.ndarray>"""
676        )
677        assert expected == repr(ds)
678        assert kernel_call_count == 0  # should not evaluate dask array
679
680    def test_dataarray_pickle(self):
681        # Test that pickling/unpickling converts the dask backend
682        # to numpy in neither the data variable nor the non-index coords
683        data = build_dask_array("data")
684        nonindex_coord = build_dask_array("coord")
685        a1 = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)})
686        a1.compute()
687        assert not a1._in_memory
688        assert not a1.coords["y"]._in_memory
689        assert kernel_call_count == 2
690        a2 = pickle.loads(pickle.dumps(a1))
691        assert kernel_call_count == 2
692        assert_identical(a1, a2)
693        assert not a1._in_memory
694        assert not a2._in_memory
695        assert not a1.coords["y"]._in_memory
696        assert not a2.coords["y"]._in_memory
697
698    def test_dataset_pickle(self):
699        # Test that pickling/unpickling converts the dask backend
700        # to numpy in neither the data variables nor the non-index coords
701        data = build_dask_array("data")
702        nonindex_coord = build_dask_array("coord")
703        ds1 = Dataset(data_vars={"a": ("x", data)}, coords={"y": ("x", nonindex_coord)})
704        ds1.compute()
705        assert not ds1["a"]._in_memory
706        assert not ds1["y"]._in_memory
707        assert kernel_call_count == 2
708        ds2 = pickle.loads(pickle.dumps(ds1))
709        assert kernel_call_count == 2
710        assert_identical(ds1, ds2)
711        assert not ds1["a"]._in_memory
712        assert not ds2["a"]._in_memory
713        assert not ds1["y"]._in_memory
714        assert not ds2["y"]._in_memory
715
716    def test_dataarray_getattr(self):
717        # ipython/jupyter does a long list of getattr() calls to when trying to
718        # represent an object.
719        # Make sure we're not accidentally computing dask variables.
720        data = build_dask_array("data")
721        nonindex_coord = build_dask_array("coord")
722        a = DataArray(data, dims=["x"], coords={"y": ("x", nonindex_coord)})
723        with suppress(AttributeError):
724            getattr(a, "NOTEXIST")
725        assert kernel_call_count == 0
726
727    def test_dataset_getattr(self):
728        # Test that pickling/unpickling converts the dask backend
729        # to numpy in neither the data variables nor the non-index coords
730        data = build_dask_array("data")
731        nonindex_coord = build_dask_array("coord")
732        ds = Dataset(data_vars={"a": ("x", data)}, coords={"y": ("x", nonindex_coord)})
733        with suppress(AttributeError):
734            getattr(ds, "NOTEXIST")
735        assert kernel_call_count == 0
736
737    def test_values(self):
738        # Test that invoking the values property does not convert the dask
739        # backend to numpy
740        a = DataArray([1, 2]).chunk()
741        assert not a._in_memory
742        assert a.values.tolist() == [1, 2]
743        assert not a._in_memory
744
745    def test_from_dask_variable(self):
746        # Test array creation from Variable with dask backend.
747        # This is used e.g. in broadcast()
748        a = DataArray(self.lazy_array.variable, coords={"x": range(4)}, name="foo")
749        self.assertLazyAndIdentical(self.lazy_array, a)
750
751    @requires_pint
752    def test_tokenize_duck_dask_array(self):
753        import pint
754
755        unit_registry = pint.UnitRegistry()
756
757        q = unit_registry.Quantity(self.data, unit_registry.meter)
758        data_array = xr.DataArray(
759            data=q, coords={"x": range(4)}, dims=("x", "y"), name="foo"
760        )
761
762        token = dask.base.tokenize(data_array)
763        post_op = data_array + 5 * unit_registry.meter
764
765        assert dask.base.tokenize(data_array) != dask.base.tokenize(post_op)
766        # Immutability check
767        assert dask.base.tokenize(data_array) == token
768
769
770class TestToDaskDataFrame:
771    def test_to_dask_dataframe(self):
772        # Test conversion of Datasets to dask DataFrames
773        x = np.random.randn(10)
774        y = np.arange(10, dtype="uint8")
775        t = list("abcdefghij")
776
777        ds = Dataset(
778            {"a": ("t", da.from_array(x, chunks=4)), "b": ("t", y), "t": ("t", t)}
779        )
780
781        expected_pd = pd.DataFrame({"a": x, "b": y}, index=pd.Index(t, name="t"))
782
783        # test if 1-D index is correctly set up
784        expected = dd.from_pandas(expected_pd, chunksize=4)
785        actual = ds.to_dask_dataframe(set_index=True)
786        # test if we have dask dataframes
787        assert isinstance(actual, dd.DataFrame)
788
789        # use the .equals from pandas to check dataframes are equivalent
790        assert_frame_equal(expected.compute(), actual.compute())
791
792        # test if no index is given
793        expected = dd.from_pandas(expected_pd.reset_index(drop=False), chunksize=4)
794
795        actual = ds.to_dask_dataframe(set_index=False)
796
797        assert isinstance(actual, dd.DataFrame)
798        assert_frame_equal(expected.compute(), actual.compute())
799
800    def test_to_dask_dataframe_2D(self):
801        # Test if 2-D dataset is supplied
802        w = np.random.randn(2, 3)
803        ds = Dataset({"w": (("x", "y"), da.from_array(w, chunks=(1, 2)))})
804        ds["x"] = ("x", np.array([0, 1], np.int64))
805        ds["y"] = ("y", list("abc"))
806
807        # dask dataframes do not (yet) support multiindex,
808        # but when it does, this would be the expected index:
809        exp_index = pd.MultiIndex.from_arrays(
810            [[0, 0, 0, 1, 1, 1], ["a", "b", "c", "a", "b", "c"]], names=["x", "y"]
811        )
812        expected = pd.DataFrame({"w": w.reshape(-1)}, index=exp_index)
813        # so for now, reset the index
814        expected = expected.reset_index(drop=False)
815        actual = ds.to_dask_dataframe(set_index=False)
816
817        assert isinstance(actual, dd.DataFrame)
818        assert_frame_equal(expected, actual.compute())
819
820    @pytest.mark.xfail(raises=NotImplementedError)
821    def test_to_dask_dataframe_2D_set_index(self):
822        # This will fail until dask implements MultiIndex support
823        w = da.from_array(np.random.randn(2, 3), chunks=(1, 2))
824        ds = Dataset({"w": (("x", "y"), w)})
825        ds["x"] = ("x", np.array([0, 1], np.int64))
826        ds["y"] = ("y", list("abc"))
827
828        expected = ds.compute().to_dataframe()
829        actual = ds.to_dask_dataframe(set_index=True)
830        assert isinstance(actual, dd.DataFrame)
831        assert_frame_equal(expected, actual.compute())
832
833    def test_to_dask_dataframe_coordinates(self):
834        # Test if coordinate is also a dask array
835        x = np.random.randn(10)
836        t = np.arange(10) * 2
837
838        ds = Dataset(
839            {
840                "a": ("t", da.from_array(x, chunks=4)),
841                "t": ("t", da.from_array(t, chunks=4)),
842            }
843        )
844
845        expected_pd = pd.DataFrame({"a": x}, index=pd.Index(t, name="t"))
846        expected = dd.from_pandas(expected_pd, chunksize=4)
847        actual = ds.to_dask_dataframe(set_index=True)
848        assert isinstance(actual, dd.DataFrame)
849        assert_frame_equal(expected.compute(), actual.compute())
850
851    def test_to_dask_dataframe_not_daskarray(self):
852        # Test if DataArray is not a dask array
853        x = np.random.randn(10)
854        y = np.arange(10, dtype="uint8")
855        t = list("abcdefghij")
856
857        ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t)})
858
859        expected = pd.DataFrame({"a": x, "b": y}, index=pd.Index(t, name="t"))
860
861        actual = ds.to_dask_dataframe(set_index=True)
862        assert isinstance(actual, dd.DataFrame)
863        assert_frame_equal(expected, actual.compute())
864
865    def test_to_dask_dataframe_no_coordinate(self):
866        x = da.from_array(np.random.randn(10), chunks=4)
867        ds = Dataset({"x": ("dim_0", x)})
868
869        expected = ds.compute().to_dataframe().reset_index()
870        actual = ds.to_dask_dataframe()
871        assert isinstance(actual, dd.DataFrame)
872        assert_frame_equal(expected, actual.compute())
873
874        expected = ds.compute().to_dataframe()
875        actual = ds.to_dask_dataframe(set_index=True)
876        assert isinstance(actual, dd.DataFrame)
877        assert_frame_equal(expected, actual.compute())
878
879    def test_to_dask_dataframe_dim_order(self):
880        values = np.array([[1, 2], [3, 4]], dtype=np.int64)
881        ds = Dataset({"w": (("x", "y"), values)}).chunk(1)
882
883        expected = ds["w"].to_series().reset_index()
884        actual = ds.to_dask_dataframe(dim_order=["x", "y"])
885        assert isinstance(actual, dd.DataFrame)
886        assert_frame_equal(expected, actual.compute())
887
888        expected = ds["w"].T.to_series().reset_index()
889        actual = ds.to_dask_dataframe(dim_order=["y", "x"])
890        assert isinstance(actual, dd.DataFrame)
891        assert_frame_equal(expected, actual.compute())
892
893        with pytest.raises(ValueError, match=r"does not match the set of dimensions"):
894            ds.to_dask_dataframe(dim_order=["x"])
895
896
897@pytest.mark.parametrize("method", ["load", "compute"])
898def test_dask_kwargs_variable(method):
899    x = Variable("y", da.from_array(np.arange(3), chunks=(2,)))
900    # args should be passed on to da.Array.compute()
901    with mock.patch.object(
902        da.Array, "compute", return_value=np.arange(3)
903    ) as mock_compute:
904        getattr(x, method)(foo="bar")
905    mock_compute.assert_called_with(foo="bar")
906
907
908@pytest.mark.parametrize("method", ["load", "compute", "persist"])
909def test_dask_kwargs_dataarray(method):
910    data = da.from_array(np.arange(3), chunks=(2,))
911    x = DataArray(data)
912    if method in ["load", "compute"]:
913        dask_func = "dask.array.compute"
914    else:
915        dask_func = "dask.persist"
916    # args should be passed on to "dask_func"
917    with mock.patch(dask_func) as mock_func:
918        getattr(x, method)(foo="bar")
919    mock_func.assert_called_with(data, foo="bar")
920
921
922@pytest.mark.parametrize("method", ["load", "compute", "persist"])
923def test_dask_kwargs_dataset(method):
924    data = da.from_array(np.arange(3), chunks=(2,))
925    x = Dataset({"x": (("y"), data)})
926    if method in ["load", "compute"]:
927        dask_func = "dask.array.compute"
928    else:
929        dask_func = "dask.persist"
930    # args should be passed on to "dask_func"
931    with mock.patch(dask_func) as mock_func:
932        getattr(x, method)(foo="bar")
933    mock_func.assert_called_with(data, foo="bar")
934
935
936kernel_call_count = 0
937
938
939def kernel(name):
940    """Dask kernel to test pickling/unpickling and __repr__.
941    Must be global to make it pickleable.
942    """
943    global kernel_call_count
944    kernel_call_count += 1
945    return np.ones(1, dtype=np.int64)
946
947
948def build_dask_array(name):
949    global kernel_call_count
950    kernel_call_count = 0
951    return dask.array.Array(
952        dask={(name, 0): (kernel, name)}, name=name, chunks=((1,),), dtype=np.int64
953    )
954
955
956@pytest.mark.parametrize(
957    "persist", [lambda x: x.persist(), lambda x: dask.persist(x)[0]]
958)
959def test_persist_Dataset(persist):
960    ds = Dataset({"foo": ("x", range(5)), "bar": ("x", range(5))}).chunk()
961    ds = ds + 1
962    n = len(ds.foo.data.dask)
963
964    ds2 = persist(ds)
965
966    assert len(ds2.foo.data.dask) == 1
967    assert len(ds.foo.data.dask) == n  # doesn't mutate in place
968
969
970@pytest.mark.parametrize(
971    "persist", [lambda x: x.persist(), lambda x: dask.persist(x)[0]]
972)
973def test_persist_DataArray(persist):
974    x = da.arange(10, chunks=(5,))
975    y = DataArray(x)
976    z = y + 1
977    n = len(z.data.dask)
978
979    zz = persist(z)
980
981    assert len(z.data.dask) == n
982    assert len(zz.data.dask) == zz.data.npartitions
983
984
985def test_dataarray_with_dask_coords():
986    import toolz
987
988    x = xr.Variable("x", da.arange(8, chunks=(4,)))
989    y = xr.Variable("y", da.arange(8, chunks=(4,)) * 2)
990    data = da.random.random((8, 8), chunks=(4, 4)) + 1
991    array = xr.DataArray(data, dims=["x", "y"])
992    array.coords["xx"] = x
993    array.coords["yy"] = y
994
995    assert dict(array.__dask_graph__()) == toolz.merge(
996        data.__dask_graph__(), x.__dask_graph__(), y.__dask_graph__()
997    )
998
999    (array2,) = dask.compute(array)
1000    assert not dask.is_dask_collection(array2)
1001
1002    assert all(isinstance(v._variable.data, np.ndarray) for v in array2.coords.values())
1003
1004
1005def test_basic_compute():
1006    ds = Dataset({"foo": ("x", range(5)), "bar": ("x", range(5))}).chunk({"x": 2})
1007    for get in [dask.threaded.get, dask.multiprocessing.get, dask.local.get_sync, None]:
1008        with dask.config.set(scheduler=get):
1009            ds.compute()
1010            ds.foo.compute()
1011            ds.foo.variable.compute()
1012
1013
1014def test_dask_layers_and_dependencies():
1015    ds = Dataset({"foo": ("x", range(5)), "bar": ("x", range(5))}).chunk()
1016
1017    x = dask.delayed(ds)
1018    assert set(x.__dask_graph__().dependencies).issuperset(
1019        ds.__dask_graph__().dependencies
1020    )
1021    assert set(x.foo.__dask_graph__().dependencies).issuperset(
1022        ds.__dask_graph__().dependencies
1023    )
1024
1025
1026def make_da():
1027    da = xr.DataArray(
1028        np.ones((10, 20)),
1029        dims=["x", "y"],
1030        coords={"x": np.arange(10), "y": np.arange(100, 120)},
1031        name="a",
1032    ).chunk({"x": 4, "y": 5})
1033    da.x.attrs["long_name"] = "x"
1034    da.attrs["test"] = "test"
1035    da.coords["c2"] = 0.5
1036    da.coords["ndcoord"] = da.x * 2
1037    da.coords["cxy"] = (da.x * da.y).chunk({"x": 4, "y": 5})
1038
1039    return da
1040
1041
1042def make_ds():
1043    map_ds = xr.Dataset()
1044    map_ds["a"] = make_da()
1045    map_ds["b"] = map_ds.a + 50
1046    map_ds["c"] = map_ds.x + 20
1047    map_ds = map_ds.chunk({"x": 4, "y": 5})
1048    map_ds["d"] = ("z", [1, 1, 1, 1])
1049    map_ds["z"] = [0, 1, 2, 3]
1050    map_ds["e"] = map_ds.x + map_ds.y
1051    map_ds.coords["c1"] = 0.5
1052    map_ds.coords["cx"] = ("x", np.arange(len(map_ds.x)))
1053    map_ds.coords["cx"].attrs["test2"] = "test2"
1054    map_ds.attrs["test"] = "test"
1055    map_ds.coords["xx"] = map_ds["a"] * map_ds.y
1056
1057    map_ds.x.attrs["long_name"] = "x"
1058    map_ds.y.attrs["long_name"] = "y"
1059
1060    return map_ds
1061
1062
1063# fixtures cannot be used in parametrize statements
1064# instead use this workaround
1065# https://docs.pytest.org/en/latest/deprecations.html#calling-fixtures-directly
1066@pytest.fixture
1067def map_da():
1068    return make_da()
1069
1070
1071@pytest.fixture
1072def map_ds():
1073    return make_ds()
1074
1075
1076def test_unify_chunks(map_ds):
1077    ds_copy = map_ds.copy()
1078    ds_copy["cxy"] = ds_copy.cxy.chunk({"y": 10})
1079
1080    with pytest.raises(ValueError, match=r"inconsistent chunks"):
1081        ds_copy.chunks
1082
1083    expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5)}
1084    with raise_if_dask_computes():
1085        actual_chunks = ds_copy.unify_chunks().chunks
1086    assert actual_chunks == expected_chunks
1087    assert_identical(map_ds, ds_copy.unify_chunks())
1088
1089    out_a, out_b = xr.unify_chunks(ds_copy.cxy, ds_copy.drop_vars("cxy"))
1090    assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5))
1091    assert out_b.chunks == expected_chunks
1092
1093    # Test unordered dims
1094    da = ds_copy["cxy"]
1095    out_a, out_b = xr.unify_chunks(da.chunk({"x": -1}), da.T.chunk({"y": -1}))
1096    assert out_a.chunks == ((4, 4, 2), (5, 5, 5, 5))
1097    assert out_b.chunks == ((5, 5, 5, 5), (4, 4, 2))
1098
1099    # Test mismatch
1100    with pytest.raises(ValueError, match=r"Dimension 'x' size mismatch: 10 != 2"):
1101        xr.unify_chunks(da, da.isel(x=slice(2)))
1102
1103
1104@pytest.mark.parametrize("obj", [make_ds(), make_da()])
1105@pytest.mark.parametrize(
1106    "transform", [lambda x: x.compute(), lambda x: x.unify_chunks()]
1107)
1108def test_unify_chunks_shallow_copy(obj, transform):
1109    obj = transform(obj)
1110    unified = obj.unify_chunks()
1111    assert_identical(obj, unified) and obj is not obj.unify_chunks()
1112
1113
1114@pytest.mark.parametrize("obj", [make_da()])
1115def test_auto_chunk_da(obj):
1116    actual = obj.chunk("auto").data
1117    expected = obj.data.rechunk("auto")
1118    np.testing.assert_array_equal(actual, expected)
1119    assert actual.chunks == expected.chunks
1120
1121
1122def test_map_blocks_error(map_da, map_ds):
1123    def bad_func(darray):
1124        return (darray * darray.x + 5 * darray.y)[:1, :1]
1125
1126    with pytest.raises(ValueError, match=r"Received dimension 'x' of length 1"):
1127        xr.map_blocks(bad_func, map_da).compute()
1128
1129    def returns_numpy(darray):
1130        return (darray * darray.x + 5 * darray.y).values
1131
1132    with pytest.raises(TypeError, match=r"Function must return an xarray DataArray"):
1133        xr.map_blocks(returns_numpy, map_da)
1134
1135    with pytest.raises(TypeError, match=r"args must be"):
1136        xr.map_blocks(operator.add, map_da, args=10)
1137
1138    with pytest.raises(TypeError, match=r"kwargs must be"):
1139        xr.map_blocks(operator.add, map_da, args=[10], kwargs=[20])
1140
1141    def really_bad_func(darray):
1142        raise ValueError("couldn't do anything.")
1143
1144    with pytest.raises(Exception, match=r"Cannot infer"):
1145        xr.map_blocks(really_bad_func, map_da)
1146
1147    ds_copy = map_ds.copy()
1148    ds_copy["cxy"] = ds_copy.cxy.chunk({"y": 10})
1149
1150    with pytest.raises(ValueError, match=r"inconsistent chunks"):
1151        xr.map_blocks(bad_func, ds_copy)
1152
1153    with pytest.raises(TypeError, match=r"Cannot pass dask collections"):
1154        xr.map_blocks(bad_func, map_da, kwargs=dict(a=map_da.chunk()))
1155
1156
1157@pytest.mark.parametrize("obj", [make_da(), make_ds()])
1158def test_map_blocks(obj):
1159    def func(obj):
1160        result = obj + obj.x + 5 * obj.y
1161        return result
1162
1163    with raise_if_dask_computes():
1164        actual = xr.map_blocks(func, obj)
1165    expected = func(obj)
1166    assert_chunks_equal(expected.chunk(), actual)
1167    assert_identical(actual, expected)
1168
1169
1170@pytest.mark.parametrize("obj", [make_da(), make_ds()])
1171def test_map_blocks_convert_args_to_list(obj):
1172    expected = obj + 10
1173    with raise_if_dask_computes():
1174        actual = xr.map_blocks(operator.add, obj, [10])
1175    assert_chunks_equal(expected.chunk(), actual)
1176    assert_identical(actual, expected)
1177
1178
1179def test_map_blocks_dask_args():
1180    da1 = xr.DataArray(
1181        np.ones((10, 20)),
1182        dims=["x", "y"],
1183        coords={"x": np.arange(10), "y": np.arange(20)},
1184    ).chunk({"x": 5, "y": 4})
1185
1186    # check that block shapes are the same
1187    def sumda(da1, da2):
1188        assert da1.shape == da2.shape
1189        return da1 + da2
1190
1191    da2 = da1 + 1
1192    with raise_if_dask_computes():
1193        mapped = xr.map_blocks(sumda, da1, args=[da2])
1194    xr.testing.assert_equal(da1 + da2, mapped)
1195
1196    # one dimension in common
1197    da2 = (da1 + 1).isel(x=1, drop=True)
1198    with raise_if_dask_computes():
1199        mapped = xr.map_blocks(operator.add, da1, args=[da2])
1200    xr.testing.assert_equal(da1 + da2, mapped)
1201
1202    # test that everything works when dimension names are different
1203    da2 = (da1 + 1).isel(x=1, drop=True).rename({"y": "k"})
1204    with raise_if_dask_computes():
1205        mapped = xr.map_blocks(operator.add, da1, args=[da2])
1206    xr.testing.assert_equal(da1 + da2, mapped)
1207
1208    with pytest.raises(ValueError, match=r"Chunk sizes along dimension 'x'"):
1209        xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})])
1210
1211    with pytest.raises(ValueError, match=r"indexes along dimension 'x' are not equal"):
1212        xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))])
1213
1214    # reduction
1215    da1 = da1.chunk({"x": -1})
1216    da2 = da1 + 1
1217    with raise_if_dask_computes():
1218        mapped = xr.map_blocks(lambda a, b: (a + b).sum("x"), da1, args=[da2])
1219    xr.testing.assert_equal((da1 + da2).sum("x"), mapped)
1220
1221    # reduction with template
1222    da1 = da1.chunk({"x": -1})
1223    da2 = da1 + 1
1224    with raise_if_dask_computes():
1225        mapped = xr.map_blocks(
1226            lambda a, b: (a + b).sum("x"), da1, args=[da2], template=da1.sum("x")
1227        )
1228    xr.testing.assert_equal((da1 + da2).sum("x"), mapped)
1229
1230
1231@pytest.mark.parametrize("obj", [make_da(), make_ds()])
1232def test_map_blocks_add_attrs(obj):
1233    def add_attrs(obj):
1234        obj = obj.copy(deep=True)
1235        obj.attrs["new"] = "new"
1236        obj.cxy.attrs["new2"] = "new2"
1237        return obj
1238
1239    expected = add_attrs(obj)
1240    with raise_if_dask_computes():
1241        actual = xr.map_blocks(add_attrs, obj)
1242
1243    assert_identical(actual, expected)
1244
1245    # when template is specified, attrs are copied from template, not set by function
1246    with raise_if_dask_computes():
1247        actual = xr.map_blocks(add_attrs, obj, template=obj)
1248    assert_identical(actual, obj)
1249
1250
1251def test_map_blocks_change_name(map_da):
1252    def change_name(obj):
1253        obj = obj.copy(deep=True)
1254        obj.name = "new"
1255        return obj
1256
1257    expected = change_name(map_da)
1258    with raise_if_dask_computes():
1259        actual = xr.map_blocks(change_name, map_da)
1260
1261    assert_identical(actual, expected)
1262
1263
1264@pytest.mark.parametrize("obj", [make_da(), make_ds()])
1265def test_map_blocks_kwargs(obj):
1266    expected = xr.full_like(obj, fill_value=np.nan)
1267    with raise_if_dask_computes():
1268        actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan))
1269    assert_chunks_equal(expected.chunk(), actual)
1270    assert_identical(actual, expected)
1271
1272
1273def test_map_blocks_to_array(map_ds):
1274    with raise_if_dask_computes():
1275        actual = xr.map_blocks(lambda x: x.to_array(), map_ds)
1276
1277    # to_array does not preserve name, so cannot use assert_identical
1278    assert_equal(actual, map_ds.to_array())
1279
1280
1281@pytest.mark.parametrize(
1282    "func",
1283    [
1284        lambda x: x,
1285        lambda x: x.to_dataset(),
1286        lambda x: x.drop_vars("x"),
1287        lambda x: x.expand_dims(k=[1, 2, 3]),
1288        lambda x: x.expand_dims(k=3),
1289        lambda x: x.assign_coords(new_coord=("y", x.y.data * 2)),
1290        lambda x: x.astype(np.int32),
1291        lambda x: x.x,
1292    ],
1293)
1294def test_map_blocks_da_transformations(func, map_da):
1295    with raise_if_dask_computes():
1296        actual = xr.map_blocks(func, map_da)
1297
1298    assert_identical(actual, func(map_da))
1299
1300
1301@pytest.mark.parametrize(
1302    "func",
1303    [
1304        lambda x: x,
1305        lambda x: x.drop_vars("cxy"),
1306        lambda x: x.drop_vars("a"),
1307        lambda x: x.drop_vars("x"),
1308        lambda x: x.expand_dims(k=[1, 2, 3]),
1309        lambda x: x.expand_dims(k=3),
1310        lambda x: x.rename({"a": "new1", "b": "new2"}),
1311        lambda x: x.x,
1312    ],
1313)
1314def test_map_blocks_ds_transformations(func, map_ds):
1315    with raise_if_dask_computes():
1316        actual = xr.map_blocks(func, map_ds)
1317
1318    assert_identical(actual, func(map_ds))
1319
1320
1321@pytest.mark.parametrize("obj", [make_da(), make_ds()])
1322def test_map_blocks_da_ds_with_template(obj):
1323    func = lambda x: x.isel(x=[1])
1324    template = obj.isel(x=[1, 5, 9])
1325    with raise_if_dask_computes():
1326        actual = xr.map_blocks(func, obj, template=template)
1327    assert_identical(actual, template)
1328
1329    with raise_if_dask_computes():
1330        actual = obj.map_blocks(func, template=template)
1331    assert_identical(actual, template)
1332
1333
1334def test_map_blocks_template_convert_object():
1335    da = make_da()
1336    func = lambda x: x.to_dataset().isel(x=[1])
1337    template = da.to_dataset().isel(x=[1, 5, 9])
1338    with raise_if_dask_computes():
1339        actual = xr.map_blocks(func, da, template=template)
1340    assert_identical(actual, template)
1341
1342    ds = da.to_dataset()
1343    func = lambda x: x.to_array().isel(x=[1])
1344    template = ds.to_array().isel(x=[1, 5, 9])
1345    with raise_if_dask_computes():
1346        actual = xr.map_blocks(func, ds, template=template)
1347    assert_identical(actual, template)
1348
1349
1350@pytest.mark.parametrize("obj", [make_da(), make_ds()])
1351def test_map_blocks_errors_bad_template(obj):
1352    with pytest.raises(ValueError, match=r"unexpected coordinate variables"):
1353        xr.map_blocks(lambda x: x.assign_coords(a=10), obj, template=obj).compute()
1354    with pytest.raises(ValueError, match=r"does not contain coordinate variables"):
1355        xr.map_blocks(lambda x: x.drop_vars("cxy"), obj, template=obj).compute()
1356    with pytest.raises(ValueError, match=r"Dimensions {'x'} missing"):
1357        xr.map_blocks(lambda x: x.isel(x=1), obj, template=obj).compute()
1358    with pytest.raises(ValueError, match=r"Received dimension 'x' of length 1"):
1359        xr.map_blocks(lambda x: x.isel(x=[1]), obj, template=obj).compute()
1360    with pytest.raises(TypeError, match=r"must be a DataArray"):
1361        xr.map_blocks(lambda x: x.isel(x=[1]), obj, template=(obj,)).compute()
1362    with pytest.raises(ValueError, match=r"map_blocks requires that one block"):
1363        xr.map_blocks(
1364            lambda x: x.isel(x=[1]).assign_coords(x=10), obj, template=obj.isel(x=[1])
1365        ).compute()
1366    with pytest.raises(ValueError, match=r"Expected index 'x' to be"):
1367        xr.map_blocks(
1368            lambda a: a.isel(x=[1]).assign_coords(x=[120]),  # assign bad index values
1369            obj,
1370            template=obj.isel(x=[1, 5, 9]),
1371        ).compute()
1372
1373
1374def test_map_blocks_errors_bad_template_2(map_ds):
1375    with pytest.raises(ValueError, match=r"unexpected data variables {'xyz'}"):
1376        xr.map_blocks(lambda x: x.assign(xyz=1), map_ds, template=map_ds).compute()
1377
1378
1379@pytest.mark.parametrize("obj", [make_da(), make_ds()])
1380def test_map_blocks_object_method(obj):
1381    def func(obj):
1382        result = obj + obj.x + 5 * obj.y
1383        return result
1384
1385    with raise_if_dask_computes():
1386        expected = xr.map_blocks(func, obj)
1387        actual = obj.map_blocks(func)
1388
1389    assert_identical(expected, actual)
1390
1391
1392def test_map_blocks_hlg_layers():
1393    # regression test for #3599
1394    ds = xr.Dataset(
1395        {
1396            "x": (("a",), dask.array.ones(10, chunks=(5,))),
1397            "z": (("b",), dask.array.ones(10, chunks=(5,))),
1398        }
1399    )
1400    mapped = ds.map_blocks(lambda x: x)
1401
1402    xr.testing.assert_equal(mapped, ds)
1403
1404
1405def test_make_meta(map_ds):
1406    from ..core.parallel import make_meta
1407
1408    meta = make_meta(map_ds)
1409
1410    for variable in map_ds._coord_names:
1411        assert variable in meta._coord_names
1412        assert meta.coords[variable].shape == (0,) * meta.coords[variable].ndim
1413
1414    for variable in map_ds.data_vars:
1415        assert variable in meta.data_vars
1416        assert meta.data_vars[variable].shape == (0,) * meta.data_vars[variable].ndim
1417
1418
1419def test_identical_coords_no_computes():
1420    lons2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1421    a = xr.DataArray(
1422        da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
1423    )
1424    b = xr.DataArray(
1425        da.zeros((10, 10), chunks=2), dims=("y", "x"), coords={"lons": lons2}
1426    )
1427    with raise_if_dask_computes():
1428        c = a + b
1429    assert_identical(c, a)
1430
1431
1432@pytest.mark.parametrize(
1433    "obj", [make_da(), make_da().compute(), make_ds(), make_ds().compute()]
1434)
1435@pytest.mark.parametrize(
1436    "transform",
1437    [
1438        lambda x: x.reset_coords(),
1439        lambda x: x.reset_coords(drop=True),
1440        lambda x: x.isel(x=1),
1441        lambda x: x.attrs.update(new_attrs=1),
1442        lambda x: x.assign_coords(cxy=1),
1443        lambda x: x.rename({"x": "xnew"}),
1444        lambda x: x.rename({"cxy": "cxynew"}),
1445    ],
1446)
1447def test_token_changes_on_transform(obj, transform):
1448    with raise_if_dask_computes():
1449        assert dask.base.tokenize(obj) != dask.base.tokenize(transform(obj))
1450
1451
1452@pytest.mark.parametrize(
1453    "obj", [make_da(), make_da().compute(), make_ds(), make_ds().compute()]
1454)
1455def test_token_changes_when_data_changes(obj):
1456    with raise_if_dask_computes():
1457        t1 = dask.base.tokenize(obj)
1458
1459    # Change data_var
1460    if isinstance(obj, DataArray):
1461        obj *= 2
1462    else:
1463        obj["a"] *= 2
1464    with raise_if_dask_computes():
1465        t2 = dask.base.tokenize(obj)
1466    assert t2 != t1
1467
1468    # Change non-index coord
1469    obj.coords["ndcoord"] *= 2
1470    with raise_if_dask_computes():
1471        t3 = dask.base.tokenize(obj)
1472    assert t3 != t2
1473
1474    # Change IndexVariable
1475    obj = obj.assign_coords(x=obj.x * 2)
1476    with raise_if_dask_computes():
1477        t4 = dask.base.tokenize(obj)
1478    assert t4 != t3
1479
1480
1481@pytest.mark.parametrize("obj", [make_da().compute(), make_ds().compute()])
1482def test_token_changes_when_buffer_changes(obj):
1483    with raise_if_dask_computes():
1484        t1 = dask.base.tokenize(obj)
1485
1486    if isinstance(obj, DataArray):
1487        obj[0, 0] = 123
1488    else:
1489        obj["a"][0, 0] = 123
1490    with raise_if_dask_computes():
1491        t2 = dask.base.tokenize(obj)
1492    assert t2 != t1
1493
1494    obj.coords["ndcoord"][0] = 123
1495    with raise_if_dask_computes():
1496        t3 = dask.base.tokenize(obj)
1497    assert t3 != t2
1498
1499
1500@pytest.mark.parametrize(
1501    "transform",
1502    [lambda x: x, lambda x: x.copy(deep=False), lambda x: x.copy(deep=True)],
1503)
1504@pytest.mark.parametrize("obj", [make_da(), make_ds(), make_ds().variables["a"]])
1505def test_token_identical(obj, transform):
1506    with raise_if_dask_computes():
1507        assert dask.base.tokenize(obj) == dask.base.tokenize(transform(obj))
1508    assert dask.base.tokenize(obj.compute()) == dask.base.tokenize(
1509        transform(obj.compute())
1510    )
1511
1512
1513def test_recursive_token():
1514    """Test that tokenization is invoked recursively, and doesn't just rely on the
1515    output of str()
1516    """
1517    a = np.ones(10000)
1518    b = np.ones(10000)
1519    b[5000] = 2
1520    assert str(a) == str(b)
1521    assert dask.base.tokenize(a) != dask.base.tokenize(b)
1522
1523    # Test DataArray and Variable
1524    da_a = DataArray(a)
1525    da_b = DataArray(b)
1526    assert dask.base.tokenize(da_a) != dask.base.tokenize(da_b)
1527
1528    # Test Dataset
1529    ds_a = da_a.to_dataset(name="x")
1530    ds_b = da_b.to_dataset(name="x")
1531    assert dask.base.tokenize(ds_a) != dask.base.tokenize(ds_b)
1532
1533    # Test IndexVariable
1534    da_a = DataArray(a, dims=["x"], coords={"x": a})
1535    da_b = DataArray(a, dims=["x"], coords={"x": b})
1536    assert dask.base.tokenize(da_a) != dask.base.tokenize(da_b)
1537
1538
1539@requires_scipy_or_netCDF4
1540def test_normalize_token_with_backend(map_ds):
1541    with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp_file:
1542        map_ds.to_netcdf(tmp_file)
1543        read = xr.open_dataset(tmp_file)
1544        assert not dask.base.tokenize(map_ds) == dask.base.tokenize(read)
1545        read.close()
1546
1547
1548@pytest.mark.parametrize(
1549    "compat", ["broadcast_equals", "equals", "identical", "no_conflicts"]
1550)
1551def test_lazy_array_equiv_variables(compat):
1552    var1 = xr.Variable(("y", "x"), da.zeros((10, 10), chunks=2))
1553    var2 = xr.Variable(("y", "x"), da.zeros((10, 10), chunks=2))
1554    var3 = xr.Variable(("y", "x"), da.zeros((20, 10), chunks=2))
1555
1556    with raise_if_dask_computes():
1557        assert getattr(var1, compat)(var2, equiv=lazy_array_equiv)
1558    # values are actually equal, but we don't know that till we compute, return None
1559    with raise_if_dask_computes():
1560        assert getattr(var1, compat)(var2 / 2, equiv=lazy_array_equiv) is None
1561
1562    # shapes are not equal, return False without computes
1563    with raise_if_dask_computes():
1564        assert getattr(var1, compat)(var3, equiv=lazy_array_equiv) is False
1565
1566    # if one or both arrays are numpy, return None
1567    assert getattr(var1, compat)(var2.compute(), equiv=lazy_array_equiv) is None
1568    assert (
1569        getattr(var1.compute(), compat)(var2.compute(), equiv=lazy_array_equiv) is None
1570    )
1571
1572    with raise_if_dask_computes():
1573        assert getattr(var1, compat)(var2.transpose("y", "x"))
1574
1575
1576@pytest.mark.parametrize(
1577    "compat", ["broadcast_equals", "equals", "identical", "no_conflicts"]
1578)
1579def test_lazy_array_equiv_merge(compat):
1580    da1 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1581    da2 = xr.DataArray(da.zeros((10, 10), chunks=2), dims=("y", "x"))
1582    da3 = xr.DataArray(da.ones((20, 10), chunks=2), dims=("y", "x"))
1583
1584    with raise_if_dask_computes():
1585        xr.merge([da1, da2], compat=compat)
1586    # shapes are not equal; no computes necessary
1587    with raise_if_dask_computes(max_computes=0):
1588        with pytest.raises(ValueError):
1589            xr.merge([da1, da3], compat=compat)
1590    with raise_if_dask_computes(max_computes=2):
1591        xr.merge([da1, da2 / 2], compat=compat)
1592
1593
1594@pytest.mark.filterwarnings("ignore::FutureWarning")  # transpose_coords
1595@pytest.mark.parametrize("obj", [make_da(), make_ds()])
1596@pytest.mark.parametrize(
1597    "transform",
1598    [
1599        lambda a: a.assign_attrs(new_attr="anew"),
1600        lambda a: a.assign_coords(cxy=a.cxy),
1601        lambda a: a.copy(),
1602        lambda a: a.isel(x=np.arange(a.sizes["x"])),
1603        lambda a: a.isel(x=slice(None)),
1604        lambda a: a.loc[dict(x=slice(None))],
1605        lambda a: a.loc[dict(x=np.arange(a.sizes["x"]))],
1606        lambda a: a.loc[dict(x=a.x)],
1607        lambda a: a.sel(x=a.x),
1608        lambda a: a.sel(x=a.x.values),
1609        lambda a: a.transpose(...),
1610        lambda a: a.squeeze(),  # no dimensions to squeeze
1611        lambda a: a.sortby("x"),  # "x" is already sorted
1612        lambda a: a.reindex(x=a.x),
1613        lambda a: a.reindex_like(a),
1614        lambda a: a.rename({"cxy": "cnew"}).rename({"cnew": "cxy"}),
1615        lambda a: a.pipe(lambda x: x),
1616        lambda a: xr.align(a, xr.zeros_like(a))[0],
1617        # assign
1618        # swap_dims
1619        # set_index / reset_index
1620    ],
1621)
1622def test_transforms_pass_lazy_array_equiv(obj, transform):
1623    with raise_if_dask_computes():
1624        assert_equal(obj, transform(obj))
1625
1626
1627def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds):
1628    with raise_if_dask_computes():
1629        assert_equal(map_ds.cxy.broadcast_like(map_ds.cxy), map_ds.cxy)
1630        assert_equal(xr.broadcast(map_ds.cxy, map_ds.cxy)[0], map_ds.cxy)
1631        assert_equal(map_ds.map(lambda x: x), map_ds)
1632        assert_equal(map_ds.set_coords("a").reset_coords("a"), map_ds)
1633        assert_equal(map_ds.assign({"a": map_ds.a}), map_ds)
1634
1635        # fails because of index error
1636        # assert_equal(
1637        #     map_ds.rename_dims({"x": "xnew"}).rename_dims({"xnew": "x"}), map_ds
1638        # )
1639
1640        assert_equal(
1641            map_ds.rename_vars({"cxy": "cnew"}).rename_vars({"cnew": "cxy"}), map_ds
1642        )
1643
1644        assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da)
1645        assert_equal(map_da.astype(map_da.dtype), map_da)
1646        assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy)
1647
1648
1649def test_optimize():
1650    # https://github.com/pydata/xarray/issues/3698
1651    a = dask.array.ones((10, 4), chunks=(5, 2))
1652    arr = xr.DataArray(a).chunk(5)
1653    (arr2,) = dask.optimize(arr)
1654    arr2.compute()
1655
1656
1657# The graph_manipulation module is in dask since 2021.2 but it became usable with
1658# xarray only since 2021.3
1659@pytest.mark.skipif(dask_version <= "2021.02.0", reason="new module")
1660def test_graph_manipulation():
1661    """dask.graph_manipulation passes an optional parameter, "rename", to the rebuilder
1662    function returned by __dask_postperist__; also, the dsk passed to the rebuilder is
1663    a HighLevelGraph whereas with dask.persist() and dask.optimize() it's a plain dict.
1664    """
1665    import dask.graph_manipulation as gm
1666
1667    v = Variable(["x"], [1, 2]).chunk(-1).chunk(1) * 2
1668    da = DataArray(v)
1669    ds = Dataset({"d1": v[0], "d2": v[1], "d3": ("x", [3, 4])})
1670
1671    v2, da2, ds2 = gm.clone(v, da, ds)
1672
1673    assert_equal(v2, v)
1674    assert_equal(da2, da)
1675    assert_equal(ds2, ds)
1676
1677    for a, b in ((v, v2), (da, da2), (ds, ds2)):
1678        assert a.__dask_layers__() != b.__dask_layers__()
1679        assert len(a.__dask_layers__()) == len(b.__dask_layers__())
1680        assert a.__dask_graph__().keys() != b.__dask_graph__().keys()
1681        assert len(a.__dask_graph__()) == len(b.__dask_graph__())
1682        assert a.__dask_graph__().layers.keys() != b.__dask_graph__().layers.keys()
1683        assert len(a.__dask_graph__().layers) == len(b.__dask_graph__().layers)
1684
1685    # Above we performed a slice operation; adding the two slices back together creates
1686    # a diamond-shaped dependency graph, which in turn will trigger a collision in layer
1687    # names if we were to use HighLevelGraph.cull() instead of
1688    # HighLevelGraph.cull_layers() in Dataset.__dask_postpersist__().
1689    assert_equal(ds2.d1 + ds2.d2, ds.d1 + ds.d2)
1690