1import contextlib
2import gzip
3import itertools
4import math
5import os.path
6import pickle
7import re
8import shutil
9import sys
10import tempfile
11import warnings
12from contextlib import ExitStack
13from io import BytesIO
14from pathlib import Path
15from typing import Optional
16
17import numpy as np
18import pandas as pd
19import pytest
20from pandas.errors import OutOfBoundsDatetime
21
22import xarray as xr
23from xarray import (
24    DataArray,
25    Dataset,
26    backends,
27    load_dataarray,
28    load_dataset,
29    open_dataarray,
30    open_dataset,
31    open_mfdataset,
32    save_mfdataset,
33)
34from xarray.backends.common import robust_getitem
35from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint
36from xarray.backends.netcdf3 import _nc3_dtype_coercions
37from xarray.backends.netCDF4_ import (
38    NetCDF4BackendEntrypoint,
39    _extract_nc4_variable_encoding,
40)
41from xarray.backends.pydap_ import PydapDataStore
42from xarray.backends.scipy_ import ScipyBackendEntrypoint
43from xarray.coding.variables import SerializationWarning
44from xarray.conventions import encode_dataset_coordinates
45from xarray.core import indexing
46from xarray.core.options import set_options
47from xarray.core.pycompat import dask_array_type
48from xarray.tests import LooseVersion, mock
49
50from . import (
51    arm_xfail,
52    assert_allclose,
53    assert_array_equal,
54    assert_equal,
55    assert_identical,
56    has_dask,
57    has_netCDF4,
58    has_scipy,
59    network,
60    requires_cfgrib,
61    requires_cftime,
62    requires_dask,
63    requires_fsspec,
64    requires_h5netcdf,
65    requires_iris,
66    requires_netCDF4,
67    requires_pseudonetcdf,
68    requires_pydap,
69    requires_pynio,
70    requires_rasterio,
71    requires_scipy,
72    requires_scipy_or_netCDF4,
73    requires_zarr,
74)
75from .test_coding_times import (
76    _ALL_CALENDARS,
77    _NON_STANDARD_CALENDARS,
78    _STANDARD_CALENDARS,
79)
80from .test_dataset import create_append_test_data, create_test_data
81
82try:
83    import netCDF4 as nc4
84except ImportError:
85    pass
86
87try:
88    import dask
89    import dask.array as da
90except ImportError:
91    pass
92
93ON_WINDOWS = sys.platform == "win32"
94default_value = object()
95
96
97def open_example_dataset(name, *args, **kwargs):
98    return open_dataset(
99        os.path.join(os.path.dirname(__file__), "data", name), *args, **kwargs
100    )
101
102
103def open_example_mfdataset(names, *args, **kwargs):
104    return open_mfdataset(
105        [os.path.join(os.path.dirname(__file__), "data", name) for name in names],
106        *args,
107        **kwargs,
108    )
109
110
111def create_masked_and_scaled_data():
112    x = np.array([np.nan, np.nan, 10, 10.1, 10.2], dtype=np.float32)
113    encoding = {
114        "_FillValue": -1,
115        "add_offset": 10,
116        "scale_factor": np.float32(0.1),
117        "dtype": "i2",
118    }
119    return Dataset({"x": ("t", x, {}, encoding)})
120
121
122def create_encoded_masked_and_scaled_data():
123    attributes = {"_FillValue": -1, "add_offset": 10, "scale_factor": np.float32(0.1)}
124    return Dataset({"x": ("t", np.int16([-1, -1, 0, 1, 2]), attributes)})
125
126
127def create_unsigned_masked_scaled_data():
128    encoding = {
129        "_FillValue": 255,
130        "_Unsigned": "true",
131        "dtype": "i1",
132        "add_offset": 10,
133        "scale_factor": np.float32(0.1),
134    }
135    x = np.array([10.0, 10.1, 22.7, 22.8, np.nan], dtype=np.float32)
136    return Dataset({"x": ("t", x, {}, encoding)})
137
138
139def create_encoded_unsigned_masked_scaled_data():
140    # These are values as written to the file: the _FillValue will
141    # be represented in the signed form.
142    attributes = {
143        "_FillValue": -1,
144        "_Unsigned": "true",
145        "add_offset": 10,
146        "scale_factor": np.float32(0.1),
147    }
148    # Create unsigned data corresponding to [0, 1, 127, 128, 255] unsigned
149    sb = np.asarray([0, 1, 127, -128, -1], dtype="i1")
150    return Dataset({"x": ("t", sb, attributes)})
151
152
153def create_bad_unsigned_masked_scaled_data():
154    encoding = {
155        "_FillValue": 255,
156        "_Unsigned": True,
157        "dtype": "i1",
158        "add_offset": 10,
159        "scale_factor": np.float32(0.1),
160    }
161    x = np.array([10.0, 10.1, 22.7, 22.8, np.nan], dtype=np.float32)
162    return Dataset({"x": ("t", x, {}, encoding)})
163
164
165def create_bad_encoded_unsigned_masked_scaled_data():
166    # These are values as written to the file: the _FillValue will
167    # be represented in the signed form.
168    attributes = {
169        "_FillValue": -1,
170        "_Unsigned": True,
171        "add_offset": 10,
172        "scale_factor": np.float32(0.1),
173    }
174    # Create signed data corresponding to [0, 1, 127, 128, 255] unsigned
175    sb = np.asarray([0, 1, 127, -128, -1], dtype="i1")
176    return Dataset({"x": ("t", sb, attributes)})
177
178
179def create_signed_masked_scaled_data():
180    encoding = {
181        "_FillValue": -127,
182        "_Unsigned": "false",
183        "dtype": "i1",
184        "add_offset": 10,
185        "scale_factor": np.float32(0.1),
186    }
187    x = np.array([-1.0, 10.1, 22.7, np.nan], dtype=np.float32)
188    return Dataset({"x": ("t", x, {}, encoding)})
189
190
191def create_encoded_signed_masked_scaled_data():
192    # These are values as written to the file: the _FillValue will
193    # be represented in the signed form.
194    attributes = {
195        "_FillValue": -127,
196        "_Unsigned": "false",
197        "add_offset": 10,
198        "scale_factor": np.float32(0.1),
199    }
200    # Create signed data corresponding to [0, 1, 127, 128, 255] unsigned
201    sb = np.asarray([-110, 1, 127, -127], dtype="i1")
202    return Dataset({"x": ("t", sb, attributes)})
203
204
205def create_boolean_data():
206    attributes = {"units": "-"}
207    return Dataset({"x": ("t", [True, False, False, True], attributes)})
208
209
210class TestCommon:
211    def test_robust_getitem(self):
212        class UnreliableArrayFailure(Exception):
213            pass
214
215        class UnreliableArray:
216            def __init__(self, array, failures=1):
217                self.array = array
218                self.failures = failures
219
220            def __getitem__(self, key):
221                if self.failures > 0:
222                    self.failures -= 1
223                    raise UnreliableArrayFailure
224                return self.array[key]
225
226        array = UnreliableArray([0])
227        with pytest.raises(UnreliableArrayFailure):
228            array[0]
229        assert array[0] == 0
230
231        actual = robust_getitem(array, 0, catch=UnreliableArrayFailure, initial_delay=0)
232        assert actual == 0
233
234
235class NetCDF3Only:
236    netcdf3_formats = ("NETCDF3_CLASSIC", "NETCDF3_64BIT")
237
238    @requires_scipy
239    def test_dtype_coercion_error(self):
240        """Failing dtype coercion should lead to an error"""
241        for dtype, format in itertools.product(
242            _nc3_dtype_coercions, self.netcdf3_formats
243        ):
244            if dtype == "bool":
245                # coerced upcast (bool to int8) ==> can never fail
246                continue
247
248            # Using the largest representable value, create some data that will
249            # no longer compare equal after the coerced downcast
250            maxval = np.iinfo(dtype).max
251            x = np.array([0, 1, 2, maxval], dtype=dtype)
252            ds = Dataset({"x": ("t", x, {})})
253
254            with create_tmp_file(allow_cleanup_failure=False) as path:
255                with pytest.raises(ValueError, match="could not safely cast"):
256                    ds.to_netcdf(path, format=format)
257
258
259class DatasetIOBase:
260    engine: Optional[str] = None
261    file_format: Optional[str] = None
262
263    def create_store(self):
264        raise NotImplementedError()
265
266    @contextlib.contextmanager
267    def roundtrip(
268        self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False
269    ):
270        if save_kwargs is None:
271            save_kwargs = {}
272        if open_kwargs is None:
273            open_kwargs = {}
274        with create_tmp_file(allow_cleanup_failure=allow_cleanup_failure) as path:
275            self.save(data, path, **save_kwargs)
276            with self.open(path, **open_kwargs) as ds:
277                yield ds
278
279    @contextlib.contextmanager
280    def roundtrip_append(
281        self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False
282    ):
283        if save_kwargs is None:
284            save_kwargs = {}
285        if open_kwargs is None:
286            open_kwargs = {}
287        with create_tmp_file(allow_cleanup_failure=allow_cleanup_failure) as path:
288            for i, key in enumerate(data.variables):
289                mode = "a" if i > 0 else "w"
290                self.save(data[[key]], path, mode=mode, **save_kwargs)
291            with self.open(path, **open_kwargs) as ds:
292                yield ds
293
294    # The save/open methods may be overwritten below
295    def save(self, dataset, path, **kwargs):
296        return dataset.to_netcdf(
297            path, engine=self.engine, format=self.file_format, **kwargs
298        )
299
300    @contextlib.contextmanager
301    def open(self, path, **kwargs):
302        with open_dataset(path, engine=self.engine, **kwargs) as ds:
303            yield ds
304
305    def test_zero_dimensional_variable(self):
306        expected = create_test_data()
307        expected["float_var"] = ([], 1.0e9, {"units": "units of awesome"})
308        expected["bytes_var"] = ([], b"foobar")
309        expected["string_var"] = ([], "foobar")
310        with self.roundtrip(expected) as actual:
311            assert_identical(expected, actual)
312
313    def test_write_store(self):
314        expected = create_test_data()
315        with self.create_store() as store:
316            expected.dump_to_store(store)
317            # we need to cf decode the store because it has time and
318            # non-dimension coordinates
319            with xr.decode_cf(store) as actual:
320                assert_allclose(expected, actual)
321
322    def check_dtypes_roundtripped(self, expected, actual):
323        for k in expected.variables:
324            expected_dtype = expected.variables[k].dtype
325
326            # For NetCDF3, the backend should perform dtype coercion
327            if (
328                isinstance(self, NetCDF3Only)
329                and str(expected_dtype) in _nc3_dtype_coercions
330            ):
331                expected_dtype = np.dtype(_nc3_dtype_coercions[str(expected_dtype)])
332
333            actual_dtype = actual.variables[k].dtype
334            # TODO: check expected behavior for string dtypes more carefully
335            string_kinds = {"O", "S", "U"}
336            assert expected_dtype == actual_dtype or (
337                expected_dtype.kind in string_kinds
338                and actual_dtype.kind in string_kinds
339            )
340
341    def test_roundtrip_test_data(self):
342        expected = create_test_data()
343        with self.roundtrip(expected) as actual:
344            self.check_dtypes_roundtripped(expected, actual)
345            assert_identical(expected, actual)
346
347    def test_load(self):
348        expected = create_test_data()
349
350        @contextlib.contextmanager
351        def assert_loads(vars=None):
352            if vars is None:
353                vars = expected
354            with self.roundtrip(expected) as actual:
355                for k, v in actual.variables.items():
356                    # IndexVariables are eagerly loaded into memory
357                    assert v._in_memory == (k in actual.dims)
358                yield actual
359                for k, v in actual.variables.items():
360                    if k in vars:
361                        assert v._in_memory
362                assert_identical(expected, actual)
363
364        with pytest.raises(AssertionError):
365            # make sure the contextmanager works!
366            with assert_loads() as ds:
367                pass
368
369        with assert_loads() as ds:
370            ds.load()
371
372        with assert_loads(["var1", "dim1", "dim2"]) as ds:
373            ds["var1"].load()
374
375        # verify we can read data even after closing the file
376        with self.roundtrip(expected) as ds:
377            actual = ds.load()
378        assert_identical(expected, actual)
379
380    def test_dataset_compute(self):
381        expected = create_test_data()
382
383        with self.roundtrip(expected) as actual:
384            # Test Dataset.compute()
385            for k, v in actual.variables.items():
386                # IndexVariables are eagerly cached
387                assert v._in_memory == (k in actual.dims)
388
389            computed = actual.compute()
390
391            for k, v in actual.variables.items():
392                assert v._in_memory == (k in actual.dims)
393            for v in computed.variables.values():
394                assert v._in_memory
395
396            assert_identical(expected, actual)
397            assert_identical(expected, computed)
398
399    def test_pickle(self):
400        if not has_dask:
401            pytest.xfail("pickling requires dask for SerializableLock")
402        expected = Dataset({"foo": ("x", [42])})
403        with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped:
404            with roundtripped:
405                # Windows doesn't like reopening an already open file
406                raw_pickle = pickle.dumps(roundtripped)
407            with pickle.loads(raw_pickle) as unpickled_ds:
408                assert_identical(expected, unpickled_ds)
409
410    @pytest.mark.filterwarnings("ignore:deallocating CachingFileManager")
411    def test_pickle_dataarray(self):
412        if not has_dask:
413            pytest.xfail("pickling requires dask for SerializableLock")
414        expected = Dataset({"foo": ("x", [42])})
415        with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped:
416            with roundtripped:
417                raw_pickle = pickle.dumps(roundtripped["foo"])
418            # TODO: figure out how to explicitly close the file for the
419            # unpickled DataArray?
420            unpickled = pickle.loads(raw_pickle)
421            assert_identical(expected["foo"], unpickled)
422
423    def test_dataset_caching(self):
424        expected = Dataset({"foo": ("x", [5, 6, 7])})
425        with self.roundtrip(expected) as actual:
426            assert isinstance(actual.foo.variable._data, indexing.MemoryCachedArray)
427            assert not actual.foo.variable._in_memory
428            actual.foo.values  # cache
429            assert actual.foo.variable._in_memory
430
431        with self.roundtrip(expected, open_kwargs={"cache": False}) as actual:
432            assert isinstance(actual.foo.variable._data, indexing.CopyOnWriteArray)
433            assert not actual.foo.variable._in_memory
434            actual.foo.values  # no caching
435            assert not actual.foo.variable._in_memory
436
437    def test_roundtrip_None_variable(self):
438        expected = Dataset({None: (("x", "y"), [[0, 1], [2, 3]])})
439        with self.roundtrip(expected) as actual:
440            assert_identical(expected, actual)
441
442    def test_roundtrip_object_dtype(self):
443        floats = np.array([0.0, 0.0, 1.0, 2.0, 3.0], dtype=object)
444        floats_nans = np.array([np.nan, np.nan, 1.0, 2.0, 3.0], dtype=object)
445        bytes_ = np.array([b"ab", b"cdef", b"g"], dtype=object)
446        bytes_nans = np.array([b"ab", b"cdef", np.nan], dtype=object)
447        strings = np.array(["ab", "cdef", "g"], dtype=object)
448        strings_nans = np.array(["ab", "cdef", np.nan], dtype=object)
449        all_nans = np.array([np.nan, np.nan], dtype=object)
450        original = Dataset(
451            {
452                "floats": ("a", floats),
453                "floats_nans": ("a", floats_nans),
454                "bytes": ("b", bytes_),
455                "bytes_nans": ("b", bytes_nans),
456                "strings": ("b", strings),
457                "strings_nans": ("b", strings_nans),
458                "all_nans": ("c", all_nans),
459                "nan": ([], np.nan),
460            }
461        )
462        expected = original.copy(deep=True)
463        with self.roundtrip(original) as actual:
464            try:
465                assert_identical(expected, actual)
466            except AssertionError:
467                # Most stores use '' for nans in strings, but some don't.
468                # First try the ideal case (where the store returns exactly)
469                # the original Dataset), then try a more realistic case.
470                # This currently includes all netCDF files when encoding is not
471                # explicitly set.
472                # https://github.com/pydata/xarray/issues/1647
473                expected["bytes_nans"][-1] = b""
474                expected["strings_nans"][-1] = ""
475                assert_identical(expected, actual)
476
477    def test_roundtrip_string_data(self):
478        expected = Dataset({"x": ("t", ["ab", "cdef"])})
479        with self.roundtrip(expected) as actual:
480            assert_identical(expected, actual)
481
482    def test_roundtrip_string_encoded_characters(self):
483        expected = Dataset({"x": ("t", ["ab", "cdef"])})
484        expected["x"].encoding["dtype"] = "S1"
485        with self.roundtrip(expected) as actual:
486            assert_identical(expected, actual)
487            assert actual["x"].encoding["_Encoding"] == "utf-8"
488
489        expected["x"].encoding["_Encoding"] = "ascii"
490        with self.roundtrip(expected) as actual:
491            assert_identical(expected, actual)
492            assert actual["x"].encoding["_Encoding"] == "ascii"
493
494    @arm_xfail
495    def test_roundtrip_numpy_datetime_data(self):
496        times = pd.to_datetime(["2000-01-01", "2000-01-02", "NaT"])
497        expected = Dataset({"t": ("t", times), "t0": times[0]})
498        kwargs = {"encoding": {"t0": {"units": "days since 1950-01-01"}}}
499        with self.roundtrip(expected, save_kwargs=kwargs) as actual:
500            assert_identical(expected, actual)
501            assert actual.t0.encoding["units"] == "days since 1950-01-01"
502
503    @requires_cftime
504    def test_roundtrip_cftime_datetime_data(self):
505        from .test_coding_times import _all_cftime_date_types
506
507        date_types = _all_cftime_date_types()
508        for date_type in date_types.values():
509            times = [date_type(1, 1, 1), date_type(1, 1, 2)]
510            expected = Dataset({"t": ("t", times), "t0": times[0]})
511            kwargs = {"encoding": {"t0": {"units": "days since 0001-01-01"}}}
512            expected_decoded_t = np.array(times)
513            expected_decoded_t0 = np.array([date_type(1, 1, 1)])
514            expected_calendar = times[0].calendar
515
516            with warnings.catch_warnings():
517                if expected_calendar in {"proleptic_gregorian", "gregorian"}:
518                    warnings.filterwarnings("ignore", "Unable to decode time axis")
519
520                with self.roundtrip(expected, save_kwargs=kwargs) as actual:
521                    abs_diff = abs(actual.t.values - expected_decoded_t)
522                    assert (abs_diff <= np.timedelta64(1, "s")).all()
523                    assert (
524                        actual.t.encoding["units"]
525                        == "days since 0001-01-01 00:00:00.000000"
526                    )
527                    assert actual.t.encoding["calendar"] == expected_calendar
528
529                    abs_diff = abs(actual.t0.values - expected_decoded_t0)
530                    assert (abs_diff <= np.timedelta64(1, "s")).all()
531                    assert actual.t0.encoding["units"] == "days since 0001-01-01"
532                    assert actual.t.encoding["calendar"] == expected_calendar
533
534    def test_roundtrip_timedelta_data(self):
535        time_deltas = pd.to_timedelta(["1h", "2h", "NaT"])
536        expected = Dataset({"td": ("td", time_deltas), "td0": time_deltas[0]})
537        with self.roundtrip(expected) as actual:
538            assert_identical(expected, actual)
539
540    def test_roundtrip_float64_data(self):
541        expected = Dataset({"x": ("y", np.array([1.0, 2.0, np.pi], dtype="float64"))})
542        with self.roundtrip(expected) as actual:
543            assert_identical(expected, actual)
544
545    def test_roundtrip_example_1_netcdf(self):
546        with open_example_dataset("example_1.nc") as expected:
547            with self.roundtrip(expected) as actual:
548                # we allow the attributes to differ since that
549                # will depend on the encoding used.  For example,
550                # without CF encoding 'actual' will end up with
551                # a dtype attribute.
552                assert_equal(expected, actual)
553
554    def test_roundtrip_coordinates(self):
555        original = Dataset(
556            {"foo": ("x", [0, 1])}, {"x": [2, 3], "y": ("a", [42]), "z": ("x", [4, 5])}
557        )
558
559        with self.roundtrip(original) as actual:
560            assert_identical(original, actual)
561
562        original["foo"].encoding["coordinates"] = "y"
563        with self.roundtrip(original, open_kwargs={"decode_coords": False}) as expected:
564            # check roundtripping when decode_coords=False
565            with self.roundtrip(
566                expected, open_kwargs={"decode_coords": False}
567            ) as actual:
568                assert_identical(expected, actual)
569
570    def test_roundtrip_global_coordinates(self):
571        original = Dataset(
572            {"foo": ("x", [0, 1])}, {"x": [2, 3], "y": ("a", [42]), "z": ("x", [4, 5])}
573        )
574        with self.roundtrip(original) as actual:
575            assert_identical(original, actual)
576
577        # test that global "coordinates" is as expected
578        _, attrs = encode_dataset_coordinates(original)
579        assert attrs["coordinates"] == "y"
580
581        # test warning when global "coordinates" is already set
582        original.attrs["coordinates"] = "foo"
583        with pytest.warns(SerializationWarning):
584            _, attrs = encode_dataset_coordinates(original)
585            assert attrs["coordinates"] == "foo"
586
587    def test_roundtrip_coordinates_with_space(self):
588        original = Dataset(coords={"x": 0, "y z": 1})
589        expected = Dataset({"y z": 1}, {"x": 0})
590        with pytest.warns(SerializationWarning):
591            with self.roundtrip(original) as actual:
592                assert_identical(expected, actual)
593
594    def test_roundtrip_boolean_dtype(self):
595        original = create_boolean_data()
596        assert original["x"].dtype == "bool"
597        with self.roundtrip(original) as actual:
598            assert_identical(original, actual)
599            assert actual["x"].dtype == "bool"
600
601    def test_orthogonal_indexing(self):
602        in_memory = create_test_data()
603        with self.roundtrip(in_memory) as on_disk:
604            indexers = {"dim1": [1, 2, 0], "dim2": [3, 2, 0, 3], "dim3": np.arange(5)}
605            expected = in_memory.isel(**indexers)
606            actual = on_disk.isel(**indexers)
607            # make sure the array is not yet loaded into memory
608            assert not actual["var1"].variable._in_memory
609            assert_identical(expected, actual)
610            # do it twice, to make sure we're switched from orthogonal -> numpy
611            # when we cached the values
612            actual = on_disk.isel(**indexers)
613            assert_identical(expected, actual)
614
615    def test_vectorized_indexing(self):
616        in_memory = create_test_data()
617        with self.roundtrip(in_memory) as on_disk:
618            indexers = {
619                "dim1": DataArray([0, 2, 0], dims="a"),
620                "dim2": DataArray([0, 2, 3], dims="a"),
621            }
622            expected = in_memory.isel(**indexers)
623            actual = on_disk.isel(**indexers)
624            # make sure the array is not yet loaded into memory
625            assert not actual["var1"].variable._in_memory
626            assert_identical(expected, actual.load())
627            # do it twice, to make sure we're switched from
628            # vectorized -> numpy when we cached the values
629            actual = on_disk.isel(**indexers)
630            assert_identical(expected, actual)
631
632        def multiple_indexing(indexers):
633            # make sure a sequence of lazy indexings certainly works.
634            with self.roundtrip(in_memory) as on_disk:
635                actual = on_disk["var3"]
636                expected = in_memory["var3"]
637                for ind in indexers:
638                    actual = actual.isel(**ind)
639                    expected = expected.isel(**ind)
640                    # make sure the array is not yet loaded into memory
641                    assert not actual.variable._in_memory
642                assert_identical(expected, actual.load())
643
644        # two-staged vectorized-indexing
645        indexers = [
646            {
647                "dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"]),
648                "dim3": DataArray([[0, 4], [1, 3], [2, 2]], dims=["a", "b"]),
649            },
650            {"a": DataArray([0, 1], dims=["c"]), "b": DataArray([0, 1], dims=["c"])},
651        ]
652        multiple_indexing(indexers)
653
654        # vectorized-slice mixed
655        indexers = [
656            {
657                "dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"]),
658                "dim3": slice(None, 10),
659            }
660        ]
661        multiple_indexing(indexers)
662
663        # vectorized-integer mixed
664        indexers = [
665            {"dim3": 0},
666            {"dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"])},
667            {"a": slice(None, None, 2)},
668        ]
669        multiple_indexing(indexers)
670
671        # vectorized-integer mixed
672        indexers = [
673            {"dim3": 0},
674            {"dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"])},
675            {"a": 1, "b": 0},
676        ]
677        multiple_indexing(indexers)
678
679    @pytest.mark.xfail(
680        reason="zarr without dask handles negative steps in slices incorrectly",
681    )
682    def test_vectorized_indexing_negative_step(self):
683        # use dask explicitly when present
684        if has_dask:
685            open_kwargs = {"chunks": {}}
686        else:
687            open_kwargs = None
688        in_memory = create_test_data()
689
690        def multiple_indexing(indexers):
691            # make sure a sequence of lazy indexings certainly works.
692            with self.roundtrip(in_memory, open_kwargs=open_kwargs) as on_disk:
693                actual = on_disk["var3"]
694                expected = in_memory["var3"]
695                for ind in indexers:
696                    actual = actual.isel(**ind)
697                    expected = expected.isel(**ind)
698                    # make sure the array is not yet loaded into memory
699                    assert not actual.variable._in_memory
700                assert_identical(expected, actual.load())
701
702        # with negative step slice.
703        indexers = [
704            {
705                "dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"]),
706                "dim3": slice(-1, 1, -1),
707            }
708        ]
709        multiple_indexing(indexers)
710
711        # with negative step slice.
712        indexers = [
713            {
714                "dim1": DataArray([[0, 7], [2, 6], [3, 5]], dims=["a", "b"]),
715                "dim3": slice(-1, 1, -2),
716            }
717        ]
718        multiple_indexing(indexers)
719
720    def test_isel_dataarray(self):
721        # Make sure isel works lazily. GH:issue:1688
722        in_memory = create_test_data()
723        with self.roundtrip(in_memory) as on_disk:
724            expected = in_memory.isel(dim2=in_memory["dim2"] < 3)
725            actual = on_disk.isel(dim2=on_disk["dim2"] < 3)
726            assert_identical(expected, actual)
727
728    def validate_array_type(self, ds):
729        # Make sure that only NumpyIndexingAdapter stores a bare np.ndarray.
730        def find_and_validate_array(obj):
731            # recursively called function. obj: array or array wrapper.
732            if hasattr(obj, "array"):
733                if isinstance(obj.array, indexing.ExplicitlyIndexed):
734                    find_and_validate_array(obj.array)
735                else:
736                    if isinstance(obj.array, np.ndarray):
737                        assert isinstance(obj, indexing.NumpyIndexingAdapter)
738                    elif isinstance(obj.array, dask_array_type):
739                        assert isinstance(obj, indexing.DaskIndexingAdapter)
740                    elif isinstance(obj.array, pd.Index):
741                        assert isinstance(obj, indexing.PandasIndexingAdapter)
742                    else:
743                        raise TypeError(
744                            "{} is wrapped by {}".format(type(obj.array), type(obj))
745                        )
746
747        for k, v in ds.variables.items():
748            find_and_validate_array(v._data)
749
750    def test_array_type_after_indexing(self):
751        in_memory = create_test_data()
752        with self.roundtrip(in_memory) as on_disk:
753            self.validate_array_type(on_disk)
754            indexers = {"dim1": [1, 2, 0], "dim2": [3, 2, 0, 3], "dim3": np.arange(5)}
755            expected = in_memory.isel(**indexers)
756            actual = on_disk.isel(**indexers)
757            assert_identical(expected, actual)
758            self.validate_array_type(actual)
759            # do it twice, to make sure we're switched from orthogonal -> numpy
760            # when we cached the values
761            actual = on_disk.isel(**indexers)
762            assert_identical(expected, actual)
763            self.validate_array_type(actual)
764
765    def test_dropna(self):
766        # regression test for GH:issue:1694
767        a = np.random.randn(4, 3)
768        a[1, 1] = np.NaN
769        in_memory = xr.Dataset(
770            {"a": (("y", "x"), a)}, coords={"y": np.arange(4), "x": np.arange(3)}
771        )
772
773        assert_identical(
774            in_memory.dropna(dim="x"), in_memory.isel(x=slice(None, None, 2))
775        )
776
777        with self.roundtrip(in_memory) as on_disk:
778            self.validate_array_type(on_disk)
779            expected = in_memory.dropna(dim="x")
780            actual = on_disk.dropna(dim="x")
781            assert_identical(expected, actual)
782
783    def test_ondisk_after_print(self):
784        """Make sure print does not load file into memory"""
785        in_memory = create_test_data()
786        with self.roundtrip(in_memory) as on_disk:
787            repr(on_disk)
788            assert not on_disk["var1"]._in_memory
789
790
791class CFEncodedBase(DatasetIOBase):
792    def test_roundtrip_bytes_with_fill_value(self):
793        values = np.array([b"ab", b"cdef", np.nan], dtype=object)
794        encoding = {"_FillValue": b"X", "dtype": "S1"}
795        original = Dataset({"x": ("t", values, {}, encoding)})
796        expected = original.copy(deep=True)
797        with self.roundtrip(original) as actual:
798            assert_identical(expected, actual)
799
800        original = Dataset({"x": ("t", values, {}, {"_FillValue": b""})})
801        with self.roundtrip(original) as actual:
802            assert_identical(expected, actual)
803
804    def test_roundtrip_string_with_fill_value_nchar(self):
805        values = np.array(["ab", "cdef", np.nan], dtype=object)
806        expected = Dataset({"x": ("t", values)})
807
808        encoding = {"dtype": "S1", "_FillValue": b"X"}
809        original = Dataset({"x": ("t", values, {}, encoding)})
810        # Not supported yet.
811        with pytest.raises(NotImplementedError):
812            with self.roundtrip(original) as actual:
813                assert_identical(expected, actual)
814
815    @pytest.mark.parametrize(
816        "decoded_fn, encoded_fn",
817        [
818            (
819                create_unsigned_masked_scaled_data,
820                create_encoded_unsigned_masked_scaled_data,
821            ),
822            pytest.param(
823                create_bad_unsigned_masked_scaled_data,
824                create_bad_encoded_unsigned_masked_scaled_data,
825                marks=pytest.mark.xfail(reason="Bad _Unsigned attribute."),
826            ),
827            (
828                create_signed_masked_scaled_data,
829                create_encoded_signed_masked_scaled_data,
830            ),
831            (create_masked_and_scaled_data, create_encoded_masked_and_scaled_data),
832        ],
833    )
834    def test_roundtrip_mask_and_scale(self, decoded_fn, encoded_fn):
835        decoded = decoded_fn()
836        encoded = encoded_fn()
837
838        with self.roundtrip(decoded) as actual:
839            for k in decoded.variables:
840                assert decoded.variables[k].dtype == actual.variables[k].dtype
841            assert_allclose(decoded, actual, decode_bytes=False)
842
843        with self.roundtrip(decoded, open_kwargs=dict(decode_cf=False)) as actual:
844            # TODO: this assumes that all roundtrips will first
845            # encode.  Is that something we want to test for?
846            for k in encoded.variables:
847                assert encoded.variables[k].dtype == actual.variables[k].dtype
848            assert_allclose(encoded, actual, decode_bytes=False)
849
850        with self.roundtrip(encoded, open_kwargs=dict(decode_cf=False)) as actual:
851            for k in encoded.variables:
852                assert encoded.variables[k].dtype == actual.variables[k].dtype
853            assert_allclose(encoded, actual, decode_bytes=False)
854
855        # make sure roundtrip encoding didn't change the
856        # original dataset.
857        assert_allclose(encoded, encoded_fn(), decode_bytes=False)
858
859        with self.roundtrip(encoded) as actual:
860            for k in decoded.variables:
861                assert decoded.variables[k].dtype == actual.variables[k].dtype
862            assert_allclose(decoded, actual, decode_bytes=False)
863
864    @staticmethod
865    def _create_cf_dataset():
866        original = Dataset(
867            dict(
868                variable=(
869                    ("ln_p", "latitude", "longitude"),
870                    np.arange(8, dtype="f4").reshape(2, 2, 2),
871                    {"ancillary_variables": "std_devs det_lim"},
872                ),
873                std_devs=(
874                    ("ln_p", "latitude", "longitude"),
875                    np.arange(0.1, 0.9, 0.1).reshape(2, 2, 2),
876                    {"standard_name": "standard_error"},
877                ),
878                det_lim=(
879                    (),
880                    0.1,
881                    {"standard_name": "detection_minimum"},
882                ),
883            ),
884            dict(
885                latitude=("latitude", [0, 1], {"units": "degrees_north"}),
886                longitude=("longitude", [0, 1], {"units": "degrees_east"}),
887                latlon=((), -1, {"grid_mapping_name": "latitude_longitude"}),
888                latitude_bnds=(("latitude", "bnds2"), [[0, 1], [1, 2]]),
889                longitude_bnds=(("longitude", "bnds2"), [[0, 1], [1, 2]]),
890                areas=(
891                    ("latitude", "longitude"),
892                    [[1, 1], [1, 1]],
893                    {"units": "degree^2"},
894                ),
895                ln_p=(
896                    "ln_p",
897                    [1.0, 0.5],
898                    {
899                        "standard_name": "atmosphere_ln_pressure_coordinate",
900                        "computed_standard_name": "air_pressure",
901                    },
902                ),
903                P0=((), 1013.25, {"units": "hPa"}),
904            ),
905        )
906        original["variable"].encoding.update(
907            {"cell_measures": "area: areas", "grid_mapping": "latlon"},
908        )
909        original.coords["latitude"].encoding.update(
910            dict(grid_mapping="latlon", bounds="latitude_bnds")
911        )
912        original.coords["longitude"].encoding.update(
913            dict(grid_mapping="latlon", bounds="longitude_bnds")
914        )
915        original.coords["ln_p"].encoding.update({"formula_terms": "p0: P0 lev : ln_p"})
916        return original
917
918    def test_grid_mapping_and_bounds_are_not_coordinates_in_file(self):
919        original = self._create_cf_dataset()
920        with create_tmp_file() as tmp_file:
921            original.to_netcdf(tmp_file)
922            with open_dataset(tmp_file, decode_coords=False) as ds:
923                assert ds.coords["latitude"].attrs["bounds"] == "latitude_bnds"
924                assert ds.coords["longitude"].attrs["bounds"] == "longitude_bnds"
925                assert "coordinates" not in ds["variable"].attrs
926                assert "coordinates" not in ds.attrs
927
928    def test_coordinate_variables_after_dataset_roundtrip(self):
929        original = self._create_cf_dataset()
930        with self.roundtrip(original, open_kwargs={"decode_coords": "all"}) as actual:
931            assert_identical(actual, original)
932
933        with self.roundtrip(original) as actual:
934            expected = original.reset_coords(
935                ["latitude_bnds", "longitude_bnds", "areas", "P0", "latlon"]
936            )
937            # equal checks that coords and data_vars are equal which
938            # should be enough
939            # identical would require resetting a number of attributes
940            # skip that.
941            assert_equal(actual, expected)
942
943    def test_grid_mapping_and_bounds_are_coordinates_after_dataarray_roundtrip(self):
944        original = self._create_cf_dataset()
945        # The DataArray roundtrip should have the same warnings as the
946        # Dataset, but we already tested for those, so just go for the
947        # new warnings.  It would appear that there is no way to tell
948        # pytest "This warning and also this warning should both be
949        # present".
950        # xarray/tests/test_conventions.py::TestCFEncodedDataStore
951        # needs the to_dataset. The other backends should be fine
952        # without it.
953        with pytest.warns(
954            UserWarning,
955            match=(
956                r"Variable\(s\) referenced in bounds not in variables: "
957                r"\['l(at|ong)itude_bnds'\]"
958            ),
959        ):
960            with self.roundtrip(
961                original["variable"].to_dataset(), open_kwargs={"decode_coords": "all"}
962            ) as actual:
963                assert_identical(actual, original["variable"].to_dataset())
964
965    @requires_iris
966    def test_coordinate_variables_after_iris_roundtrip(self):
967        original = self._create_cf_dataset()
968        iris_cube = original["variable"].to_iris()
969        actual = DataArray.from_iris(iris_cube)
970        # Bounds will be missing (xfail)
971        del original.coords["latitude_bnds"], original.coords["longitude_bnds"]
972        # Ancillary vars will be missing
973        # Those are data_vars, and will be dropped when grabbing the variable
974        assert_identical(actual, original["variable"])
975
976    def test_coordinates_encoding(self):
977        def equals_latlon(obj):
978            return obj == "lat lon" or obj == "lon lat"
979
980        original = Dataset(
981            {"temp": ("x", [0, 1]), "precip": ("x", [0, -1])},
982            {"lat": ("x", [2, 3]), "lon": ("x", [4, 5])},
983        )
984        with self.roundtrip(original) as actual:
985            assert_identical(actual, original)
986        with create_tmp_file() as tmp_file:
987            original.to_netcdf(tmp_file)
988            with open_dataset(tmp_file, decode_coords=False) as ds:
989                assert equals_latlon(ds["temp"].attrs["coordinates"])
990                assert equals_latlon(ds["precip"].attrs["coordinates"])
991                assert "coordinates" not in ds.attrs
992                assert "coordinates" not in ds["lat"].attrs
993                assert "coordinates" not in ds["lon"].attrs
994
995        modified = original.drop_vars(["temp", "precip"])
996        with self.roundtrip(modified) as actual:
997            assert_identical(actual, modified)
998        with create_tmp_file() as tmp_file:
999            modified.to_netcdf(tmp_file)
1000            with open_dataset(tmp_file, decode_coords=False) as ds:
1001                assert equals_latlon(ds.attrs["coordinates"])
1002                assert "coordinates" not in ds["lat"].attrs
1003                assert "coordinates" not in ds["lon"].attrs
1004
1005        original["temp"].encoding["coordinates"] = "lat"
1006        with self.roundtrip(original) as actual:
1007            assert_identical(actual, original)
1008        original["precip"].encoding["coordinates"] = "lat"
1009        with create_tmp_file() as tmp_file:
1010            original.to_netcdf(tmp_file)
1011            with open_dataset(tmp_file, decode_coords=True) as ds:
1012                assert "lon" not in ds["temp"].encoding["coordinates"]
1013                assert "lon" not in ds["precip"].encoding["coordinates"]
1014                assert "coordinates" not in ds["lat"].encoding
1015                assert "coordinates" not in ds["lon"].encoding
1016
1017    def test_roundtrip_endian(self):
1018        ds = Dataset(
1019            {
1020                "x": np.arange(3, 10, dtype=">i2"),
1021                "y": np.arange(3, 20, dtype="<i4"),
1022                "z": np.arange(3, 30, dtype="=i8"),
1023                "w": ("x", np.arange(3, 10, dtype=float)),
1024            }
1025        )
1026
1027        with self.roundtrip(ds) as actual:
1028            # technically these datasets are slightly different,
1029            # one hold mixed endian data (ds) the other should be
1030            # all big endian (actual).  assertDatasetIdentical
1031            # should still pass though.
1032            assert_identical(ds, actual)
1033
1034        if self.engine == "netcdf4":
1035            ds["z"].encoding["endian"] = "big"
1036            with pytest.raises(NotImplementedError):
1037                with self.roundtrip(ds) as actual:
1038                    pass
1039
1040    def test_invalid_dataarray_names_raise(self):
1041        te = (TypeError, "string or None")
1042        ve = (ValueError, "string must be length 1 or")
1043        data = np.random.random((2, 2))
1044        da = xr.DataArray(data)
1045        for name, (error, msg) in zip([0, (4, 5), True, ""], [te, te, te, ve]):
1046            ds = Dataset({name: da})
1047            with pytest.raises(error) as excinfo:
1048                with self.roundtrip(ds):
1049                    pass
1050            excinfo.match(msg)
1051            excinfo.match(repr(name))
1052
1053    def test_encoding_kwarg(self):
1054        ds = Dataset({"x": ("y", np.arange(10.0))})
1055        kwargs = dict(encoding={"x": {"dtype": "f4"}})
1056        with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1057            encoded_dtype = actual.x.encoding["dtype"]
1058            # On OS X, dtype sometimes switches endianness for unclear reasons
1059            assert encoded_dtype.kind == "f" and encoded_dtype.itemsize == 4
1060        assert ds.x.encoding == {}
1061
1062        kwargs = dict(encoding={"x": {"foo": "bar"}})
1063        with pytest.raises(ValueError, match=r"unexpected encoding"):
1064            with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1065                pass
1066
1067        kwargs = dict(encoding={"x": "foo"})
1068        with pytest.raises(ValueError, match=r"must be castable"):
1069            with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1070                pass
1071
1072        kwargs = dict(encoding={"invalid": {}})
1073        with pytest.raises(KeyError):
1074            with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1075                pass
1076
1077    def test_encoding_kwarg_dates(self):
1078        ds = Dataset({"t": pd.date_range("2000-01-01", periods=3)})
1079        units = "days since 1900-01-01"
1080        kwargs = dict(encoding={"t": {"units": units}})
1081        with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1082            assert actual.t.encoding["units"] == units
1083            assert_identical(actual, ds)
1084
1085    def test_encoding_kwarg_fixed_width_string(self):
1086        # regression test for GH2149
1087        for strings in [[b"foo", b"bar", b"baz"], ["foo", "bar", "baz"]]:
1088            ds = Dataset({"x": strings})
1089            kwargs = dict(encoding={"x": {"dtype": "S1"}})
1090            with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1091                assert actual["x"].encoding["dtype"] == "S1"
1092                assert_identical(actual, ds)
1093
1094    def test_default_fill_value(self):
1095        # Test default encoding for float:
1096        ds = Dataset({"x": ("y", np.arange(10.0))})
1097        kwargs = dict(encoding={"x": {"dtype": "f4"}})
1098        with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1099            assert math.isnan(actual.x.encoding["_FillValue"])
1100        assert ds.x.encoding == {}
1101
1102        # Test default encoding for int:
1103        ds = Dataset({"x": ("y", np.arange(10.0))})
1104        kwargs = dict(encoding={"x": {"dtype": "int16"}})
1105        with warnings.catch_warnings():
1106            warnings.filterwarnings("ignore", ".*floating point data as an integer")
1107            with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1108                assert "_FillValue" not in actual.x.encoding
1109        assert ds.x.encoding == {}
1110
1111        # Test default encoding for implicit int:
1112        ds = Dataset({"x": ("y", np.arange(10, dtype="int16"))})
1113        with self.roundtrip(ds) as actual:
1114            assert "_FillValue" not in actual.x.encoding
1115        assert ds.x.encoding == {}
1116
1117    def test_explicitly_omit_fill_value(self):
1118        ds = Dataset({"x": ("y", [np.pi, -np.pi])})
1119        ds.x.encoding["_FillValue"] = None
1120        with self.roundtrip(ds) as actual:
1121            assert "_FillValue" not in actual.x.encoding
1122
1123    def test_explicitly_omit_fill_value_via_encoding_kwarg(self):
1124        ds = Dataset({"x": ("y", [np.pi, -np.pi])})
1125        kwargs = dict(encoding={"x": {"_FillValue": None}})
1126        with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1127            assert "_FillValue" not in actual.x.encoding
1128        assert ds.y.encoding == {}
1129
1130    def test_explicitly_omit_fill_value_in_coord(self):
1131        ds = Dataset({"x": ("y", [np.pi, -np.pi])}, coords={"y": [0.0, 1.0]})
1132        ds.y.encoding["_FillValue"] = None
1133        with self.roundtrip(ds) as actual:
1134            assert "_FillValue" not in actual.y.encoding
1135
1136    def test_explicitly_omit_fill_value_in_coord_via_encoding_kwarg(self):
1137        ds = Dataset({"x": ("y", [np.pi, -np.pi])}, coords={"y": [0.0, 1.0]})
1138        kwargs = dict(encoding={"y": {"_FillValue": None}})
1139        with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1140            assert "_FillValue" not in actual.y.encoding
1141        assert ds.y.encoding == {}
1142
1143    def test_encoding_same_dtype(self):
1144        ds = Dataset({"x": ("y", np.arange(10.0, dtype="f4"))})
1145        kwargs = dict(encoding={"x": {"dtype": "f4"}})
1146        with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1147            encoded_dtype = actual.x.encoding["dtype"]
1148            # On OS X, dtype sometimes switches endianness for unclear reasons
1149            assert encoded_dtype.kind == "f" and encoded_dtype.itemsize == 4
1150        assert ds.x.encoding == {}
1151
1152    def test_append_write(self):
1153        # regression for GH1215
1154        data = create_test_data()
1155        with self.roundtrip_append(data) as actual:
1156            assert_identical(data, actual)
1157
1158    def test_append_overwrite_values(self):
1159        # regression for GH1215
1160        data = create_test_data()
1161        with create_tmp_file(allow_cleanup_failure=False) as tmp_file:
1162            self.save(data, tmp_file, mode="w")
1163            data["var2"][:] = -999
1164            data["var9"] = data["var2"] * 3
1165            self.save(data[["var2", "var9"]], tmp_file, mode="a")
1166            with self.open(tmp_file) as actual:
1167                assert_identical(data, actual)
1168
1169    def test_append_with_invalid_dim_raises(self):
1170        data = create_test_data()
1171        with create_tmp_file(allow_cleanup_failure=False) as tmp_file:
1172            self.save(data, tmp_file, mode="w")
1173            data["var9"] = data["var2"] * 3
1174            data = data.isel(dim1=slice(2, 6))  # modify one dimension
1175            with pytest.raises(
1176                ValueError, match=r"Unable to update size for existing dimension"
1177            ):
1178                self.save(data, tmp_file, mode="a")
1179
1180    def test_multiindex_not_implemented(self):
1181        ds = Dataset(coords={"y": ("x", [1, 2]), "z": ("x", ["a", "b"])}).set_index(
1182            x=["y", "z"]
1183        )
1184        with pytest.raises(NotImplementedError, match=r"MultiIndex"):
1185            with self.roundtrip(ds):
1186                pass
1187
1188
1189_counter = itertools.count()
1190
1191
1192@contextlib.contextmanager
1193def create_tmp_file(suffix=".nc", allow_cleanup_failure=False):
1194    temp_dir = tempfile.mkdtemp()
1195    path = os.path.join(temp_dir, "temp-{}{}".format(next(_counter), suffix))
1196    try:
1197        yield path
1198    finally:
1199        try:
1200            shutil.rmtree(temp_dir)
1201        except OSError:
1202            if not allow_cleanup_failure:
1203                raise
1204
1205
1206@contextlib.contextmanager
1207def create_tmp_files(nfiles, suffix=".nc", allow_cleanup_failure=False):
1208    with ExitStack() as stack:
1209        files = [
1210            stack.enter_context(create_tmp_file(suffix, allow_cleanup_failure))
1211            for apath in np.arange(nfiles)
1212        ]
1213        yield files
1214
1215
1216class NetCDF4Base(CFEncodedBase):
1217    """Tests for both netCDF4-python and h5netcdf."""
1218
1219    engine = "netcdf4"
1220
1221    def test_open_group(self):
1222        # Create a netCDF file with a dataset stored within a group
1223        with create_tmp_file() as tmp_file:
1224            with nc4.Dataset(tmp_file, "w") as rootgrp:
1225                foogrp = rootgrp.createGroup("foo")
1226                ds = foogrp
1227                ds.createDimension("time", size=10)
1228                x = np.arange(10)
1229                ds.createVariable("x", np.int32, dimensions=("time",))
1230                ds.variables["x"][:] = x
1231
1232            expected = Dataset()
1233            expected["x"] = ("time", x)
1234
1235            # check equivalent ways to specify group
1236            for group in "foo", "/foo", "foo/", "/foo/":
1237                with self.open(tmp_file, group=group) as actual:
1238                    assert_equal(actual["x"], expected["x"])
1239
1240            # check that missing group raises appropriate exception
1241            with pytest.raises(OSError):
1242                open_dataset(tmp_file, group="bar")
1243            with pytest.raises(ValueError, match=r"must be a string"):
1244                open_dataset(tmp_file, group=(1, 2, 3))
1245
1246    def test_open_subgroup(self):
1247        # Create a netCDF file with a dataset stored within a group within a
1248        # group
1249        with create_tmp_file() as tmp_file:
1250            rootgrp = nc4.Dataset(tmp_file, "w")
1251            foogrp = rootgrp.createGroup("foo")
1252            bargrp = foogrp.createGroup("bar")
1253            ds = bargrp
1254            ds.createDimension("time", size=10)
1255            x = np.arange(10)
1256            ds.createVariable("x", np.int32, dimensions=("time",))
1257            ds.variables["x"][:] = x
1258            rootgrp.close()
1259
1260            expected = Dataset()
1261            expected["x"] = ("time", x)
1262
1263            # check equivalent ways to specify group
1264            for group in "foo/bar", "/foo/bar", "foo/bar/", "/foo/bar/":
1265                with self.open(tmp_file, group=group) as actual:
1266                    assert_equal(actual["x"], expected["x"])
1267
1268    def test_write_groups(self):
1269        data1 = create_test_data()
1270        data2 = data1 * 2
1271        with create_tmp_file() as tmp_file:
1272            self.save(data1, tmp_file, group="data/1")
1273            self.save(data2, tmp_file, group="data/2", mode="a")
1274            with self.open(tmp_file, group="data/1") as actual1:
1275                assert_identical(data1, actual1)
1276            with self.open(tmp_file, group="data/2") as actual2:
1277                assert_identical(data2, actual2)
1278
1279    def test_encoding_kwarg_vlen_string(self):
1280        for input_strings in [[b"foo", b"bar", b"baz"], ["foo", "bar", "baz"]]:
1281            original = Dataset({"x": input_strings})
1282            expected = Dataset({"x": ["foo", "bar", "baz"]})
1283            kwargs = dict(encoding={"x": {"dtype": str}})
1284            with self.roundtrip(original, save_kwargs=kwargs) as actual:
1285                assert actual["x"].encoding["dtype"] is str
1286                assert_identical(actual, expected)
1287
1288    def test_roundtrip_string_with_fill_value_vlen(self):
1289        values = np.array(["ab", "cdef", np.nan], dtype=object)
1290        expected = Dataset({"x": ("t", values)})
1291
1292        # netCDF4-based backends don't support an explicit fillvalue
1293        # for variable length strings yet.
1294        # https://github.com/Unidata/netcdf4-python/issues/730
1295        # https://github.com/shoyer/h5netcdf/issues/37
1296        original = Dataset({"x": ("t", values, {}, {"_FillValue": "XXX"})})
1297        with pytest.raises(NotImplementedError):
1298            with self.roundtrip(original) as actual:
1299                assert_identical(expected, actual)
1300
1301        original = Dataset({"x": ("t", values, {}, {"_FillValue": ""})})
1302        with pytest.raises(NotImplementedError):
1303            with self.roundtrip(original) as actual:
1304                assert_identical(expected, actual)
1305
1306    def test_roundtrip_character_array(self):
1307        with create_tmp_file() as tmp_file:
1308            values = np.array([["a", "b", "c"], ["d", "e", "f"]], dtype="S")
1309
1310            with nc4.Dataset(tmp_file, mode="w") as nc:
1311                nc.createDimension("x", 2)
1312                nc.createDimension("string3", 3)
1313                v = nc.createVariable("x", np.dtype("S1"), ("x", "string3"))
1314                v[:] = values
1315
1316            values = np.array(["abc", "def"], dtype="S")
1317            expected = Dataset({"x": ("x", values)})
1318            with open_dataset(tmp_file) as actual:
1319                assert_identical(expected, actual)
1320                # regression test for #157
1321                with self.roundtrip(actual) as roundtripped:
1322                    assert_identical(expected, roundtripped)
1323
1324    def test_default_to_char_arrays(self):
1325        data = Dataset({"x": np.array(["foo", "zzzz"], dtype="S")})
1326        with self.roundtrip(data) as actual:
1327            assert_identical(data, actual)
1328            assert actual["x"].dtype == np.dtype("S4")
1329
1330    def test_open_encodings(self):
1331        # Create a netCDF file with explicit time units
1332        # and make sure it makes it into the encodings
1333        # and survives a round trip
1334        with create_tmp_file() as tmp_file:
1335            with nc4.Dataset(tmp_file, "w") as ds:
1336                ds.createDimension("time", size=10)
1337                ds.createVariable("time", np.int32, dimensions=("time",))
1338                units = "days since 1999-01-01"
1339                ds.variables["time"].setncattr("units", units)
1340                ds.variables["time"][:] = np.arange(10) + 4
1341
1342            expected = Dataset()
1343
1344            time = pd.date_range("1999-01-05", periods=10)
1345            encoding = {"units": units, "dtype": np.dtype("int32")}
1346            expected["time"] = ("time", time, {}, encoding)
1347
1348            with open_dataset(tmp_file) as actual:
1349                assert_equal(actual["time"], expected["time"])
1350                actual_encoding = {
1351                    k: v
1352                    for k, v in actual["time"].encoding.items()
1353                    if k in expected["time"].encoding
1354                }
1355                assert actual_encoding == expected["time"].encoding
1356
1357    def test_dump_encodings(self):
1358        # regression test for #709
1359        ds = Dataset({"x": ("y", np.arange(10.0))})
1360        kwargs = dict(encoding={"x": {"zlib": True}})
1361        with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1362            assert actual.x.encoding["zlib"]
1363
1364    def test_dump_and_open_encodings(self):
1365        # Create a netCDF file with explicit time units
1366        # and make sure it makes it into the encodings
1367        # and survives a round trip
1368        with create_tmp_file() as tmp_file:
1369            with nc4.Dataset(tmp_file, "w") as ds:
1370                ds.createDimension("time", size=10)
1371                ds.createVariable("time", np.int32, dimensions=("time",))
1372                units = "days since 1999-01-01"
1373                ds.variables["time"].setncattr("units", units)
1374                ds.variables["time"][:] = np.arange(10) + 4
1375
1376            with open_dataset(tmp_file) as xarray_dataset:
1377                with create_tmp_file() as tmp_file2:
1378                    xarray_dataset.to_netcdf(tmp_file2)
1379                    with nc4.Dataset(tmp_file2, "r") as ds:
1380                        assert ds.variables["time"].getncattr("units") == units
1381                        assert_array_equal(ds.variables["time"], np.arange(10) + 4)
1382
1383    def test_compression_encoding(self):
1384        data = create_test_data()
1385        data["var2"].encoding.update(
1386            {
1387                "zlib": True,
1388                "chunksizes": (5, 5),
1389                "fletcher32": True,
1390                "shuffle": True,
1391                "original_shape": data.var2.shape,
1392            }
1393        )
1394        with self.roundtrip(data) as actual:
1395            for k, v in data["var2"].encoding.items():
1396                assert v == actual["var2"].encoding[k]
1397
1398        # regression test for #156
1399        expected = data.isel(dim1=0)
1400        with self.roundtrip(expected) as actual:
1401            assert_equal(expected, actual)
1402
1403    def test_encoding_kwarg_compression(self):
1404        ds = Dataset({"x": np.arange(10.0)})
1405        encoding = dict(
1406            dtype="f4",
1407            zlib=True,
1408            complevel=9,
1409            fletcher32=True,
1410            chunksizes=(5,),
1411            shuffle=True,
1412        )
1413        kwargs = dict(encoding=dict(x=encoding))
1414
1415        with self.roundtrip(ds, save_kwargs=kwargs) as actual:
1416            assert_equal(actual, ds)
1417            assert actual.x.encoding["dtype"] == "f4"
1418            assert actual.x.encoding["zlib"]
1419            assert actual.x.encoding["complevel"] == 9
1420            assert actual.x.encoding["fletcher32"]
1421            assert actual.x.encoding["chunksizes"] == (5,)
1422            assert actual.x.encoding["shuffle"]
1423
1424        assert ds.x.encoding == {}
1425
1426    def test_keep_chunksizes_if_no_original_shape(self):
1427        ds = Dataset({"x": [1, 2, 3]})
1428        chunksizes = (2,)
1429        ds.variables["x"].encoding = {"chunksizes": chunksizes}
1430
1431        with self.roundtrip(ds) as actual:
1432            assert_identical(ds, actual)
1433            assert_array_equal(
1434                ds["x"].encoding["chunksizes"], actual["x"].encoding["chunksizes"]
1435            )
1436
1437    def test_encoding_chunksizes_unlimited(self):
1438        # regression test for GH1225
1439        ds = Dataset({"x": [1, 2, 3], "y": ("x", [2, 3, 4])})
1440        ds.variables["x"].encoding = {
1441            "zlib": False,
1442            "shuffle": False,
1443            "complevel": 0,
1444            "fletcher32": False,
1445            "contiguous": False,
1446            "chunksizes": (2 ** 20,),
1447            "original_shape": (3,),
1448        }
1449        with self.roundtrip(ds) as actual:
1450            assert_equal(ds, actual)
1451
1452    def test_mask_and_scale(self):
1453        with create_tmp_file() as tmp_file:
1454            with nc4.Dataset(tmp_file, mode="w") as nc:
1455                nc.createDimension("t", 5)
1456                nc.createVariable("x", "int16", ("t",), fill_value=-1)
1457                v = nc.variables["x"]
1458                v.set_auto_maskandscale(False)
1459                v.add_offset = 10
1460                v.scale_factor = 0.1
1461                v[:] = np.array([-1, -1, 0, 1, 2])
1462
1463            # first make sure netCDF4 reads the masked and scaled data
1464            # correctly
1465            with nc4.Dataset(tmp_file, mode="r") as nc:
1466                expected = np.ma.array(
1467                    [-1, -1, 10, 10.1, 10.2], mask=[True, True, False, False, False]
1468                )
1469                actual = nc.variables["x"][:]
1470                assert_array_equal(expected, actual)
1471
1472            # now check xarray
1473            with open_dataset(tmp_file) as ds:
1474                expected = create_masked_and_scaled_data()
1475                assert_identical(expected, ds)
1476
1477    def test_0dimensional_variable(self):
1478        # This fix verifies our work-around to this netCDF4-python bug:
1479        # https://github.com/Unidata/netcdf4-python/pull/220
1480        with create_tmp_file() as tmp_file:
1481            with nc4.Dataset(tmp_file, mode="w") as nc:
1482                v = nc.createVariable("x", "int16")
1483                v[...] = 123
1484
1485            with open_dataset(tmp_file) as ds:
1486                expected = Dataset({"x": ((), 123)})
1487                assert_identical(expected, ds)
1488
1489    def test_read_variable_len_strings(self):
1490        with create_tmp_file() as tmp_file:
1491            values = np.array(["foo", "bar", "baz"], dtype=object)
1492
1493            with nc4.Dataset(tmp_file, mode="w") as nc:
1494                nc.createDimension("x", 3)
1495                v = nc.createVariable("x", str, ("x",))
1496                v[:] = values
1497
1498            expected = Dataset({"x": ("x", values)})
1499            for kwargs in [{}, {"decode_cf": True}]:
1500                with open_dataset(tmp_file, **kwargs) as actual:
1501                    assert_identical(expected, actual)
1502
1503    def test_encoding_unlimited_dims(self):
1504        ds = Dataset({"x": ("y", np.arange(10.0))})
1505        with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=["y"])) as actual:
1506            assert actual.encoding["unlimited_dims"] == set("y")
1507            assert_equal(ds, actual)
1508        ds.encoding = {"unlimited_dims": ["y"]}
1509        with self.roundtrip(ds) as actual:
1510            assert actual.encoding["unlimited_dims"] == set("y")
1511            assert_equal(ds, actual)
1512
1513
1514@requires_netCDF4
1515class TestNetCDF4Data(NetCDF4Base):
1516    @contextlib.contextmanager
1517    def create_store(self):
1518        with create_tmp_file() as tmp_file:
1519            with backends.NetCDF4DataStore.open(tmp_file, mode="w") as store:
1520                yield store
1521
1522    def test_variable_order(self):
1523        # doesn't work with scipy or h5py :(
1524        ds = Dataset()
1525        ds["a"] = 1
1526        ds["z"] = 2
1527        ds["b"] = 3
1528        ds.coords["c"] = 4
1529
1530        with self.roundtrip(ds) as actual:
1531            assert list(ds.variables) == list(actual.variables)
1532
1533    def test_unsorted_index_raises(self):
1534        # should be fixed in netcdf4 v1.2.1
1535        random_data = np.random.random(size=(4, 6))
1536        dim0 = [0, 1, 2, 3]
1537        dim1 = [0, 2, 1, 3, 5, 4]  # We will sort this in a later step
1538        da = xr.DataArray(
1539            data=random_data,
1540            dims=("dim0", "dim1"),
1541            coords={"dim0": dim0, "dim1": dim1},
1542            name="randovar",
1543        )
1544        ds = da.to_dataset()
1545
1546        with self.roundtrip(ds) as ondisk:
1547            inds = np.argsort(dim1)
1548            ds2 = ondisk.isel(dim1=inds)
1549            # Older versions of NetCDF4 raise an exception here, and if so we
1550            # want to ensure we improve (that is, replace) the error message
1551            try:
1552                ds2.randovar.values
1553            except IndexError as err:
1554                assert "first by calling .load" in str(err)
1555
1556    def test_setncattr_string(self):
1557        list_of_strings = ["list", "of", "strings"]
1558        one_element_list_of_strings = ["one element"]
1559        one_string = "one string"
1560        attrs = {
1561            "foo": list_of_strings,
1562            "bar": one_element_list_of_strings,
1563            "baz": one_string,
1564        }
1565        ds = Dataset({"x": ("y", [1, 2, 3], attrs)}, attrs=attrs)
1566
1567        with self.roundtrip(ds) as actual:
1568            for totest in [actual, actual["x"]]:
1569                assert_array_equal(list_of_strings, totest.attrs["foo"])
1570                assert_array_equal(one_element_list_of_strings, totest.attrs["bar"])
1571                assert one_string == totest.attrs["baz"]
1572
1573
1574@requires_netCDF4
1575class TestNetCDF4AlreadyOpen:
1576    def test_base_case(self):
1577        with create_tmp_file() as tmp_file:
1578            with nc4.Dataset(tmp_file, mode="w") as nc:
1579                v = nc.createVariable("x", "int")
1580                v[...] = 42
1581
1582            nc = nc4.Dataset(tmp_file, mode="r")
1583            store = backends.NetCDF4DataStore(nc)
1584            with open_dataset(store) as ds:
1585                expected = Dataset({"x": ((), 42)})
1586                assert_identical(expected, ds)
1587
1588    def test_group(self):
1589        with create_tmp_file() as tmp_file:
1590            with nc4.Dataset(tmp_file, mode="w") as nc:
1591                group = nc.createGroup("g")
1592                v = group.createVariable("x", "int")
1593                v[...] = 42
1594
1595            nc = nc4.Dataset(tmp_file, mode="r")
1596            store = backends.NetCDF4DataStore(nc.groups["g"])
1597            with open_dataset(store) as ds:
1598                expected = Dataset({"x": ((), 42)})
1599                assert_identical(expected, ds)
1600
1601            nc = nc4.Dataset(tmp_file, mode="r")
1602            store = backends.NetCDF4DataStore(nc, group="g")
1603            with open_dataset(store) as ds:
1604                expected = Dataset({"x": ((), 42)})
1605                assert_identical(expected, ds)
1606
1607            with nc4.Dataset(tmp_file, mode="r") as nc:
1608                with pytest.raises(ValueError, match="must supply a root"):
1609                    backends.NetCDF4DataStore(nc.groups["g"], group="g")
1610
1611    def test_deepcopy(self):
1612        # regression test for https://github.com/pydata/xarray/issues/4425
1613        with create_tmp_file() as tmp_file:
1614            with nc4.Dataset(tmp_file, mode="w") as nc:
1615                nc.createDimension("x", 10)
1616                v = nc.createVariable("y", np.int32, ("x",))
1617                v[:] = np.arange(10)
1618
1619            h5 = nc4.Dataset(tmp_file, mode="r")
1620            store = backends.NetCDF4DataStore(h5)
1621            with open_dataset(store) as ds:
1622                copied = ds.copy(deep=True)
1623                expected = Dataset({"y": ("x", np.arange(10))})
1624                assert_identical(expected, copied)
1625
1626
1627@requires_netCDF4
1628@requires_dask
1629@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager")
1630class TestNetCDF4ViaDaskData(TestNetCDF4Data):
1631    @contextlib.contextmanager
1632    def roundtrip(
1633        self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False
1634    ):
1635        if open_kwargs is None:
1636            open_kwargs = {}
1637        if save_kwargs is None:
1638            save_kwargs = {}
1639        open_kwargs.setdefault("chunks", -1)
1640        with TestNetCDF4Data.roundtrip(
1641            self, data, save_kwargs, open_kwargs, allow_cleanup_failure
1642        ) as ds:
1643            yield ds
1644
1645    def test_unsorted_index_raises(self):
1646        # Skip when using dask because dask rewrites indexers to getitem,
1647        # dask first pulls items by block.
1648        pass
1649
1650    def test_dataset_caching(self):
1651        # caching behavior differs for dask
1652        pass
1653
1654    def test_write_inconsistent_chunks(self):
1655        # Construct two variables with the same dimensions, but different
1656        # chunk sizes.
1657        x = da.zeros((100, 100), dtype="f4", chunks=(50, 100))
1658        x = DataArray(data=x, dims=("lat", "lon"), name="x")
1659        x.encoding["chunksizes"] = (50, 100)
1660        x.encoding["original_shape"] = (100, 100)
1661        y = da.ones((100, 100), dtype="f4", chunks=(100, 50))
1662        y = DataArray(data=y, dims=("lat", "lon"), name="y")
1663        y.encoding["chunksizes"] = (100, 50)
1664        y.encoding["original_shape"] = (100, 100)
1665        # Put them both into the same dataset
1666        ds = Dataset({"x": x, "y": y})
1667        with self.roundtrip(ds) as actual:
1668            assert actual["x"].encoding["chunksizes"] == (50, 100)
1669            assert actual["y"].encoding["chunksizes"] == (100, 50)
1670
1671
1672@requires_zarr
1673class ZarrBase(CFEncodedBase):
1674
1675    DIMENSION_KEY = "_ARRAY_DIMENSIONS"
1676
1677    def create_zarr_target(self):
1678        raise NotImplementedError
1679
1680    @contextlib.contextmanager
1681    def create_store(self):
1682        with self.create_zarr_target() as store_target:
1683            yield backends.ZarrStore.open_group(store_target, mode="w")
1684
1685    def save(self, dataset, store_target, **kwargs):
1686        return dataset.to_zarr(store=store_target, **kwargs)
1687
1688    @contextlib.contextmanager
1689    def open(self, store_target, **kwargs):
1690        with xr.open_dataset(store_target, engine="zarr", **kwargs) as ds:
1691            yield ds
1692
1693    @contextlib.contextmanager
1694    def roundtrip(
1695        self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False
1696    ):
1697        if save_kwargs is None:
1698            save_kwargs = {}
1699        if open_kwargs is None:
1700            open_kwargs = {}
1701        with self.create_zarr_target() as store_target:
1702            self.save(data, store_target, **save_kwargs)
1703            with self.open(store_target, **open_kwargs) as ds:
1704                yield ds
1705
1706    @pytest.mark.parametrize("consolidated", [False, True, None])
1707    def test_roundtrip_consolidated(self, consolidated):
1708        expected = create_test_data()
1709        with self.roundtrip(
1710            expected,
1711            save_kwargs={"consolidated": True},
1712            open_kwargs={"backend_kwargs": {"consolidated": True}},
1713        ) as actual:
1714            self.check_dtypes_roundtripped(expected, actual)
1715            assert_identical(expected, actual)
1716
1717    def test_read_non_consolidated_warning(self):
1718        expected = create_test_data()
1719        with self.create_zarr_target() as store:
1720            expected.to_zarr(store, consolidated=False)
1721            with pytest.warns(
1722                RuntimeWarning,
1723                match="Failed to open Zarr store with consolidated",
1724            ):
1725                with xr.open_zarr(store) as ds:
1726                    assert_identical(ds, expected)
1727
1728    def test_with_chunkstore(self):
1729        expected = create_test_data()
1730        with self.create_zarr_target() as store_target, self.create_zarr_target() as chunk_store:
1731            save_kwargs = {"chunk_store": chunk_store}
1732            self.save(expected, store_target, **save_kwargs)
1733            open_kwargs = {"backend_kwargs": {"chunk_store": chunk_store}}
1734            with self.open(store_target, **open_kwargs) as ds:
1735                assert_equal(ds, expected)
1736
1737    @requires_dask
1738    def test_auto_chunk(self):
1739        original = create_test_data().chunk()
1740
1741        with self.roundtrip(original, open_kwargs={"chunks": None}) as actual:
1742            for k, v in actual.variables.items():
1743                # only index variables should be in memory
1744                assert v._in_memory == (k in actual.dims)
1745                # there should be no chunks
1746                assert v.chunks is None
1747
1748        with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual:
1749            for k, v in actual.variables.items():
1750                # only index variables should be in memory
1751                assert v._in_memory == (k in actual.dims)
1752                # chunk size should be the same as original
1753                assert v.chunks == original[k].chunks
1754
1755    @requires_dask
1756    @pytest.mark.filterwarnings("ignore:Specified Dask chunks")
1757    def test_manual_chunk(self):
1758        original = create_test_data().chunk({"dim1": 3, "dim2": 4, "dim3": 3})
1759
1760        # Using chunks = None should return non-chunked arrays
1761        open_kwargs = {"chunks": None}
1762        with self.roundtrip(original, open_kwargs=open_kwargs) as actual:
1763            for k, v in actual.variables.items():
1764                # only index variables should be in memory
1765                assert v._in_memory == (k in actual.dims)
1766                # there should be no chunks
1767                assert v.chunks is None
1768
1769        # uniform arrays
1770        for i in range(2, 6):
1771            rechunked = original.chunk(chunks=i)
1772            open_kwargs = {"chunks": i}
1773            with self.roundtrip(original, open_kwargs=open_kwargs) as actual:
1774                for k, v in actual.variables.items():
1775                    # only index variables should be in memory
1776                    assert v._in_memory == (k in actual.dims)
1777                    # chunk size should be the same as rechunked
1778                    assert v.chunks == rechunked[k].chunks
1779
1780        chunks = {"dim1": 2, "dim2": 3, "dim3": 5}
1781        rechunked = original.chunk(chunks=chunks)
1782
1783        open_kwargs = {
1784            "chunks": chunks,
1785            "backend_kwargs": {"overwrite_encoded_chunks": True},
1786        }
1787        with self.roundtrip(original, open_kwargs=open_kwargs) as actual:
1788            for k, v in actual.variables.items():
1789                assert v.chunks == rechunked[k].chunks
1790
1791            with self.roundtrip(actual) as auto:
1792                # encoding should have changed
1793                for k, v in actual.variables.items():
1794                    assert v.chunks == rechunked[k].chunks
1795
1796                assert_identical(actual, auto)
1797                assert_identical(actual.load(), auto.load())
1798
1799    @requires_dask
1800    def test_warning_on_bad_chunks(self):
1801        original = create_test_data().chunk({"dim1": 4, "dim2": 3, "dim3": 3})
1802
1803        bad_chunks = (2, {"dim2": (3, 3, 2, 1)})
1804        for chunks in bad_chunks:
1805            kwargs = {"chunks": chunks}
1806            with pytest.warns(UserWarning):
1807                with self.roundtrip(original, open_kwargs=kwargs) as actual:
1808                    for k, v in actual.variables.items():
1809                        # only index variables should be in memory
1810                        assert v._in_memory == (k in actual.dims)
1811
1812        good_chunks = ({"dim2": 3}, {"dim3": (6, 4)}, {})
1813        for chunks in good_chunks:
1814            kwargs = {"chunks": chunks}
1815            with pytest.warns(None) as record:
1816                with self.roundtrip(original, open_kwargs=kwargs) as actual:
1817                    for k, v in actual.variables.items():
1818                        # only index variables should be in memory
1819                        assert v._in_memory == (k in actual.dims)
1820            assert len(record) == 0
1821
1822    @requires_dask
1823    def test_deprecate_auto_chunk(self):
1824        original = create_test_data().chunk()
1825        with pytest.raises(TypeError):
1826            with self.roundtrip(original, open_kwargs={"auto_chunk": True}) as actual:
1827                for k, v in actual.variables.items():
1828                    # only index variables should be in memory
1829                    assert v._in_memory == (k in actual.dims)
1830                    # chunk size should be the same as original
1831                    assert v.chunks == original[k].chunks
1832
1833        with pytest.raises(TypeError):
1834            with self.roundtrip(original, open_kwargs={"auto_chunk": False}) as actual:
1835                for k, v in actual.variables.items():
1836                    # only index variables should be in memory
1837                    assert v._in_memory == (k in actual.dims)
1838                    # there should be no chunks
1839                    assert v.chunks is None
1840
1841    @requires_dask
1842    def test_write_uneven_dask_chunks(self):
1843        # regression for GH#2225
1844        original = create_test_data().chunk({"dim1": 3, "dim2": 4, "dim3": 3})
1845        with self.roundtrip(original, open_kwargs={"chunks": {}}) as actual:
1846            for k, v in actual.data_vars.items():
1847                print(k)
1848                assert v.chunks == actual[k].chunks
1849
1850    def test_chunk_encoding(self):
1851        # These datasets have no dask chunks. All chunking specified in
1852        # encoding
1853        data = create_test_data()
1854        chunks = (5, 5)
1855        data["var2"].encoding.update({"chunks": chunks})
1856
1857        with self.roundtrip(data) as actual:
1858            assert chunks == actual["var2"].encoding["chunks"]
1859
1860        # expect an error with non-integer chunks
1861        data["var2"].encoding.update({"chunks": (5, 4.5)})
1862        with pytest.raises(TypeError):
1863            with self.roundtrip(data) as actual:
1864                pass
1865
1866    @requires_dask
1867    def test_chunk_encoding_with_dask(self):
1868        # These datasets DO have dask chunks. Need to check for various
1869        # interactions between dask and zarr chunks
1870        ds = xr.DataArray((np.arange(12)), dims="x", name="var1").to_dataset()
1871
1872        # - no encoding specified -
1873        # zarr automatically gets chunk information from dask chunks
1874        ds_chunk4 = ds.chunk({"x": 4})
1875        with self.roundtrip(ds_chunk4) as actual:
1876            assert (4,) == actual["var1"].encoding["chunks"]
1877
1878        # should fail if dask_chunks are irregular...
1879        ds_chunk_irreg = ds.chunk({"x": (5, 4, 3)})
1880        with pytest.raises(ValueError, match=r"uniform chunk sizes."):
1881            with self.roundtrip(ds_chunk_irreg) as actual:
1882                pass
1883
1884        # should fail if encoding["chunks"] clashes with dask_chunks
1885        badenc = ds.chunk({"x": 4})
1886        badenc.var1.encoding["chunks"] = (6,)
1887        with pytest.raises(NotImplementedError, match=r"named 'var1' would overlap"):
1888            with self.roundtrip(badenc) as actual:
1889                pass
1890
1891        # unless...
1892        with self.roundtrip(badenc, save_kwargs={"safe_chunks": False}) as actual:
1893            # don't actually check equality because the data could be corrupted
1894            pass
1895
1896        badenc.var1.encoding["chunks"] = (2,)
1897        with pytest.raises(NotImplementedError, match=r"Specified Zarr chunk encoding"):
1898            with self.roundtrip(badenc) as actual:
1899                pass
1900
1901        badenc = badenc.chunk({"x": (3, 3, 6)})
1902        badenc.var1.encoding["chunks"] = (3,)
1903        with pytest.raises(
1904            NotImplementedError, match=r"incompatible with this encoding"
1905        ):
1906            with self.roundtrip(badenc) as actual:
1907                pass
1908
1909        # ... except if the last chunk is smaller than the first
1910        ds_chunk_irreg = ds.chunk({"x": (5, 5, 2)})
1911        with self.roundtrip(ds_chunk_irreg) as actual:
1912            assert (5,) == actual["var1"].encoding["chunks"]
1913        # re-save Zarr arrays
1914        with self.roundtrip(ds_chunk_irreg) as original:
1915            with self.roundtrip(original) as actual:
1916                assert_identical(original, actual)
1917
1918        # - encoding specified  -
1919        # specify compatible encodings
1920        for chunk_enc in 4, (4,):
1921            ds_chunk4["var1"].encoding.update({"chunks": chunk_enc})
1922            with self.roundtrip(ds_chunk4) as actual:
1923                assert (4,) == actual["var1"].encoding["chunks"]
1924
1925        # TODO: remove this failure once syncronized overlapping writes are
1926        # supported by xarray
1927        ds_chunk4["var1"].encoding.update({"chunks": 5})
1928        with pytest.raises(NotImplementedError, match=r"named 'var1' would overlap"):
1929            with self.roundtrip(ds_chunk4) as actual:
1930                pass
1931        # override option
1932        with self.roundtrip(ds_chunk4, save_kwargs={"safe_chunks": False}) as actual:
1933            # don't actually check equality because the data could be corrupted
1934            pass
1935
1936    def test_hidden_zarr_keys(self):
1937        expected = create_test_data()
1938        with self.create_store() as store:
1939            expected.dump_to_store(store)
1940            zarr_group = store.ds
1941
1942            # check that a variable hidden attribute is present and correct
1943            # JSON only has a single array type, which maps to list in Python.
1944            # In contrast, dims in xarray is always a tuple.
1945            for var in expected.variables.keys():
1946                dims = zarr_group[var].attrs[self.DIMENSION_KEY]
1947                assert dims == list(expected[var].dims)
1948
1949            with xr.decode_cf(store):
1950                # make sure it is hidden
1951                for var in expected.variables.keys():
1952                    assert self.DIMENSION_KEY not in expected[var].attrs
1953
1954            # put it back and try removing from a variable
1955            del zarr_group.var2.attrs[self.DIMENSION_KEY]
1956            with pytest.raises(KeyError):
1957                with xr.decode_cf(store):
1958                    pass
1959
1960    @pytest.mark.parametrize("group", [None, "group1"])
1961    def test_write_persistence_modes(self, group):
1962        original = create_test_data()
1963
1964        # overwrite mode
1965        with self.roundtrip(
1966            original,
1967            save_kwargs={"mode": "w", "group": group},
1968            open_kwargs={"group": group},
1969        ) as actual:
1970            assert_identical(original, actual)
1971
1972        # don't overwrite mode
1973        with self.roundtrip(
1974            original,
1975            save_kwargs={"mode": "w-", "group": group},
1976            open_kwargs={"group": group},
1977        ) as actual:
1978            assert_identical(original, actual)
1979
1980        # make sure overwriting works as expected
1981        with self.create_zarr_target() as store:
1982            self.save(original, store)
1983            # should overwrite with no error
1984            self.save(original, store, mode="w", group=group)
1985            with self.open(store, group=group) as actual:
1986                assert_identical(original, actual)
1987                with pytest.raises(ValueError):
1988                    self.save(original, store, mode="w-")
1989
1990        # check append mode for normal write
1991        with self.roundtrip(
1992            original,
1993            save_kwargs={"mode": "a", "group": group},
1994            open_kwargs={"group": group},
1995        ) as actual:
1996            assert_identical(original, actual)
1997
1998        # check append mode for append write
1999        ds, ds_to_append, _ = create_append_test_data()
2000        with self.create_zarr_target() as store_target:
2001            ds.to_zarr(store_target, mode="w", group=group)
2002            ds_to_append.to_zarr(store_target, append_dim="time", group=group)
2003            original = xr.concat([ds, ds_to_append], dim="time")
2004            actual = xr.open_dataset(store_target, group=group, engine="zarr")
2005            assert_identical(original, actual)
2006
2007    def test_compressor_encoding(self):
2008        original = create_test_data()
2009        # specify a custom compressor
2010        import zarr
2011
2012        blosc_comp = zarr.Blosc(cname="zstd", clevel=3, shuffle=2)
2013        save_kwargs = dict(encoding={"var1": {"compressor": blosc_comp}})
2014        with self.roundtrip(original, save_kwargs=save_kwargs) as ds:
2015            actual = ds["var1"].encoding["compressor"]
2016            # get_config returns a dictionary of compressor attributes
2017            assert actual.get_config() == blosc_comp.get_config()
2018
2019    def test_group(self):
2020        original = create_test_data()
2021        group = "some/random/path"
2022        with self.roundtrip(
2023            original, save_kwargs={"group": group}, open_kwargs={"group": group}
2024        ) as actual:
2025            assert_identical(original, actual)
2026
2027    def test_encoding_kwarg_fixed_width_string(self):
2028        # not relevant for zarr, since we don't use EncodedStringCoder
2029        pass
2030
2031    # TODO: someone who understand caching figure out whether caching
2032    # makes sense for Zarr backend
2033    @pytest.mark.xfail(reason="Zarr caching not implemented")
2034    def test_dataset_caching(self):
2035        super().test_dataset_caching()
2036
2037    def test_append_write(self):
2038        super().test_append_write()
2039
2040    def test_append_with_mode_rplus_success(self):
2041        original = Dataset({"foo": ("x", [1])})
2042        modified = Dataset({"foo": ("x", [2])})
2043        with self.create_zarr_target() as store:
2044            original.to_zarr(store)
2045            modified.to_zarr(store, mode="r+")
2046            with self.open(store) as actual:
2047                assert_identical(actual, modified)
2048
2049    def test_append_with_mode_rplus_fails(self):
2050        original = Dataset({"foo": ("x", [1])})
2051        modified = Dataset({"bar": ("x", [2])})
2052        with self.create_zarr_target() as store:
2053            original.to_zarr(store)
2054            with pytest.raises(
2055                ValueError, match="dataset contains non-pre-existing variables"
2056            ):
2057                modified.to_zarr(store, mode="r+")
2058
2059    def test_append_with_invalid_dim_raises(self):
2060        ds, ds_to_append, _ = create_append_test_data()
2061        with self.create_zarr_target() as store_target:
2062            ds.to_zarr(store_target, mode="w")
2063            with pytest.raises(
2064                ValueError, match="does not match any existing dataset dimensions"
2065            ):
2066                ds_to_append.to_zarr(store_target, append_dim="notvalid")
2067
2068    def test_append_with_no_dims_raises(self):
2069        with self.create_zarr_target() as store_target:
2070            Dataset({"foo": ("x", [1])}).to_zarr(store_target, mode="w")
2071            with pytest.raises(ValueError, match="different dimension names"):
2072                Dataset({"foo": ("y", [2])}).to_zarr(store_target, mode="a")
2073
2074    def test_append_with_append_dim_not_set_raises(self):
2075        ds, ds_to_append, _ = create_append_test_data()
2076        with self.create_zarr_target() as store_target:
2077            ds.to_zarr(store_target, mode="w")
2078            with pytest.raises(ValueError, match="different dimension sizes"):
2079                ds_to_append.to_zarr(store_target, mode="a")
2080
2081    def test_append_with_mode_not_a_raises(self):
2082        ds, ds_to_append, _ = create_append_test_data()
2083        with self.create_zarr_target() as store_target:
2084            ds.to_zarr(store_target, mode="w")
2085            with pytest.raises(ValueError, match="cannot set append_dim unless"):
2086                ds_to_append.to_zarr(store_target, mode="w", append_dim="time")
2087
2088    def test_append_with_existing_encoding_raises(self):
2089        ds, ds_to_append, _ = create_append_test_data()
2090        with self.create_zarr_target() as store_target:
2091            ds.to_zarr(store_target, mode="w")
2092            with pytest.raises(ValueError, match="but encoding was provided"):
2093                ds_to_append.to_zarr(
2094                    store_target,
2095                    append_dim="time",
2096                    encoding={"da": {"compressor": None}},
2097                )
2098
2099    def test_check_encoding_is_consistent_after_append(self):
2100
2101        ds, ds_to_append, _ = create_append_test_data()
2102
2103        # check encoding consistency
2104        with self.create_zarr_target() as store_target:
2105            import zarr
2106
2107            compressor = zarr.Blosc()
2108            encoding = {"da": {"compressor": compressor}}
2109            ds.to_zarr(store_target, mode="w", encoding=encoding)
2110            ds_to_append.to_zarr(store_target, append_dim="time")
2111            actual_ds = xr.open_dataset(store_target, engine="zarr")
2112            actual_encoding = actual_ds["da"].encoding["compressor"]
2113            assert actual_encoding.get_config() == compressor.get_config()
2114            assert_identical(
2115                xr.open_dataset(store_target, engine="zarr").compute(),
2116                xr.concat([ds, ds_to_append], dim="time"),
2117            )
2118
2119    def test_append_with_new_variable(self):
2120
2121        ds, ds_to_append, ds_with_new_var = create_append_test_data()
2122
2123        # check append mode for new variable
2124        with self.create_zarr_target() as store_target:
2125            xr.concat([ds, ds_to_append], dim="time").to_zarr(store_target, mode="w")
2126            ds_with_new_var.to_zarr(store_target, mode="a")
2127            combined = xr.concat([ds, ds_to_append], dim="time")
2128            combined["new_var"] = ds_with_new_var["new_var"]
2129            assert_identical(combined, xr.open_dataset(store_target, engine="zarr"))
2130
2131    @requires_dask
2132    def test_to_zarr_compute_false_roundtrip(self):
2133        from dask.delayed import Delayed
2134
2135        original = create_test_data().chunk()
2136
2137        with self.create_zarr_target() as store:
2138            delayed_obj = self.save(original, store, compute=False)
2139            assert isinstance(delayed_obj, Delayed)
2140
2141            # make sure target store has not been written to yet
2142            with pytest.raises(AssertionError):
2143                with self.open(store) as actual:
2144                    assert_identical(original, actual)
2145
2146            delayed_obj.compute()
2147
2148            with self.open(store) as actual:
2149                assert_identical(original, actual)
2150
2151    @requires_dask
2152    def test_to_zarr_append_compute_false_roundtrip(self):
2153        from dask.delayed import Delayed
2154
2155        ds, ds_to_append, _ = create_append_test_data()
2156        ds, ds_to_append = ds.chunk(), ds_to_append.chunk()
2157
2158        with pytest.warns(SerializationWarning):
2159            with self.create_zarr_target() as store:
2160                delayed_obj = self.save(ds, store, compute=False, mode="w")
2161                assert isinstance(delayed_obj, Delayed)
2162
2163                with pytest.raises(AssertionError):
2164                    with self.open(store) as actual:
2165                        assert_identical(ds, actual)
2166
2167                delayed_obj.compute()
2168
2169                with self.open(store) as actual:
2170                    assert_identical(ds, actual)
2171
2172                delayed_obj = self.save(
2173                    ds_to_append, store, compute=False, append_dim="time"
2174                )
2175                assert isinstance(delayed_obj, Delayed)
2176
2177                with pytest.raises(AssertionError):
2178                    with self.open(store) as actual:
2179                        assert_identical(
2180                            xr.concat([ds, ds_to_append], dim="time"), actual
2181                        )
2182
2183                delayed_obj.compute()
2184
2185                with self.open(store) as actual:
2186                    assert_identical(xr.concat([ds, ds_to_append], dim="time"), actual)
2187
2188    @pytest.mark.parametrize("chunk", [False, True])
2189    def test_save_emptydim(self, chunk):
2190        if chunk and not has_dask:
2191            pytest.skip("requires dask")
2192        ds = Dataset({"x": (("a", "b"), np.empty((5, 0))), "y": ("a", [1, 2, 5, 8, 9])})
2193        if chunk:
2194            ds = ds.chunk({})  # chunk dataset to save dask array
2195        with self.roundtrip(ds) as ds_reload:
2196            assert_identical(ds, ds_reload)
2197
2198    @pytest.mark.parametrize("consolidated", [False, True])
2199    @pytest.mark.parametrize("compute", [False, True])
2200    @pytest.mark.parametrize("use_dask", [False, True])
2201    def test_write_region(self, consolidated, compute, use_dask):
2202        if (use_dask or not compute) and not has_dask:
2203            pytest.skip("requires dask")
2204
2205        zeros = Dataset({"u": (("x",), np.zeros(10))})
2206        nonzeros = Dataset({"u": (("x",), np.arange(1, 11))})
2207
2208        if use_dask:
2209            zeros = zeros.chunk(2)
2210            nonzeros = nonzeros.chunk(2)
2211
2212        with self.create_zarr_target() as store:
2213            zeros.to_zarr(
2214                store,
2215                consolidated=consolidated,
2216                compute=compute,
2217                encoding={"u": dict(chunks=2)},
2218            )
2219            if compute:
2220                with xr.open_zarr(store, consolidated=consolidated) as actual:
2221                    assert_identical(actual, zeros)
2222            for i in range(0, 10, 2):
2223                region = {"x": slice(i, i + 2)}
2224                nonzeros.isel(region).to_zarr(
2225                    store, region=region, consolidated=consolidated
2226                )
2227            with xr.open_zarr(store, consolidated=consolidated) as actual:
2228                assert_identical(actual, nonzeros)
2229
2230    @pytest.mark.parametrize("mode", [None, "r+", "a"])
2231    def test_write_region_mode(self, mode):
2232        zeros = Dataset({"u": (("x",), np.zeros(10))})
2233        nonzeros = Dataset({"u": (("x",), np.arange(1, 11))})
2234        with self.create_zarr_target() as store:
2235            zeros.to_zarr(store)
2236            for region in [{"x": slice(5)}, {"x": slice(5, 10)}]:
2237                nonzeros.isel(region).to_zarr(store, region=region, mode=mode)
2238            with xr.open_zarr(store) as actual:
2239                assert_identical(actual, nonzeros)
2240
2241    @requires_dask
2242    def test_write_preexisting_override_metadata(self):
2243        """Metadata should be overriden if mode="a" but not in mode="r+"."""
2244        original = Dataset(
2245            {"u": (("x",), np.zeros(10), {"variable": "original"})},
2246            attrs={"global": "original"},
2247        )
2248        both_modified = Dataset(
2249            {"u": (("x",), np.ones(10), {"variable": "modified"})},
2250            attrs={"global": "modified"},
2251        )
2252        global_modified = Dataset(
2253            {"u": (("x",), np.ones(10), {"variable": "original"})},
2254            attrs={"global": "modified"},
2255        )
2256        only_new_data = Dataset(
2257            {"u": (("x",), np.ones(10), {"variable": "original"})},
2258            attrs={"global": "original"},
2259        )
2260
2261        with self.create_zarr_target() as store:
2262            original.to_zarr(store, compute=False)
2263            both_modified.to_zarr(store, mode="a")
2264            with self.open(store) as actual:
2265                # NOTE: this arguably incorrect -- we should probably be
2266                # overriding the variable metadata, too. See the TODO note in
2267                # ZarrStore.set_variables.
2268                assert_identical(actual, global_modified)
2269
2270        with self.create_zarr_target() as store:
2271            original.to_zarr(store, compute=False)
2272            both_modified.to_zarr(store, mode="r+")
2273            with self.open(store) as actual:
2274                assert_identical(actual, only_new_data)
2275
2276        with self.create_zarr_target() as store:
2277            original.to_zarr(store, compute=False)
2278            # with region, the default mode becomes r+
2279            both_modified.to_zarr(store, region={"x": slice(None)})
2280            with self.open(store) as actual:
2281                assert_identical(actual, only_new_data)
2282
2283    def test_write_region_errors(self):
2284        data = Dataset({"u": (("x",), np.arange(5))})
2285        data2 = Dataset({"u": (("x",), np.array([10, 11]))})
2286
2287        @contextlib.contextmanager
2288        def setup_and_verify_store(expected=data):
2289            with self.create_zarr_target() as store:
2290                data.to_zarr(store)
2291                yield store
2292                with self.open(store) as actual:
2293                    assert_identical(actual, expected)
2294
2295        # verify the base case works
2296        expected = Dataset({"u": (("x",), np.array([10, 11, 2, 3, 4]))})
2297        with setup_and_verify_store(expected) as store:
2298            data2.to_zarr(store, region={"x": slice(2)})
2299
2300        with setup_and_verify_store() as store:
2301            with pytest.raises(
2302                ValueError,
2303                match=re.escape(
2304                    "cannot set region unless mode='a', mode='r+' or mode=None"
2305                ),
2306            ):
2307                data.to_zarr(store, region={"x": slice(None)}, mode="w")
2308
2309        with setup_and_verify_store() as store:
2310            with pytest.raises(TypeError, match=r"must be a dict"):
2311                data.to_zarr(store, region=slice(None))
2312
2313        with setup_and_verify_store() as store:
2314            with pytest.raises(TypeError, match=r"must be slice objects"):
2315                data2.to_zarr(store, region={"x": [0, 1]})
2316
2317        with setup_and_verify_store() as store:
2318            with pytest.raises(ValueError, match=r"step on all slices"):
2319                data2.to_zarr(store, region={"x": slice(None, None, 2)})
2320
2321        with setup_and_verify_store() as store:
2322            with pytest.raises(
2323                ValueError,
2324                match=r"all keys in ``region`` are not in Dataset dimensions",
2325            ):
2326                data.to_zarr(store, region={"y": slice(None)})
2327
2328        with setup_and_verify_store() as store:
2329            with pytest.raises(
2330                ValueError,
2331                match=r"all variables in the dataset to write must have at least one dimension in common",
2332            ):
2333                data2.assign(v=2).to_zarr(store, region={"x": slice(2)})
2334
2335        with setup_and_verify_store() as store:
2336            with pytest.raises(
2337                ValueError, match=r"cannot list the same dimension in both"
2338            ):
2339                data.to_zarr(store, region={"x": slice(None)}, append_dim="x")
2340
2341        with setup_and_verify_store() as store:
2342            with pytest.raises(
2343                ValueError,
2344                match=r"variable 'u' already exists with different dimension sizes",
2345            ):
2346                data2.to_zarr(store, region={"x": slice(3)})
2347
2348    @requires_dask
2349    def test_encoding_chunksizes(self):
2350        # regression test for GH2278
2351        # see also test_encoding_chunksizes_unlimited
2352        nx, ny, nt = 4, 4, 5
2353        original = xr.Dataset(
2354            {}, coords={"x": np.arange(nx), "y": np.arange(ny), "t": np.arange(nt)}
2355        )
2356        original["v"] = xr.Variable(("x", "y", "t"), np.zeros((nx, ny, nt)))
2357        original = original.chunk({"t": 1, "x": 2, "y": 2})
2358
2359        with self.roundtrip(original) as ds1:
2360            assert_equal(ds1, original)
2361            with self.roundtrip(ds1.isel(t=0)) as ds2:
2362                assert_equal(ds2, original.isel(t=0))
2363
2364    @requires_dask
2365    def test_chunk_encoding_with_partial_dask_chunks(self):
2366        original = xr.Dataset(
2367            {"x": xr.DataArray(np.random.random(size=(6, 8)), dims=("a", "b"))}
2368        ).chunk({"a": 3})
2369
2370        with self.roundtrip(
2371            original, save_kwargs={"encoding": {"x": {"chunks": [3, 2]}}}
2372        ) as ds1:
2373            assert_equal(ds1, original)
2374
2375    @requires_cftime
2376    def test_open_zarr_use_cftime(self):
2377        ds = create_test_data()
2378        with self.create_zarr_target() as store_target:
2379            ds.to_zarr(store_target)
2380            ds_a = xr.open_zarr(store_target)
2381            assert_identical(ds, ds_a)
2382            ds_b = xr.open_zarr(store_target, use_cftime=True)
2383            assert xr.coding.times.contains_cftime_datetimes(ds_b.time)
2384
2385
2386@requires_zarr
2387class TestZarrDictStore(ZarrBase):
2388    @contextlib.contextmanager
2389    def create_zarr_target(self):
2390        yield {}
2391
2392
2393@requires_zarr
2394class TestZarrDirectoryStore(ZarrBase):
2395    @contextlib.contextmanager
2396    def create_zarr_target(self):
2397        with create_tmp_file(suffix=".zarr") as tmp:
2398            yield tmp
2399
2400
2401@requires_zarr
2402@requires_fsspec
2403def test_zarr_storage_options():
2404    pytest.importorskip("aiobotocore")
2405    ds = create_test_data()
2406    store_target = "memory://test.zarr"
2407    ds.to_zarr(store_target, storage_options={"test": "zarr_write"})
2408    ds_a = xr.open_zarr(store_target, storage_options={"test": "zarr_read"})
2409    assert_identical(ds, ds_a)
2410
2411
2412@requires_scipy
2413class TestScipyInMemoryData(CFEncodedBase, NetCDF3Only):
2414    engine = "scipy"
2415
2416    @contextlib.contextmanager
2417    def create_store(self):
2418        fobj = BytesIO()
2419        yield backends.ScipyDataStore(fobj, "w")
2420
2421    def test_to_netcdf_explicit_engine(self):
2422        # regression test for GH1321
2423        Dataset({"foo": 42}).to_netcdf(engine="scipy")
2424
2425    def test_bytes_pickle(self):
2426        data = Dataset({"foo": ("x", [1, 2, 3])})
2427        fobj = data.to_netcdf()
2428        with self.open(fobj) as ds:
2429            unpickled = pickle.loads(pickle.dumps(ds))
2430            assert_identical(unpickled, data)
2431
2432
2433@requires_scipy
2434class TestScipyFileObject(CFEncodedBase, NetCDF3Only):
2435    engine = "scipy"
2436
2437    @contextlib.contextmanager
2438    def create_store(self):
2439        fobj = BytesIO()
2440        yield backends.ScipyDataStore(fobj, "w")
2441
2442    @contextlib.contextmanager
2443    def roundtrip(
2444        self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False
2445    ):
2446        if save_kwargs is None:
2447            save_kwargs = {}
2448        if open_kwargs is None:
2449            open_kwargs = {}
2450        with create_tmp_file() as tmp_file:
2451            with open(tmp_file, "wb") as f:
2452                self.save(data, f, **save_kwargs)
2453            with open(tmp_file, "rb") as f:
2454                with self.open(f, **open_kwargs) as ds:
2455                    yield ds
2456
2457    @pytest.mark.skip(reason="cannot pickle file objects")
2458    def test_pickle(self):
2459        pass
2460
2461    @pytest.mark.skip(reason="cannot pickle file objects")
2462    def test_pickle_dataarray(self):
2463        pass
2464
2465
2466@requires_scipy
2467class TestScipyFilePath(CFEncodedBase, NetCDF3Only):
2468    engine = "scipy"
2469
2470    @contextlib.contextmanager
2471    def create_store(self):
2472        with create_tmp_file() as tmp_file:
2473            with backends.ScipyDataStore(tmp_file, mode="w") as store:
2474                yield store
2475
2476    def test_array_attrs(self):
2477        ds = Dataset(attrs={"foo": [[1, 2], [3, 4]]})
2478        with pytest.raises(ValueError, match=r"must be 1-dimensional"):
2479            with self.roundtrip(ds):
2480                pass
2481
2482    def test_roundtrip_example_1_netcdf_gz(self):
2483        with open_example_dataset("example_1.nc.gz") as expected:
2484            with open_example_dataset("example_1.nc") as actual:
2485                assert_identical(expected, actual)
2486
2487    def test_netcdf3_endianness(self):
2488        # regression test for GH416
2489        with open_example_dataset("bears.nc", engine="scipy") as expected:
2490            for var in expected.variables.values():
2491                assert var.dtype.isnative
2492
2493    @requires_netCDF4
2494    def test_nc4_scipy(self):
2495        with create_tmp_file(allow_cleanup_failure=True) as tmp_file:
2496            with nc4.Dataset(tmp_file, "w", format="NETCDF4") as rootgrp:
2497                rootgrp.createGroup("foo")
2498
2499            with pytest.raises(TypeError, match=r"pip install netcdf4"):
2500                open_dataset(tmp_file, engine="scipy")
2501
2502
2503@requires_netCDF4
2504class TestNetCDF3ViaNetCDF4Data(CFEncodedBase, NetCDF3Only):
2505    engine = "netcdf4"
2506    file_format = "NETCDF3_CLASSIC"
2507
2508    @contextlib.contextmanager
2509    def create_store(self):
2510        with create_tmp_file() as tmp_file:
2511            with backends.NetCDF4DataStore.open(
2512                tmp_file, mode="w", format="NETCDF3_CLASSIC"
2513            ) as store:
2514                yield store
2515
2516    def test_encoding_kwarg_vlen_string(self):
2517        original = Dataset({"x": ["foo", "bar", "baz"]})
2518        kwargs = dict(encoding={"x": {"dtype": str}})
2519        with pytest.raises(ValueError, match=r"encoding dtype=str for vlen"):
2520            with self.roundtrip(original, save_kwargs=kwargs):
2521                pass
2522
2523
2524@requires_netCDF4
2525class TestNetCDF4ClassicViaNetCDF4Data(CFEncodedBase, NetCDF3Only):
2526    engine = "netcdf4"
2527    file_format = "NETCDF4_CLASSIC"
2528
2529    @contextlib.contextmanager
2530    def create_store(self):
2531        with create_tmp_file() as tmp_file:
2532            with backends.NetCDF4DataStore.open(
2533                tmp_file, mode="w", format="NETCDF4_CLASSIC"
2534            ) as store:
2535                yield store
2536
2537
2538@requires_scipy_or_netCDF4
2539class TestGenericNetCDFData(CFEncodedBase, NetCDF3Only):
2540    # verify that we can read and write netCDF3 files as long as we have scipy
2541    # or netCDF4-python installed
2542    file_format = "netcdf3_64bit"
2543
2544    def test_write_store(self):
2545        # there's no specific store to test here
2546        pass
2547
2548    @requires_scipy
2549    def test_engine(self):
2550        data = create_test_data()
2551        with pytest.raises(ValueError, match=r"unrecognized engine"):
2552            data.to_netcdf("foo.nc", engine="foobar")
2553        with pytest.raises(ValueError, match=r"invalid engine"):
2554            data.to_netcdf(engine="netcdf4")
2555
2556        with create_tmp_file() as tmp_file:
2557            data.to_netcdf(tmp_file)
2558            with pytest.raises(ValueError, match=r"unrecognized engine"):
2559                open_dataset(tmp_file, engine="foobar")
2560
2561        netcdf_bytes = data.to_netcdf()
2562        with pytest.raises(ValueError, match=r"unrecognized engine"):
2563            open_dataset(BytesIO(netcdf_bytes), engine="foobar")
2564
2565    def test_cross_engine_read_write_netcdf3(self):
2566        data = create_test_data()
2567        valid_engines = set()
2568        if has_netCDF4:
2569            valid_engines.add("netcdf4")
2570        if has_scipy:
2571            valid_engines.add("scipy")
2572
2573        for write_engine in valid_engines:
2574            for format in self.netcdf3_formats:
2575                with create_tmp_file() as tmp_file:
2576                    data.to_netcdf(tmp_file, format=format, engine=write_engine)
2577                    for read_engine in valid_engines:
2578                        with open_dataset(tmp_file, engine=read_engine) as actual:
2579                            # hack to allow test to work:
2580                            # coord comes back as DataArray rather than coord,
2581                            # and so need to loop through here rather than in
2582                            # the test function (or we get recursion)
2583                            [
2584                                assert_allclose(data[k].variable, actual[k].variable)
2585                                for k in data.variables
2586                            ]
2587
2588    def test_encoding_unlimited_dims(self):
2589        ds = Dataset({"x": ("y", np.arange(10.0))})
2590        with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=["y"])) as actual:
2591            assert actual.encoding["unlimited_dims"] == set("y")
2592            assert_equal(ds, actual)
2593
2594        # Regression test for https://github.com/pydata/xarray/issues/2134
2595        with self.roundtrip(ds, save_kwargs=dict(unlimited_dims="y")) as actual:
2596            assert actual.encoding["unlimited_dims"] == set("y")
2597            assert_equal(ds, actual)
2598
2599        ds.encoding = {"unlimited_dims": ["y"]}
2600        with self.roundtrip(ds) as actual:
2601            assert actual.encoding["unlimited_dims"] == set("y")
2602            assert_equal(ds, actual)
2603
2604        # Regression test for https://github.com/pydata/xarray/issues/2134
2605        ds.encoding = {"unlimited_dims": "y"}
2606        with self.roundtrip(ds) as actual:
2607            assert actual.encoding["unlimited_dims"] == set("y")
2608            assert_equal(ds, actual)
2609
2610
2611@requires_h5netcdf
2612@requires_netCDF4
2613@pytest.mark.filterwarnings("ignore:use make_scale(name) instead")
2614class TestH5NetCDFData(NetCDF4Base):
2615    engine = "h5netcdf"
2616
2617    @contextlib.contextmanager
2618    def create_store(self):
2619        with create_tmp_file() as tmp_file:
2620            yield backends.H5NetCDFStore.open(tmp_file, "w")
2621
2622    @pytest.mark.filterwarnings("ignore:complex dtypes are supported by h5py")
2623    @pytest.mark.parametrize(
2624        "invalid_netcdf, warntype, num_warns",
2625        [(None, FutureWarning, 1), (False, FutureWarning, 1), (True, None, 0)],
2626    )
2627    def test_complex(self, invalid_netcdf, warntype, num_warns):
2628        expected = Dataset({"x": ("y", np.ones(5) + 1j * np.ones(5))})
2629        save_kwargs = {"invalid_netcdf": invalid_netcdf}
2630        with pytest.warns(warntype) as record:
2631            with self.roundtrip(expected, save_kwargs=save_kwargs) as actual:
2632                assert_equal(expected, actual)
2633
2634        recorded_num_warns = 0
2635        if warntype:
2636            for warning in record:
2637                if issubclass(warning.category, warntype) and (
2638                    "complex dtypes" in str(warning.message)
2639                ):
2640                    recorded_num_warns += 1
2641
2642        assert recorded_num_warns == num_warns
2643
2644    def test_numpy_bool_(self):
2645        # h5netcdf loads booleans as numpy.bool_, this type needs to be supported
2646        # when writing invalid_netcdf datasets in order to support a roundtrip
2647        expected = Dataset({"x": ("y", np.ones(5), {"numpy_bool": np.bool_(True)})})
2648        save_kwargs = {"invalid_netcdf": True}
2649        with self.roundtrip(expected, save_kwargs=save_kwargs) as actual:
2650            assert_identical(expected, actual)
2651
2652    def test_cross_engine_read_write_netcdf4(self):
2653        # Drop dim3, because its labels include strings. These appear to be
2654        # not properly read with python-netCDF4, which converts them into
2655        # unicode instead of leaving them as bytes.
2656        data = create_test_data().drop_vars("dim3")
2657        data.attrs["foo"] = "bar"
2658        valid_engines = ["netcdf4", "h5netcdf"]
2659        for write_engine in valid_engines:
2660            with create_tmp_file() as tmp_file:
2661                data.to_netcdf(tmp_file, engine=write_engine)
2662                for read_engine in valid_engines:
2663                    with open_dataset(tmp_file, engine=read_engine) as actual:
2664                        assert_identical(data, actual)
2665
2666    def test_read_byte_attrs_as_unicode(self):
2667        with create_tmp_file() as tmp_file:
2668            with nc4.Dataset(tmp_file, "w") as nc:
2669                nc.foo = b"bar"
2670            with open_dataset(tmp_file) as actual:
2671                expected = Dataset(attrs={"foo": "bar"})
2672                assert_identical(expected, actual)
2673
2674    def test_encoding_unlimited_dims(self):
2675        ds = Dataset({"x": ("y", np.arange(10.0))})
2676        with self.roundtrip(ds, save_kwargs=dict(unlimited_dims=["y"])) as actual:
2677            assert actual.encoding["unlimited_dims"] == set("y")
2678            assert_equal(ds, actual)
2679        ds.encoding = {"unlimited_dims": ["y"]}
2680        with self.roundtrip(ds) as actual:
2681            assert actual.encoding["unlimited_dims"] == set("y")
2682            assert_equal(ds, actual)
2683
2684    def test_compression_encoding_h5py(self):
2685        ENCODINGS = (
2686            # h5py style compression with gzip codec will be converted to
2687            # NetCDF4-Python style on round-trip
2688            (
2689                {"compression": "gzip", "compression_opts": 9},
2690                {"zlib": True, "complevel": 9},
2691            ),
2692            # What can't be expressed in NetCDF4-Python style is
2693            # round-tripped unaltered
2694            (
2695                {"compression": "lzf", "compression_opts": None},
2696                {"compression": "lzf", "compression_opts": None},
2697            ),
2698            # If both styles are used together, h5py format takes precedence
2699            (
2700                {
2701                    "compression": "lzf",
2702                    "compression_opts": None,
2703                    "zlib": True,
2704                    "complevel": 9,
2705                },
2706                {"compression": "lzf", "compression_opts": None},
2707            ),
2708        )
2709
2710        for compr_in, compr_out in ENCODINGS:
2711            data = create_test_data()
2712            compr_common = {
2713                "chunksizes": (5, 5),
2714                "fletcher32": True,
2715                "shuffle": True,
2716                "original_shape": data.var2.shape,
2717            }
2718            data["var2"].encoding.update(compr_in)
2719            data["var2"].encoding.update(compr_common)
2720            compr_out.update(compr_common)
2721            data["scalar"] = ("scalar_dim", np.array([2.0]))
2722            data["scalar"] = data["scalar"][0]
2723            with self.roundtrip(data) as actual:
2724                for k, v in compr_out.items():
2725                    assert v == actual["var2"].encoding[k]
2726
2727    def test_compression_check_encoding_h5py(self):
2728        """When mismatched h5py and NetCDF4-Python encodings are expressed
2729        in to_netcdf(encoding=...), must raise ValueError
2730        """
2731        data = Dataset({"x": ("y", np.arange(10.0))})
2732        # Compatible encodings are graciously supported
2733        with create_tmp_file() as tmp_file:
2734            data.to_netcdf(
2735                tmp_file,
2736                engine="h5netcdf",
2737                encoding={
2738                    "x": {
2739                        "compression": "gzip",
2740                        "zlib": True,
2741                        "compression_opts": 6,
2742                        "complevel": 6,
2743                    }
2744                },
2745            )
2746            with open_dataset(tmp_file, engine="h5netcdf") as actual:
2747                assert actual.x.encoding["zlib"] is True
2748                assert actual.x.encoding["complevel"] == 6
2749
2750        # Incompatible encodings cause a crash
2751        with create_tmp_file() as tmp_file:
2752            with pytest.raises(
2753                ValueError, match=r"'zlib' and 'compression' encodings mismatch"
2754            ):
2755                data.to_netcdf(
2756                    tmp_file,
2757                    engine="h5netcdf",
2758                    encoding={"x": {"compression": "lzf", "zlib": True}},
2759                )
2760
2761        with create_tmp_file() as tmp_file:
2762            with pytest.raises(
2763                ValueError,
2764                match=r"'complevel' and 'compression_opts' encodings mismatch",
2765            ):
2766                data.to_netcdf(
2767                    tmp_file,
2768                    engine="h5netcdf",
2769                    encoding={
2770                        "x": {
2771                            "compression": "gzip",
2772                            "compression_opts": 5,
2773                            "complevel": 6,
2774                        }
2775                    },
2776                )
2777
2778    def test_dump_encodings_h5py(self):
2779        # regression test for #709
2780        ds = Dataset({"x": ("y", np.arange(10.0))})
2781
2782        kwargs = {"encoding": {"x": {"compression": "gzip", "compression_opts": 9}}}
2783        with self.roundtrip(ds, save_kwargs=kwargs) as actual:
2784            assert actual.x.encoding["zlib"]
2785            assert actual.x.encoding["complevel"] == 9
2786
2787        kwargs = {"encoding": {"x": {"compression": "lzf", "compression_opts": None}}}
2788        with self.roundtrip(ds, save_kwargs=kwargs) as actual:
2789            assert actual.x.encoding["compression"] == "lzf"
2790            assert actual.x.encoding["compression_opts"] is None
2791
2792
2793@requires_h5netcdf
2794@requires_netCDF4
2795class TestH5NetCDFAlreadyOpen:
2796    def test_open_dataset_group(self):
2797        import h5netcdf
2798
2799        with create_tmp_file() as tmp_file:
2800            with nc4.Dataset(tmp_file, mode="w") as nc:
2801                group = nc.createGroup("g")
2802                v = group.createVariable("x", "int")
2803                v[...] = 42
2804
2805            kwargs = {}
2806            if LooseVersion(h5netcdf.__version__) >= LooseVersion(
2807                "0.10.0"
2808            ) and LooseVersion(h5netcdf.core.h5py.__version__) >= LooseVersion("3.0.0"):
2809                kwargs = dict(decode_vlen_strings=True)
2810
2811            h5 = h5netcdf.File(tmp_file, mode="r", **kwargs)
2812            store = backends.H5NetCDFStore(h5["g"])
2813            with open_dataset(store) as ds:
2814                expected = Dataset({"x": ((), 42)})
2815                assert_identical(expected, ds)
2816
2817            h5 = h5netcdf.File(tmp_file, mode="r", **kwargs)
2818            store = backends.H5NetCDFStore(h5, group="g")
2819            with open_dataset(store) as ds:
2820                expected = Dataset({"x": ((), 42)})
2821                assert_identical(expected, ds)
2822
2823    def test_deepcopy(self):
2824        import h5netcdf
2825
2826        with create_tmp_file() as tmp_file:
2827            with nc4.Dataset(tmp_file, mode="w") as nc:
2828                nc.createDimension("x", 10)
2829                v = nc.createVariable("y", np.int32, ("x",))
2830                v[:] = np.arange(10)
2831
2832            kwargs = {}
2833            if LooseVersion(h5netcdf.__version__) >= LooseVersion(
2834                "0.10.0"
2835            ) and LooseVersion(h5netcdf.core.h5py.__version__) >= LooseVersion("3.0.0"):
2836                kwargs = dict(decode_vlen_strings=True)
2837
2838            h5 = h5netcdf.File(tmp_file, mode="r", **kwargs)
2839            store = backends.H5NetCDFStore(h5)
2840            with open_dataset(store) as ds:
2841                copied = ds.copy(deep=True)
2842                expected = Dataset({"y": ("x", np.arange(10))})
2843                assert_identical(expected, copied)
2844
2845
2846@requires_h5netcdf
2847class TestH5NetCDFFileObject(TestH5NetCDFData):
2848    engine = "h5netcdf"
2849
2850    def test_open_badbytes(self):
2851        with pytest.raises(ValueError, match=r"HDF5 as bytes"):
2852            with open_dataset(b"\211HDF\r\n\032\n", engine="h5netcdf"):
2853                pass
2854        with pytest.raises(
2855            ValueError, match=r"match in any of xarray's currently installed IO"
2856        ):
2857            with open_dataset(b"garbage"):
2858                pass
2859        with pytest.raises(ValueError, match=r"can only read bytes"):
2860            with open_dataset(b"garbage", engine="netcdf4"):
2861                pass
2862        with pytest.raises(
2863            ValueError, match=r"not the signature of a valid netCDF4 file"
2864        ):
2865            with open_dataset(BytesIO(b"garbage"), engine="h5netcdf"):
2866                pass
2867
2868    def test_open_twice(self):
2869        expected = create_test_data()
2870        expected.attrs["foo"] = "bar"
2871        with pytest.raises(ValueError, match=r"read/write pointer not at the start"):
2872            with create_tmp_file() as tmp_file:
2873                expected.to_netcdf(tmp_file, engine="h5netcdf")
2874                with open(tmp_file, "rb") as f:
2875                    with open_dataset(f, engine="h5netcdf"):
2876                        with open_dataset(f, engine="h5netcdf"):
2877                            pass
2878
2879    @requires_scipy
2880    def test_open_fileobj(self):
2881        # open in-memory datasets instead of local file paths
2882        expected = create_test_data().drop_vars("dim3")
2883        expected.attrs["foo"] = "bar"
2884        with create_tmp_file() as tmp_file:
2885            expected.to_netcdf(tmp_file, engine="h5netcdf")
2886
2887            with open(tmp_file, "rb") as f:
2888                with open_dataset(f, engine="h5netcdf") as actual:
2889                    assert_identical(expected, actual)
2890
2891                f.seek(0)
2892                with open_dataset(f) as actual:
2893                    assert_identical(expected, actual)
2894
2895                f.seek(0)
2896                with BytesIO(f.read()) as bio:
2897                    with open_dataset(bio, engine="h5netcdf") as actual:
2898                        assert_identical(expected, actual)
2899
2900                f.seek(0)
2901                with pytest.raises(TypeError, match="not a valid NetCDF 3"):
2902                    open_dataset(f, engine="scipy")
2903
2904            # TOOD: this additional open is required since scipy seems to close the file
2905            # when it fails on the TypeError (though didn't when we used
2906            # `raises_regex`?). Ref https://github.com/pydata/xarray/pull/5191
2907            with open(tmp_file, "rb") as f:
2908                f.seek(8)
2909                with pytest.raises(
2910                    ValueError,
2911                    match="match in any of xarray's currently installed IO",
2912                ):
2913                    with pytest.warns(
2914                        RuntimeWarning,
2915                        match=re.escape("'h5netcdf' fails while guessing"),
2916                    ):
2917                        open_dataset(f)
2918
2919
2920@requires_h5netcdf
2921@requires_dask
2922@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager")
2923class TestH5NetCDFViaDaskData(TestH5NetCDFData):
2924    @contextlib.contextmanager
2925    def roundtrip(
2926        self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False
2927    ):
2928        if save_kwargs is None:
2929            save_kwargs = {}
2930        if open_kwargs is None:
2931            open_kwargs = {}
2932        open_kwargs.setdefault("chunks", -1)
2933        with TestH5NetCDFData.roundtrip(
2934            self, data, save_kwargs, open_kwargs, allow_cleanup_failure
2935        ) as ds:
2936            yield ds
2937
2938    def test_dataset_caching(self):
2939        # caching behavior differs for dask
2940        pass
2941
2942    def test_write_inconsistent_chunks(self):
2943        # Construct two variables with the same dimensions, but different
2944        # chunk sizes.
2945        x = da.zeros((100, 100), dtype="f4", chunks=(50, 100))
2946        x = DataArray(data=x, dims=("lat", "lon"), name="x")
2947        x.encoding["chunksizes"] = (50, 100)
2948        x.encoding["original_shape"] = (100, 100)
2949        y = da.ones((100, 100), dtype="f4", chunks=(100, 50))
2950        y = DataArray(data=y, dims=("lat", "lon"), name="y")
2951        y.encoding["chunksizes"] = (100, 50)
2952        y.encoding["original_shape"] = (100, 100)
2953        # Put them both into the same dataset
2954        ds = Dataset({"x": x, "y": y})
2955        with self.roundtrip(ds) as actual:
2956            assert actual["x"].encoding["chunksizes"] == (50, 100)
2957            assert actual["y"].encoding["chunksizes"] == (100, 50)
2958
2959
2960@pytest.fixture(params=["scipy", "netcdf4", "h5netcdf", "pynio", "zarr"])
2961def readengine(request):
2962    return request.param
2963
2964
2965@pytest.fixture(params=[1, 20])
2966def nfiles(request):
2967    return request.param
2968
2969
2970@pytest.fixture(params=[5, None])
2971def file_cache_maxsize(request):
2972    maxsize = request.param
2973    if maxsize is not None:
2974        with set_options(file_cache_maxsize=maxsize):
2975            yield maxsize
2976    else:
2977        yield maxsize
2978
2979
2980@pytest.fixture(params=[True, False])
2981def parallel(request):
2982    return request.param
2983
2984
2985@pytest.fixture(params=[None, 5])
2986def chunks(request):
2987    return request.param
2988
2989
2990# using pytest.mark.skipif does not work so this a work around
2991def skip_if_not_engine(engine):
2992    if engine == "netcdf4":
2993        pytest.importorskip("netCDF4")
2994    elif engine == "pynio":
2995        pytest.importorskip("Nio")
2996    else:
2997        pytest.importorskip(engine)
2998
2999
3000@requires_dask
3001@pytest.mark.filterwarnings("ignore:use make_scale(name) instead")
3002def test_open_mfdataset_manyfiles(
3003    readengine, nfiles, parallel, chunks, file_cache_maxsize
3004):
3005
3006    # skip certain combinations
3007    skip_if_not_engine(readengine)
3008
3009    if ON_WINDOWS:
3010        pytest.skip("Skipping on Windows")
3011
3012    randdata = np.random.randn(nfiles)
3013    original = Dataset({"foo": ("x", randdata)})
3014    # test standard open_mfdataset approach with too many files
3015    with create_tmp_files(nfiles) as tmpfiles:
3016        writeengine = readengine if readengine != "pynio" else "netcdf4"
3017        # split into multiple sets of temp files
3018        for ii in original.x.values:
3019            subds = original.isel(x=slice(ii, ii + 1))
3020            if writeengine != "zarr":
3021                subds.to_netcdf(tmpfiles[ii], engine=writeengine)
3022            else:  # if writeengine == "zarr":
3023                subds.to_zarr(store=tmpfiles[ii])
3024
3025        # check that calculation on opened datasets works properly
3026        with open_mfdataset(
3027            tmpfiles,
3028            combine="nested",
3029            concat_dim="x",
3030            engine=readengine,
3031            parallel=parallel,
3032            chunks=chunks if (not chunks and readengine != "zarr") else "auto",
3033        ) as actual:
3034
3035            # check that using open_mfdataset returns dask arrays for variables
3036            assert isinstance(actual["foo"].data, dask_array_type)
3037
3038            assert_identical(original, actual)
3039
3040
3041@requires_netCDF4
3042@requires_dask
3043def test_open_mfdataset_can_open_path_objects():
3044    dataset = os.path.join(os.path.dirname(__file__), "data", "example_1.nc")
3045    with open_mfdataset(Path(dataset)) as actual:
3046        assert isinstance(actual, Dataset)
3047
3048
3049@requires_netCDF4
3050@requires_dask
3051def test_open_mfdataset_list_attr():
3052    """
3053    Case when an attribute of type list differs across the multiple files
3054    """
3055    from netCDF4 import Dataset
3056
3057    with create_tmp_files(2) as nfiles:
3058        for i in range(2):
3059            f = Dataset(nfiles[i], "w")
3060            f.createDimension("x", 3)
3061            vlvar = f.createVariable("test_var", np.int32, ("x"))
3062            # here create an attribute as a list
3063            vlvar.test_attr = [f"string a {i}", f"string b {i}"]
3064            vlvar[:] = np.arange(3)
3065            f.close()
3066        ds1 = open_dataset(nfiles[0])
3067        ds2 = open_dataset(nfiles[1])
3068        original = xr.concat([ds1, ds2], dim="x")
3069        with xr.open_mfdataset(
3070            [nfiles[0], nfiles[1]], combine="nested", concat_dim="x"
3071        ) as actual:
3072            assert_identical(actual, original)
3073
3074
3075@requires_scipy_or_netCDF4
3076@requires_dask
3077class TestOpenMFDatasetWithDataVarsAndCoordsKw:
3078    coord_name = "lon"
3079    var_name = "v1"
3080
3081    @contextlib.contextmanager
3082    def setup_files_and_datasets(self, fuzz=0):
3083        ds1, ds2 = self.gen_datasets_with_common_coord_and_time()
3084
3085        # to test join='exact'
3086        ds1["x"] = ds1.x + fuzz
3087
3088        with create_tmp_file() as tmpfile1:
3089            with create_tmp_file() as tmpfile2:
3090
3091                # save data to the temporary files
3092                ds1.to_netcdf(tmpfile1)
3093                ds2.to_netcdf(tmpfile2)
3094
3095                yield [tmpfile1, tmpfile2], [ds1, ds2]
3096
3097    def gen_datasets_with_common_coord_and_time(self):
3098        # create coordinate data
3099        nx = 10
3100        nt = 10
3101        x = np.arange(nx)
3102        t1 = np.arange(nt)
3103        t2 = np.arange(nt, 2 * nt, 1)
3104
3105        v1 = np.random.randn(nt, nx)
3106        v2 = np.random.randn(nt, nx)
3107
3108        ds1 = Dataset(
3109            data_vars={self.var_name: (["t", "x"], v1), self.coord_name: ("x", 2 * x)},
3110            coords={"t": (["t"], t1), "x": (["x"], x)},
3111        )
3112
3113        ds2 = Dataset(
3114            data_vars={self.var_name: (["t", "x"], v2), self.coord_name: ("x", 2 * x)},
3115            coords={"t": (["t"], t2), "x": (["x"], x)},
3116        )
3117
3118        return ds1, ds2
3119
3120    @pytest.mark.parametrize(
3121        "combine, concat_dim", [("nested", "t"), ("by_coords", None)]
3122    )
3123    @pytest.mark.parametrize("opt", ["all", "minimal", "different"])
3124    @pytest.mark.parametrize("join", ["outer", "inner", "left", "right"])
3125    def test_open_mfdataset_does_same_as_concat(self, combine, concat_dim, opt, join):
3126        with self.setup_files_and_datasets() as (files, [ds1, ds2]):
3127            if combine == "by_coords":
3128                files.reverse()
3129            with open_mfdataset(
3130                files, data_vars=opt, combine=combine, concat_dim=concat_dim, join=join
3131            ) as ds:
3132                ds_expect = xr.concat([ds1, ds2], data_vars=opt, dim="t", join=join)
3133                assert_identical(ds, ds_expect)
3134
3135    @pytest.mark.parametrize(
3136        ["combine_attrs", "attrs", "expected", "expect_error"],
3137        (
3138            pytest.param("drop", [{"a": 1}, {"a": 2}], {}, False, id="drop"),
3139            pytest.param(
3140                "override", [{"a": 1}, {"a": 2}], {"a": 1}, False, id="override"
3141            ),
3142            pytest.param(
3143                "no_conflicts", [{"a": 1}, {"a": 2}], None, True, id="no_conflicts"
3144            ),
3145            pytest.param(
3146                "identical",
3147                [{"a": 1, "b": 2}, {"a": 1, "c": 3}],
3148                None,
3149                True,
3150                id="identical",
3151            ),
3152            pytest.param(
3153                "drop_conflicts",
3154                [{"a": 1, "b": 2}, {"b": -1, "c": 3}],
3155                {"a": 1, "c": 3},
3156                False,
3157                id="drop_conflicts",
3158            ),
3159        ),
3160    )
3161    def test_open_mfdataset_dataset_combine_attrs(
3162        self, combine_attrs, attrs, expected, expect_error
3163    ):
3164        with self.setup_files_and_datasets() as (files, [ds1, ds2]):
3165            # Give the files an inconsistent attribute
3166            for i, f in enumerate(files):
3167                ds = open_dataset(f).load()
3168                ds.attrs = attrs[i]
3169                ds.close()
3170                ds.to_netcdf(f)
3171
3172            if expect_error:
3173                with pytest.raises(xr.MergeError):
3174                    xr.open_mfdataset(
3175                        files,
3176                        combine="nested",
3177                        concat_dim="t",
3178                        combine_attrs=combine_attrs,
3179                    )
3180            else:
3181                with xr.open_mfdataset(
3182                    files,
3183                    combine="nested",
3184                    concat_dim="t",
3185                    combine_attrs=combine_attrs,
3186                ) as ds:
3187                    assert ds.attrs == expected
3188
3189    def test_open_mfdataset_dataset_attr_by_coords(self):
3190        """
3191        Case when an attribute differs across the multiple files
3192        """
3193        with self.setup_files_and_datasets() as (files, [ds1, ds2]):
3194            # Give the files an inconsistent attribute
3195            for i, f in enumerate(files):
3196                ds = open_dataset(f).load()
3197                ds.attrs["test_dataset_attr"] = 10 + i
3198                ds.close()
3199                ds.to_netcdf(f)
3200
3201            with xr.open_mfdataset(files, combine="nested", concat_dim="t") as ds:
3202                assert ds.test_dataset_attr == 10
3203
3204    def test_open_mfdataset_dataarray_attr_by_coords(self):
3205        """
3206        Case when an attribute of a member DataArray differs across the multiple files
3207        """
3208        with self.setup_files_and_datasets() as (files, [ds1, ds2]):
3209            # Give the files an inconsistent attribute
3210            for i, f in enumerate(files):
3211                ds = open_dataset(f).load()
3212                ds["v1"].attrs["test_dataarray_attr"] = i
3213                ds.close()
3214                ds.to_netcdf(f)
3215
3216            with xr.open_mfdataset(files, combine="nested", concat_dim="t") as ds:
3217                assert ds["v1"].test_dataarray_attr == 0
3218
3219    @pytest.mark.parametrize(
3220        "combine, concat_dim", [("nested", "t"), ("by_coords", None)]
3221    )
3222    @pytest.mark.parametrize("opt", ["all", "minimal", "different"])
3223    def test_open_mfdataset_exact_join_raises_error(self, combine, concat_dim, opt):
3224        with self.setup_files_and_datasets(fuzz=0.1) as (files, [ds1, ds2]):
3225            if combine == "by_coords":
3226                files.reverse()
3227            with pytest.raises(ValueError, match=r"indexes along dimension"):
3228                open_mfdataset(
3229                    files,
3230                    data_vars=opt,
3231                    combine=combine,
3232                    concat_dim=concat_dim,
3233                    join="exact",
3234                )
3235
3236    def test_common_coord_when_datavars_all(self):
3237        opt = "all"
3238
3239        with self.setup_files_and_datasets() as (files, [ds1, ds2]):
3240            # open the files with the data_var option
3241            with open_mfdataset(
3242                files, data_vars=opt, combine="nested", concat_dim="t"
3243            ) as ds:
3244
3245                coord_shape = ds[self.coord_name].shape
3246                coord_shape1 = ds1[self.coord_name].shape
3247                coord_shape2 = ds2[self.coord_name].shape
3248
3249                var_shape = ds[self.var_name].shape
3250
3251                assert var_shape == coord_shape
3252                assert coord_shape1 != coord_shape
3253                assert coord_shape2 != coord_shape
3254
3255    def test_common_coord_when_datavars_minimal(self):
3256        opt = "minimal"
3257
3258        with self.setup_files_and_datasets() as (files, [ds1, ds2]):
3259            # open the files using data_vars option
3260            with open_mfdataset(
3261                files, data_vars=opt, combine="nested", concat_dim="t"
3262            ) as ds:
3263
3264                coord_shape = ds[self.coord_name].shape
3265                coord_shape1 = ds1[self.coord_name].shape
3266                coord_shape2 = ds2[self.coord_name].shape
3267
3268                var_shape = ds[self.var_name].shape
3269
3270                assert var_shape != coord_shape
3271                assert coord_shape1 == coord_shape
3272                assert coord_shape2 == coord_shape
3273
3274    def test_invalid_data_vars_value_should_fail(self):
3275
3276        with self.setup_files_and_datasets() as (files, _):
3277            with pytest.raises(ValueError):
3278                with open_mfdataset(files, data_vars="minimum", combine="by_coords"):
3279                    pass
3280
3281            # test invalid coord parameter
3282            with pytest.raises(ValueError):
3283                with open_mfdataset(files, coords="minimum", combine="by_coords"):
3284                    pass
3285
3286
3287@requires_dask
3288@requires_scipy
3289@requires_netCDF4
3290class TestDask(DatasetIOBase):
3291    @contextlib.contextmanager
3292    def create_store(self):
3293        yield Dataset()
3294
3295    @contextlib.contextmanager
3296    def roundtrip(
3297        self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False
3298    ):
3299        yield data.chunk()
3300
3301    # Override methods in DatasetIOBase - not applicable to dask
3302    def test_roundtrip_string_encoded_characters(self):
3303        pass
3304
3305    def test_roundtrip_coordinates_with_space(self):
3306        pass
3307
3308    def test_roundtrip_numpy_datetime_data(self):
3309        # Override method in DatasetIOBase - remove not applicable
3310        # save_kwargs
3311        times = pd.to_datetime(["2000-01-01", "2000-01-02", "NaT"])
3312        expected = Dataset({"t": ("t", times), "t0": times[0]})
3313        with self.roundtrip(expected) as actual:
3314            assert_identical(expected, actual)
3315
3316    def test_roundtrip_cftime_datetime_data(self):
3317        # Override method in DatasetIOBase - remove not applicable
3318        # save_kwargs
3319        from .test_coding_times import _all_cftime_date_types
3320
3321        date_types = _all_cftime_date_types()
3322        for date_type in date_types.values():
3323            times = [date_type(1, 1, 1), date_type(1, 1, 2)]
3324            expected = Dataset({"t": ("t", times), "t0": times[0]})
3325            expected_decoded_t = np.array(times)
3326            expected_decoded_t0 = np.array([date_type(1, 1, 1)])
3327
3328            with self.roundtrip(expected) as actual:
3329                abs_diff = abs(actual.t.values - expected_decoded_t)
3330                assert (abs_diff <= np.timedelta64(1, "s")).all()
3331
3332                abs_diff = abs(actual.t0.values - expected_decoded_t0)
3333                assert (abs_diff <= np.timedelta64(1, "s")).all()
3334
3335    def test_write_store(self):
3336        # Override method in DatasetIOBase - not applicable to dask
3337        pass
3338
3339    def test_dataset_caching(self):
3340        expected = Dataset({"foo": ("x", [5, 6, 7])})
3341        with self.roundtrip(expected) as actual:
3342            assert not actual.foo.variable._in_memory
3343            actual.foo.values  # no caching
3344            assert not actual.foo.variable._in_memory
3345
3346    def test_open_mfdataset(self):
3347        original = Dataset({"foo": ("x", np.random.randn(10))})
3348        with create_tmp_file() as tmp1:
3349            with create_tmp_file() as tmp2:
3350                original.isel(x=slice(5)).to_netcdf(tmp1)
3351                original.isel(x=slice(5, 10)).to_netcdf(tmp2)
3352                with open_mfdataset(
3353                    [tmp1, tmp2], concat_dim="x", combine="nested"
3354                ) as actual:
3355                    assert isinstance(actual.foo.variable.data, da.Array)
3356                    assert actual.foo.variable.data.chunks == ((5, 5),)
3357                    assert_identical(original, actual)
3358                with open_mfdataset(
3359                    [tmp1, tmp2], concat_dim="x", combine="nested", chunks={"x": 3}
3360                ) as actual:
3361                    assert actual.foo.variable.data.chunks == ((3, 2, 3, 2),)
3362
3363        with pytest.raises(OSError, match=r"no files to open"):
3364            open_mfdataset("foo-bar-baz-*.nc")
3365        with pytest.raises(ValueError, match=r"wild-card"):
3366            open_mfdataset("http://some/remote/uri")
3367
3368    @requires_fsspec
3369    def test_open_mfdataset_no_files(self):
3370        pytest.importorskip("aiobotocore")
3371
3372        # glob is attempted as of #4823, but finds no files
3373        with pytest.raises(OSError, match=r"no files"):
3374            open_mfdataset("http://some/remote/uri", engine="zarr")
3375
3376    def test_open_mfdataset_2d(self):
3377        original = Dataset({"foo": (["x", "y"], np.random.randn(10, 8))})
3378        with create_tmp_file() as tmp1:
3379            with create_tmp_file() as tmp2:
3380                with create_tmp_file() as tmp3:
3381                    with create_tmp_file() as tmp4:
3382                        original.isel(x=slice(5), y=slice(4)).to_netcdf(tmp1)
3383                        original.isel(x=slice(5, 10), y=slice(4)).to_netcdf(tmp2)
3384                        original.isel(x=slice(5), y=slice(4, 8)).to_netcdf(tmp3)
3385                        original.isel(x=slice(5, 10), y=slice(4, 8)).to_netcdf(tmp4)
3386                        with open_mfdataset(
3387                            [[tmp1, tmp2], [tmp3, tmp4]],
3388                            combine="nested",
3389                            concat_dim=["y", "x"],
3390                        ) as actual:
3391                            assert isinstance(actual.foo.variable.data, da.Array)
3392                            assert actual.foo.variable.data.chunks == ((5, 5), (4, 4))
3393                            assert_identical(original, actual)
3394                        with open_mfdataset(
3395                            [[tmp1, tmp2], [tmp3, tmp4]],
3396                            combine="nested",
3397                            concat_dim=["y", "x"],
3398                            chunks={"x": 3, "y": 2},
3399                        ) as actual:
3400                            assert actual.foo.variable.data.chunks == (
3401                                (3, 2, 3, 2),
3402                                (2, 2, 2, 2),
3403                            )
3404
3405    def test_open_mfdataset_pathlib(self):
3406        original = Dataset({"foo": ("x", np.random.randn(10))})
3407        with create_tmp_file() as tmp1:
3408            with create_tmp_file() as tmp2:
3409                tmp1 = Path(tmp1)
3410                tmp2 = Path(tmp2)
3411                original.isel(x=slice(5)).to_netcdf(tmp1)
3412                original.isel(x=slice(5, 10)).to_netcdf(tmp2)
3413                with open_mfdataset(
3414                    [tmp1, tmp2], concat_dim="x", combine="nested"
3415                ) as actual:
3416                    assert_identical(original, actual)
3417
3418    def test_open_mfdataset_2d_pathlib(self):
3419        original = Dataset({"foo": (["x", "y"], np.random.randn(10, 8))})
3420        with create_tmp_file() as tmp1:
3421            with create_tmp_file() as tmp2:
3422                with create_tmp_file() as tmp3:
3423                    with create_tmp_file() as tmp4:
3424                        tmp1 = Path(tmp1)
3425                        tmp2 = Path(tmp2)
3426                        tmp3 = Path(tmp3)
3427                        tmp4 = Path(tmp4)
3428                        original.isel(x=slice(5), y=slice(4)).to_netcdf(tmp1)
3429                        original.isel(x=slice(5, 10), y=slice(4)).to_netcdf(tmp2)
3430                        original.isel(x=slice(5), y=slice(4, 8)).to_netcdf(tmp3)
3431                        original.isel(x=slice(5, 10), y=slice(4, 8)).to_netcdf(tmp4)
3432                        with open_mfdataset(
3433                            [[tmp1, tmp2], [tmp3, tmp4]],
3434                            combine="nested",
3435                            concat_dim=["y", "x"],
3436                        ) as actual:
3437                            assert_identical(original, actual)
3438
3439    def test_open_mfdataset_2(self):
3440        original = Dataset({"foo": ("x", np.random.randn(10))})
3441        with create_tmp_file() as tmp1:
3442            with create_tmp_file() as tmp2:
3443                original.isel(x=slice(5)).to_netcdf(tmp1)
3444                original.isel(x=slice(5, 10)).to_netcdf(tmp2)
3445
3446                with open_mfdataset(
3447                    [tmp1, tmp2], concat_dim="x", combine="nested"
3448                ) as actual:
3449                    assert_identical(original, actual)
3450
3451    def test_attrs_mfdataset(self):
3452        original = Dataset({"foo": ("x", np.random.randn(10))})
3453        with create_tmp_file() as tmp1:
3454            with create_tmp_file() as tmp2:
3455                ds1 = original.isel(x=slice(5))
3456                ds2 = original.isel(x=slice(5, 10))
3457                ds1.attrs["test1"] = "foo"
3458                ds2.attrs["test2"] = "bar"
3459                ds1.to_netcdf(tmp1)
3460                ds2.to_netcdf(tmp2)
3461                with open_mfdataset(
3462                    [tmp1, tmp2], concat_dim="x", combine="nested"
3463                ) as actual:
3464                    # presumes that attributes inherited from
3465                    # first dataset loaded
3466                    assert actual.test1 == ds1.test1
3467                    # attributes from ds2 are not retained, e.g.,
3468                    with pytest.raises(AttributeError, match=r"no attribute"):
3469                        actual.test2
3470
3471    def test_open_mfdataset_attrs_file(self):
3472        original = Dataset({"foo": ("x", np.random.randn(10))})
3473        with create_tmp_files(2) as (tmp1, tmp2):
3474            ds1 = original.isel(x=slice(5))
3475            ds2 = original.isel(x=slice(5, 10))
3476            ds1.attrs["test1"] = "foo"
3477            ds2.attrs["test2"] = "bar"
3478            ds1.to_netcdf(tmp1)
3479            ds2.to_netcdf(tmp2)
3480            with open_mfdataset(
3481                [tmp1, tmp2], concat_dim="x", combine="nested", attrs_file=tmp2
3482            ) as actual:
3483                # attributes are inherited from the master file
3484                assert actual.attrs["test2"] == ds2.attrs["test2"]
3485                # attributes from ds1 are not retained, e.g.,
3486                assert "test1" not in actual.attrs
3487
3488    def test_open_mfdataset_attrs_file_path(self):
3489        original = Dataset({"foo": ("x", np.random.randn(10))})
3490        with create_tmp_files(2) as (tmp1, tmp2):
3491            tmp1 = Path(tmp1)
3492            tmp2 = Path(tmp2)
3493            ds1 = original.isel(x=slice(5))
3494            ds2 = original.isel(x=slice(5, 10))
3495            ds1.attrs["test1"] = "foo"
3496            ds2.attrs["test2"] = "bar"
3497            ds1.to_netcdf(tmp1)
3498            ds2.to_netcdf(tmp2)
3499            with open_mfdataset(
3500                [tmp1, tmp2], concat_dim="x", combine="nested", attrs_file=tmp2
3501            ) as actual:
3502                # attributes are inherited from the master file
3503                assert actual.attrs["test2"] == ds2.attrs["test2"]
3504                # attributes from ds1 are not retained, e.g.,
3505                assert "test1" not in actual.attrs
3506
3507    def test_open_mfdataset_auto_combine(self):
3508        original = Dataset({"foo": ("x", np.random.randn(10)), "x": np.arange(10)})
3509        with create_tmp_file() as tmp1:
3510            with create_tmp_file() as tmp2:
3511                original.isel(x=slice(5)).to_netcdf(tmp1)
3512                original.isel(x=slice(5, 10)).to_netcdf(tmp2)
3513
3514                with open_mfdataset([tmp2, tmp1], combine="by_coords") as actual:
3515                    assert_identical(original, actual)
3516
3517    def test_open_mfdataset_raise_on_bad_combine_args(self):
3518        # Regression test for unhelpful error shown in #5230
3519        original = Dataset({"foo": ("x", np.random.randn(10)), "x": np.arange(10)})
3520        with create_tmp_file() as tmp1:
3521            with create_tmp_file() as tmp2:
3522                original.isel(x=slice(5)).to_netcdf(tmp1)
3523                original.isel(x=slice(5, 10)).to_netcdf(tmp2)
3524                with pytest.raises(ValueError, match="`concat_dim` has no effect"):
3525                    open_mfdataset([tmp1, tmp2], concat_dim="x")
3526
3527    @pytest.mark.xfail(reason="mfdataset loses encoding currently.")
3528    def test_encoding_mfdataset(self):
3529        original = Dataset(
3530            {
3531                "foo": ("t", np.random.randn(10)),
3532                "t": ("t", pd.date_range(start="2010-01-01", periods=10, freq="1D")),
3533            }
3534        )
3535        original.t.encoding["units"] = "days since 2010-01-01"
3536
3537        with create_tmp_file() as tmp1:
3538            with create_tmp_file() as tmp2:
3539                ds1 = original.isel(t=slice(5))
3540                ds2 = original.isel(t=slice(5, 10))
3541                ds1.t.encoding["units"] = "days since 2010-01-01"
3542                ds2.t.encoding["units"] = "days since 2000-01-01"
3543                ds1.to_netcdf(tmp1)
3544                ds2.to_netcdf(tmp2)
3545                with open_mfdataset([tmp1, tmp2], combine="nested") as actual:
3546                    assert actual.t.encoding["units"] == original.t.encoding["units"]
3547                    assert actual.t.encoding["units"] == ds1.t.encoding["units"]
3548                    assert actual.t.encoding["units"] != ds2.t.encoding["units"]
3549
3550    def test_preprocess_mfdataset(self):
3551        original = Dataset({"foo": ("x", np.random.randn(10))})
3552        with create_tmp_file() as tmp:
3553            original.to_netcdf(tmp)
3554
3555            def preprocess(ds):
3556                return ds.assign_coords(z=0)
3557
3558            expected = preprocess(original)
3559            with open_mfdataset(
3560                tmp, preprocess=preprocess, combine="by_coords"
3561            ) as actual:
3562                assert_identical(expected, actual)
3563
3564    def test_save_mfdataset_roundtrip(self):
3565        original = Dataset({"foo": ("x", np.random.randn(10))})
3566        datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))]
3567        with create_tmp_file() as tmp1:
3568            with create_tmp_file() as tmp2:
3569                save_mfdataset(datasets, [tmp1, tmp2])
3570                with open_mfdataset(
3571                    [tmp1, tmp2], concat_dim="x", combine="nested"
3572                ) as actual:
3573                    assert_identical(actual, original)
3574
3575    def test_save_mfdataset_invalid(self):
3576        ds = Dataset()
3577        with pytest.raises(ValueError, match=r"cannot use mode"):
3578            save_mfdataset([ds, ds], ["same", "same"])
3579        with pytest.raises(ValueError, match=r"same length"):
3580            save_mfdataset([ds, ds], ["only one path"])
3581
3582    def test_save_mfdataset_invalid_dataarray(self):
3583        # regression test for GH1555
3584        da = DataArray([1, 2])
3585        with pytest.raises(TypeError, match=r"supports writing Dataset"):
3586            save_mfdataset([da], ["dataarray"])
3587
3588    def test_save_mfdataset_pathlib_roundtrip(self):
3589        original = Dataset({"foo": ("x", np.random.randn(10))})
3590        datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))]
3591        with create_tmp_file() as tmp1:
3592            with create_tmp_file() as tmp2:
3593                tmp1 = Path(tmp1)
3594                tmp2 = Path(tmp2)
3595                save_mfdataset(datasets, [tmp1, tmp2])
3596                with open_mfdataset(
3597                    [tmp1, tmp2], concat_dim="x", combine="nested"
3598                ) as actual:
3599                    assert_identical(actual, original)
3600
3601    def test_open_and_do_math(self):
3602        original = Dataset({"foo": ("x", np.random.randn(10))})
3603        with create_tmp_file() as tmp:
3604            original.to_netcdf(tmp)
3605            with open_mfdataset(tmp, combine="by_coords") as ds:
3606                actual = 1.0 * ds
3607                assert_allclose(original, actual, decode_bytes=False)
3608
3609    def test_open_mfdataset_concat_dim_none(self):
3610        with create_tmp_file() as tmp1:
3611            with create_tmp_file() as tmp2:
3612                data = Dataset({"x": 0})
3613                data.to_netcdf(tmp1)
3614                Dataset({"x": np.nan}).to_netcdf(tmp2)
3615                with open_mfdataset(
3616                    [tmp1, tmp2], concat_dim=None, combine="nested"
3617                ) as actual:
3618                    assert_identical(data, actual)
3619
3620    def test_open_mfdataset_concat_dim_default_none(self):
3621        with create_tmp_file() as tmp1:
3622            with create_tmp_file() as tmp2:
3623                data = Dataset({"x": 0})
3624                data.to_netcdf(tmp1)
3625                Dataset({"x": np.nan}).to_netcdf(tmp2)
3626                with open_mfdataset([tmp1, tmp2], combine="nested") as actual:
3627                    assert_identical(data, actual)
3628
3629    def test_open_dataset(self):
3630        original = Dataset({"foo": ("x", np.random.randn(10))})
3631        with create_tmp_file() as tmp:
3632            original.to_netcdf(tmp)
3633            with open_dataset(tmp, chunks={"x": 5}) as actual:
3634                assert isinstance(actual.foo.variable.data, da.Array)
3635                assert actual.foo.variable.data.chunks == ((5, 5),)
3636                assert_identical(original, actual)
3637            with open_dataset(tmp, chunks=5) as actual:
3638                assert_identical(original, actual)
3639            with open_dataset(tmp) as actual:
3640                assert isinstance(actual.foo.variable.data, np.ndarray)
3641                assert_identical(original, actual)
3642
3643    def test_open_single_dataset(self):
3644        # Test for issue GH #1988. This makes sure that the
3645        # concat_dim is utilized when specified in open_mfdataset().
3646        rnddata = np.random.randn(10)
3647        original = Dataset({"foo": ("x", rnddata)})
3648        dim = DataArray([100], name="baz", dims="baz")
3649        expected = Dataset(
3650            {"foo": (("baz", "x"), rnddata[np.newaxis, :])}, {"baz": [100]}
3651        )
3652        with create_tmp_file() as tmp:
3653            original.to_netcdf(tmp)
3654            with open_mfdataset([tmp], concat_dim=dim, combine="nested") as actual:
3655                assert_identical(expected, actual)
3656
3657    def test_open_multi_dataset(self):
3658        # Test for issue GH #1988 and #2647. This makes sure that the
3659        # concat_dim is utilized when specified in open_mfdataset().
3660        # The additional wrinkle is to ensure that a length greater
3661        # than one is tested as well due to numpy's implicit casting
3662        # of 1-length arrays to booleans in tests, which allowed
3663        # #2647 to still pass the test_open_single_dataset(),
3664        # which is itself still needed as-is because the original
3665        # bug caused one-length arrays to not be used correctly
3666        # in concatenation.
3667        rnddata = np.random.randn(10)
3668        original = Dataset({"foo": ("x", rnddata)})
3669        dim = DataArray([100, 150], name="baz", dims="baz")
3670        expected = Dataset(
3671            {"foo": (("baz", "x"), np.tile(rnddata[np.newaxis, :], (2, 1)))},
3672            {"baz": [100, 150]},
3673        )
3674        with create_tmp_file() as tmp1, create_tmp_file() as tmp2:
3675            original.to_netcdf(tmp1)
3676            original.to_netcdf(tmp2)
3677            with open_mfdataset(
3678                [tmp1, tmp2], concat_dim=dim, combine="nested"
3679            ) as actual:
3680                assert_identical(expected, actual)
3681
3682    def test_dask_roundtrip(self):
3683        with create_tmp_file() as tmp:
3684            data = create_test_data()
3685            data.to_netcdf(tmp)
3686            chunks = {"dim1": 4, "dim2": 4, "dim3": 4, "time": 10}
3687            with open_dataset(tmp, chunks=chunks) as dask_ds:
3688                assert_identical(data, dask_ds)
3689                with create_tmp_file() as tmp2:
3690                    dask_ds.to_netcdf(tmp2)
3691                    with open_dataset(tmp2) as on_disk:
3692                        assert_identical(data, on_disk)
3693
3694    def test_deterministic_names(self):
3695        with create_tmp_file() as tmp:
3696            data = create_test_data()
3697            data.to_netcdf(tmp)
3698            with open_mfdataset(tmp, combine="by_coords") as ds:
3699                original_names = {k: v.data.name for k, v in ds.data_vars.items()}
3700            with open_mfdataset(tmp, combine="by_coords") as ds:
3701                repeat_names = {k: v.data.name for k, v in ds.data_vars.items()}
3702            for var_name, dask_name in original_names.items():
3703                assert var_name in dask_name
3704                assert dask_name[:13] == "open_dataset-"
3705            assert original_names == repeat_names
3706
3707    def test_dataarray_compute(self):
3708        # Test DataArray.compute() on dask backend.
3709        # The test for Dataset.compute() is already in DatasetIOBase;
3710        # however dask is the only tested backend which supports DataArrays
3711        actual = DataArray([1, 2]).chunk()
3712        computed = actual.compute()
3713        assert not actual._in_memory
3714        assert computed._in_memory
3715        assert_allclose(actual, computed, decode_bytes=False)
3716
3717    @pytest.mark.xfail
3718    def test_save_mfdataset_compute_false_roundtrip(self):
3719        from dask.delayed import Delayed
3720
3721        original = Dataset({"foo": ("x", np.random.randn(10))}).chunk()
3722        datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))]
3723        with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp1:
3724            with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp2:
3725                delayed_obj = save_mfdataset(
3726                    datasets, [tmp1, tmp2], engine=self.engine, compute=False
3727                )
3728                assert isinstance(delayed_obj, Delayed)
3729                delayed_obj.compute()
3730                with open_mfdataset(
3731                    [tmp1, tmp2], combine="nested", concat_dim="x"
3732                ) as actual:
3733                    assert_identical(actual, original)
3734
3735    def test_load_dataset(self):
3736        with create_tmp_file() as tmp:
3737            original = Dataset({"foo": ("x", np.random.randn(10))})
3738            original.to_netcdf(tmp)
3739            ds = load_dataset(tmp)
3740            # this would fail if we used open_dataset instead of load_dataset
3741            ds.to_netcdf(tmp)
3742
3743    def test_load_dataarray(self):
3744        with create_tmp_file() as tmp:
3745            original = Dataset({"foo": ("x", np.random.randn(10))})
3746            original.to_netcdf(tmp)
3747            ds = load_dataarray(tmp)
3748            # this would fail if we used open_dataarray instead of
3749            # load_dataarray
3750            ds.to_netcdf(tmp)
3751
3752
3753@requires_scipy_or_netCDF4
3754@requires_pydap
3755@pytest.mark.filterwarnings("ignore:The binary mode of fromstring is deprecated")
3756class TestPydap:
3757    def convert_to_pydap_dataset(self, original):
3758        from pydap.model import BaseType, DatasetType, GridType
3759
3760        ds = DatasetType("bears", **original.attrs)
3761        for key, var in original.data_vars.items():
3762            v = GridType(key)
3763            v[key] = BaseType(key, var.values, dimensions=var.dims, **var.attrs)
3764            for d in var.dims:
3765                v[d] = BaseType(d, var[d].values)
3766            ds[key] = v
3767        # check all dims are stored in ds
3768        for d in original.coords:
3769            ds[d] = BaseType(
3770                d, original[d].values, dimensions=(d,), **original[d].attrs
3771            )
3772        return ds
3773
3774    @contextlib.contextmanager
3775    def create_datasets(self, **kwargs):
3776        with open_example_dataset("bears.nc") as expected:
3777            pydap_ds = self.convert_to_pydap_dataset(expected)
3778            actual = open_dataset(PydapDataStore(pydap_ds))
3779            # TODO solve this workaround:
3780            # netcdf converts string to byte not unicode
3781            expected["bears"] = expected["bears"].astype(str)
3782            yield actual, expected
3783
3784    def test_cmp_local_file(self):
3785        with self.create_datasets() as (actual, expected):
3786            assert_equal(actual, expected)
3787
3788            # global attributes should be global attributes on the dataset
3789            assert "NC_GLOBAL" not in actual.attrs
3790            assert "history" in actual.attrs
3791
3792            # we don't check attributes exactly with assertDatasetIdentical()
3793            # because the test DAP server seems to insert some extra
3794            # attributes not found in the netCDF file.
3795            assert actual.attrs.keys() == expected.attrs.keys()
3796
3797        with self.create_datasets() as (actual, expected):
3798            assert_equal(actual[{"l": 2}], expected[{"l": 2}])
3799
3800        with self.create_datasets() as (actual, expected):
3801            assert_equal(actual.isel(i=0, j=-1), expected.isel(i=0, j=-1))
3802
3803        with self.create_datasets() as (actual, expected):
3804            assert_equal(actual.isel(j=slice(1, 2)), expected.isel(j=slice(1, 2)))
3805
3806        with self.create_datasets() as (actual, expected):
3807            indexers = {"i": [1, 0, 0], "j": [1, 2, 0, 1]}
3808            assert_equal(actual.isel(**indexers), expected.isel(**indexers))
3809
3810        with self.create_datasets() as (actual, expected):
3811            indexers = {
3812                "i": DataArray([0, 1, 0], dims="a"),
3813                "j": DataArray([0, 2, 1], dims="a"),
3814            }
3815            assert_equal(actual.isel(**indexers), expected.isel(**indexers))
3816
3817    def test_compatible_to_netcdf(self):
3818        # make sure it can be saved as a netcdf
3819        with self.create_datasets() as (actual, expected):
3820            with create_tmp_file() as tmp_file:
3821                actual.to_netcdf(tmp_file)
3822                with open_dataset(tmp_file) as actual2:
3823                    actual2["bears"] = actual2["bears"].astype(str)
3824                    assert_equal(actual2, expected)
3825
3826    @requires_dask
3827    def test_dask(self):
3828        with self.create_datasets(chunks={"j": 2}) as (actual, expected):
3829            assert_equal(actual, expected)
3830
3831
3832@network
3833@requires_scipy_or_netCDF4
3834@requires_pydap
3835class TestPydapOnline(TestPydap):
3836    @contextlib.contextmanager
3837    def create_datasets(self, **kwargs):
3838        url = "http://test.opendap.org/opendap/hyrax/data/nc/bears.nc"
3839        actual = open_dataset(url, engine="pydap", **kwargs)
3840        with open_example_dataset("bears.nc") as expected:
3841            # workaround to restore string which is converted to byte
3842            expected["bears"] = expected["bears"].astype(str)
3843            yield actual, expected
3844
3845    def test_session(self):
3846        from pydap.cas.urs import setup_session
3847
3848        session = setup_session("XarrayTestUser", "Xarray2017")
3849        with mock.patch("pydap.client.open_url") as mock_func:
3850            xr.backends.PydapDataStore.open("http://test.url", session=session)
3851        mock_func.assert_called_with("http://test.url", session=session)
3852
3853
3854@requires_scipy
3855@requires_pynio
3856class TestPyNio(CFEncodedBase, NetCDF3Only):
3857    def test_write_store(self):
3858        # pynio is read-only for now
3859        pass
3860
3861    @contextlib.contextmanager
3862    def open(self, path, **kwargs):
3863        with open_dataset(path, engine="pynio", **kwargs) as ds:
3864            yield ds
3865
3866    def test_kwargs(self):
3867        kwargs = {"format": "grib"}
3868        path = os.path.join(os.path.dirname(__file__), "data", "example")
3869        with backends.NioDataStore(path, **kwargs) as store:
3870            assert store._manager._kwargs["format"] == "grib"
3871
3872    def save(self, dataset, path, **kwargs):
3873        return dataset.to_netcdf(path, engine="scipy", **kwargs)
3874
3875    def test_weakrefs(self):
3876        example = Dataset({"foo": ("x", np.arange(5.0))})
3877        expected = example.rename({"foo": "bar", "x": "y"})
3878
3879        with create_tmp_file() as tmp_file:
3880            example.to_netcdf(tmp_file, engine="scipy")
3881            on_disk = open_dataset(tmp_file, engine="pynio")
3882            actual = on_disk.rename({"foo": "bar", "x": "y"})
3883            del on_disk  # trigger garbage collection
3884            assert_identical(actual, expected)
3885
3886
3887@requires_cfgrib
3888class TestCfGrib:
3889    def test_read(self):
3890        expected = {
3891            "number": 2,
3892            "time": 3,
3893            "isobaricInhPa": 2,
3894            "latitude": 3,
3895            "longitude": 4,
3896        }
3897        with open_example_dataset("example.grib", engine="cfgrib") as ds:
3898            assert ds.dims == expected
3899            assert list(ds.data_vars) == ["z", "t"]
3900            assert ds["z"].min() == 12660.0
3901
3902    def test_read_filter_by_keys(self):
3903        kwargs = {"filter_by_keys": {"shortName": "t"}}
3904        expected = {
3905            "number": 2,
3906            "time": 3,
3907            "isobaricInhPa": 2,
3908            "latitude": 3,
3909            "longitude": 4,
3910        }
3911        with open_example_dataset(
3912            "example.grib", engine="cfgrib", backend_kwargs=kwargs
3913        ) as ds:
3914            assert ds.dims == expected
3915            assert list(ds.data_vars) == ["t"]
3916            assert ds["t"].min() == 231.0
3917
3918    def test_read_outer(self):
3919        expected = {
3920            "number": 2,
3921            "time": 3,
3922            "isobaricInhPa": 2,
3923            "latitude": 2,
3924            "longitude": 3,
3925        }
3926        with open_example_dataset("example.grib", engine="cfgrib") as ds:
3927            res = ds.isel(latitude=[0, 2], longitude=[0, 1, 2])
3928            assert res.dims == expected
3929            assert res["t"].min() == 231.0
3930
3931
3932@requires_pseudonetcdf
3933@pytest.mark.filterwarnings("ignore:IOAPI_ISPH is assumed to be 6370000")
3934class TestPseudoNetCDFFormat:
3935    def open(self, path, **kwargs):
3936        return open_dataset(path, engine="pseudonetcdf", **kwargs)
3937
3938    @contextlib.contextmanager
3939    def roundtrip(
3940        self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False
3941    ):
3942        if save_kwargs is None:
3943            save_kwargs = {}
3944        if open_kwargs is None:
3945            open_kwargs = {}
3946        with create_tmp_file(allow_cleanup_failure=allow_cleanup_failure) as path:
3947            self.save(data, path, **save_kwargs)
3948            with self.open(path, **open_kwargs) as ds:
3949                yield ds
3950
3951    def test_ict_format(self):
3952        """
3953        Open a CAMx file and test data variables
3954        """
3955        stdattr = {
3956            "fill_value": -9999.0,
3957            "missing_value": -9999,
3958            "scale": 1,
3959            "llod_flag": -8888,
3960            "llod_value": "N/A",
3961            "ulod_flag": -7777,
3962            "ulod_value": "N/A",
3963        }
3964
3965        def myatts(**attrs):
3966            outattr = stdattr.copy()
3967            outattr.update(attrs)
3968            return outattr
3969
3970        input = {
3971            "coords": {},
3972            "attrs": {
3973                "fmt": "1001",
3974                "n_header_lines": 29,
3975                "PI_NAME": "Henderson, Barron",
3976                "ORGANIZATION_NAME": "U.S. EPA",
3977                "SOURCE_DESCRIPTION": "Example file with artificial data",
3978                "MISSION_NAME": "JUST_A_TEST",
3979                "VOLUME_INFO": "1, 1",
3980                "SDATE": "2018, 04, 27",
3981                "WDATE": "2018, 04, 27",
3982                "TIME_INTERVAL": "0",
3983                "INDEPENDENT_VARIABLE_DEFINITION": "Start_UTC",
3984                "INDEPENDENT_VARIABLE": "Start_UTC",
3985                "INDEPENDENT_VARIABLE_UNITS": "Start_UTC",
3986                "ULOD_FLAG": "-7777",
3987                "ULOD_VALUE": "N/A",
3988                "LLOD_FLAG": "-8888",
3989                "LLOD_VALUE": ("N/A, N/A, N/A, N/A, 0.025"),
3990                "OTHER_COMMENTS": (
3991                    "www-air.larc.nasa.gov/missions/etc/" + "IcarttDataFormat.htm"
3992                ),
3993                "REVISION": "R0",
3994                "R0": "No comments for this revision.",
3995                "TFLAG": "Start_UTC",
3996            },
3997            "dims": {"POINTS": 4},
3998            "data_vars": {
3999                "Start_UTC": {
4000                    "data": [43200.0, 46800.0, 50400.0, 50400.0],
4001                    "dims": ("POINTS",),
4002                    "attrs": myatts(units="Start_UTC", standard_name="Start_UTC"),
4003                },
4004                "lat": {
4005                    "data": [41.0, 42.0, 42.0, 42.0],
4006                    "dims": ("POINTS",),
4007                    "attrs": myatts(units="degrees_north", standard_name="lat"),
4008                },
4009                "lon": {
4010                    "data": [-71.0, -72.0, -73.0, -74.0],
4011                    "dims": ("POINTS",),
4012                    "attrs": myatts(units="degrees_east", standard_name="lon"),
4013                },
4014                "elev": {
4015                    "data": [5.0, 15.0, 20.0, 25.0],
4016                    "dims": ("POINTS",),
4017                    "attrs": myatts(units="meters", standard_name="elev"),
4018                },
4019                "TEST_ppbv": {
4020                    "data": [1.2345, 2.3456, 3.4567, 4.5678],
4021                    "dims": ("POINTS",),
4022                    "attrs": myatts(units="ppbv", standard_name="TEST_ppbv"),
4023                },
4024                "TESTM_ppbv": {
4025                    "data": [2.22, -9999.0, -7777.0, -8888.0],
4026                    "dims": ("POINTS",),
4027                    "attrs": myatts(
4028                        units="ppbv", standard_name="TESTM_ppbv", llod_value=0.025
4029                    ),
4030                },
4031            },
4032        }
4033        chkfile = Dataset.from_dict(input)
4034        with open_example_dataset(
4035            "example.ict", engine="pseudonetcdf", backend_kwargs={"format": "ffi1001"}
4036        ) as ictfile:
4037            assert_identical(ictfile, chkfile)
4038
4039    def test_ict_format_write(self):
4040        fmtkw = {"format": "ffi1001"}
4041        with open_example_dataset(
4042            "example.ict", engine="pseudonetcdf", backend_kwargs=fmtkw
4043        ) as expected:
4044            with self.roundtrip(
4045                expected, save_kwargs=fmtkw, open_kwargs={"backend_kwargs": fmtkw}
4046            ) as actual:
4047                assert_identical(expected, actual)
4048
4049    def test_uamiv_format_read(self):
4050        """
4051        Open a CAMx file and test data variables
4052        """
4053
4054        camxfile = open_example_dataset(
4055            "example.uamiv", engine="pseudonetcdf", backend_kwargs={"format": "uamiv"}
4056        )
4057        data = np.arange(20, dtype="f").reshape(1, 1, 4, 5)
4058        expected = xr.Variable(
4059            ("TSTEP", "LAY", "ROW", "COL"),
4060            data,
4061            dict(units="ppm", long_name="O3".ljust(16), var_desc="O3".ljust(80)),
4062        )
4063        actual = camxfile.variables["O3"]
4064        assert_allclose(expected, actual)
4065
4066        data = np.array([[[2002154, 0]]], dtype="i")
4067        expected = xr.Variable(
4068            ("TSTEP", "VAR", "DATE-TIME"),
4069            data,
4070            dict(
4071                long_name="TFLAG".ljust(16),
4072                var_desc="TFLAG".ljust(80),
4073                units="DATE-TIME".ljust(16),
4074            ),
4075        )
4076        actual = camxfile.variables["TFLAG"]
4077        assert_allclose(expected, actual)
4078        camxfile.close()
4079
4080    @requires_dask
4081    def test_uamiv_format_mfread(self):
4082        """
4083        Open a CAMx file and test data variables
4084        """
4085
4086        camxfile = open_example_mfdataset(
4087            ["example.uamiv", "example.uamiv"],
4088            engine="pseudonetcdf",
4089            concat_dim="TSTEP",
4090            combine="nested",
4091            backend_kwargs={"format": "uamiv"},
4092        )
4093
4094        data1 = np.arange(20, dtype="f").reshape(1, 1, 4, 5)
4095        data = np.concatenate([data1] * 2, axis=0)
4096        expected = xr.Variable(
4097            ("TSTEP", "LAY", "ROW", "COL"),
4098            data,
4099            dict(units="ppm", long_name="O3".ljust(16), var_desc="O3".ljust(80)),
4100        )
4101        actual = camxfile.variables["O3"]
4102        assert_allclose(expected, actual)
4103
4104        data = np.array([[[2002154, 0]]], dtype="i").repeat(2, 0)
4105        attrs = dict(
4106            long_name="TFLAG".ljust(16),
4107            var_desc="TFLAG".ljust(80),
4108            units="DATE-TIME".ljust(16),
4109        )
4110        dims = ("TSTEP", "VAR", "DATE-TIME")
4111        expected = xr.Variable(dims, data, attrs)
4112        actual = camxfile.variables["TFLAG"]
4113        assert_allclose(expected, actual)
4114        camxfile.close()
4115
4116    @pytest.mark.xfail(reason="Flaky; see GH3711")
4117    def test_uamiv_format_write(self):
4118        fmtkw = {"format": "uamiv"}
4119
4120        expected = open_example_dataset(
4121            "example.uamiv", engine="pseudonetcdf", backend_kwargs=fmtkw
4122        )
4123        with self.roundtrip(
4124            expected,
4125            save_kwargs=fmtkw,
4126            open_kwargs={"backend_kwargs": fmtkw},
4127            allow_cleanup_failure=True,
4128        ) as actual:
4129            assert_identical(expected, actual)
4130
4131        expected.close()
4132
4133    def save(self, dataset, path, **save_kwargs):
4134        import PseudoNetCDF as pnc
4135
4136        pncf = pnc.PseudoNetCDFFile()
4137        pncf.dimensions = {
4138            k: pnc.PseudoNetCDFDimension(pncf, k, v) for k, v in dataset.dims.items()
4139        }
4140        pncf.variables = {
4141            k: pnc.PseudoNetCDFVariable(
4142                pncf, k, v.dtype.char, v.dims, values=v.data[...], **v.attrs
4143            )
4144            for k, v in dataset.variables.items()
4145        }
4146        for pk, pv in dataset.attrs.items():
4147            setattr(pncf, pk, pv)
4148
4149        pnc.pncwrite(pncf, path, **save_kwargs)
4150
4151
4152@requires_rasterio
4153@contextlib.contextmanager
4154def create_tmp_geotiff(
4155    nx=4,
4156    ny=3,
4157    nz=3,
4158    transform=None,
4159    transform_args=default_value,
4160    crs=default_value,
4161    open_kwargs=None,
4162    additional_attrs=None,
4163):
4164    if transform_args is default_value:
4165        transform_args = [5000, 80000, 1000, 2000.0]
4166    if crs is default_value:
4167        crs = {
4168            "units": "m",
4169            "no_defs": True,
4170            "ellps": "WGS84",
4171            "proj": "utm",
4172            "zone": 18,
4173        }
4174    # yields a temporary geotiff file and a corresponding expected DataArray
4175    import rasterio
4176    from rasterio.transform import from_origin
4177
4178    if open_kwargs is None:
4179        open_kwargs = {}
4180
4181    with create_tmp_file(suffix=".tif", allow_cleanup_failure=ON_WINDOWS) as tmp_file:
4182        # allow 2d or 3d shapes
4183        if nz == 1:
4184            data_shape = ny, nx
4185            write_kwargs = {"indexes": 1}
4186        else:
4187            data_shape = nz, ny, nx
4188            write_kwargs = {}
4189        data = np.arange(nz * ny * nx, dtype=rasterio.float32).reshape(*data_shape)
4190        if transform is None:
4191            transform = from_origin(*transform_args)
4192        if additional_attrs is None:
4193            additional_attrs = {
4194                "descriptions": tuple("d{}".format(n + 1) for n in range(nz)),
4195                "units": tuple("u{}".format(n + 1) for n in range(nz)),
4196            }
4197        with rasterio.open(
4198            tmp_file,
4199            "w",
4200            driver="GTiff",
4201            height=ny,
4202            width=nx,
4203            count=nz,
4204            crs=crs,
4205            transform=transform,
4206            dtype=rasterio.float32,
4207            **open_kwargs,
4208        ) as s:
4209            for attr, val in additional_attrs.items():
4210                setattr(s, attr, val)
4211            s.write(data, **write_kwargs)
4212            dx, dy = s.res[0], -s.res[1]
4213
4214        a, b, c, d = transform_args
4215        data = data[np.newaxis, ...] if nz == 1 else data
4216        expected = DataArray(
4217            data,
4218            dims=("band", "y", "x"),
4219            coords={
4220                "band": np.arange(nz) + 1,
4221                "y": -np.arange(ny) * d + b + dy / 2,
4222                "x": np.arange(nx) * c + a + dx / 2,
4223            },
4224        )
4225        yield tmp_file, expected
4226
4227
4228@requires_rasterio
4229class TestRasterio:
4230    @requires_scipy_or_netCDF4
4231    def test_serialization(self):
4232        with create_tmp_geotiff(additional_attrs={}) as (tmp_file, expected):
4233            # Write it to a netcdf and read again (roundtrip)
4234            with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda:
4235                with create_tmp_file(suffix=".nc") as tmp_nc_file:
4236                    rioda.to_netcdf(tmp_nc_file)
4237                    with xr.open_dataarray(tmp_nc_file) as ncds:
4238                        assert_identical(rioda, ncds)
4239
4240    def test_utm(self):
4241        with create_tmp_geotiff() as (tmp_file, expected):
4242            with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda:
4243                assert_allclose(rioda, expected)
4244                assert rioda.attrs["scales"] == (1.0, 1.0, 1.0)
4245                assert rioda.attrs["offsets"] == (0.0, 0.0, 0.0)
4246                assert rioda.attrs["descriptions"] == ("d1", "d2", "d3")
4247                assert rioda.attrs["units"] == ("u1", "u2", "u3")
4248                assert isinstance(rioda.attrs["crs"], str)
4249                assert isinstance(rioda.attrs["res"], tuple)
4250                assert isinstance(rioda.attrs["is_tiled"], np.uint8)
4251                assert isinstance(rioda.attrs["transform"], tuple)
4252                assert len(rioda.attrs["transform"]) == 6
4253                np.testing.assert_array_equal(
4254                    rioda.attrs["nodatavals"], [np.NaN, np.NaN, np.NaN]
4255                )
4256
4257            # Check no parse coords
4258            with pytest.warns(DeprecationWarning), xr.open_rasterio(
4259                tmp_file, parse_coordinates=False
4260            ) as rioda:
4261                assert "x" not in rioda.coords
4262                assert "y" not in rioda.coords
4263
4264    def test_non_rectilinear(self):
4265        from rasterio.transform import from_origin
4266
4267        # Create a geotiff file with 2d coordinates
4268        with create_tmp_geotiff(
4269            transform=from_origin(0, 3, 1, 1).rotation(45), crs=None
4270        ) as (tmp_file, _):
4271            # Default is to not parse coords
4272            with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda:
4273                assert "x" not in rioda.coords
4274                assert "y" not in rioda.coords
4275                assert "crs" not in rioda.attrs
4276                assert rioda.attrs["scales"] == (1.0, 1.0, 1.0)
4277                assert rioda.attrs["offsets"] == (0.0, 0.0, 0.0)
4278                assert rioda.attrs["descriptions"] == ("d1", "d2", "d3")
4279                assert rioda.attrs["units"] == ("u1", "u2", "u3")
4280                assert isinstance(rioda.attrs["res"], tuple)
4281                assert isinstance(rioda.attrs["is_tiled"], np.uint8)
4282                assert isinstance(rioda.attrs["transform"], tuple)
4283                assert len(rioda.attrs["transform"]) == 6
4284
4285            # See if a warning is raised if we force it
4286            with pytest.warns(Warning, match="transformation isn't rectilinear"):
4287                with xr.open_rasterio(tmp_file, parse_coordinates=True) as rioda:
4288                    assert "x" not in rioda.coords
4289                    assert "y" not in rioda.coords
4290
4291    def test_platecarree(self):
4292        with create_tmp_geotiff(
4293            8,
4294            10,
4295            1,
4296            transform_args=[1, 2, 0.5, 2.0],
4297            crs="+proj=latlong",
4298            open_kwargs={"nodata": -9765},
4299        ) as (tmp_file, expected):
4300            with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda:
4301                assert_allclose(rioda, expected)
4302                assert rioda.attrs["scales"] == (1.0,)
4303                assert rioda.attrs["offsets"] == (0.0,)
4304                assert isinstance(rioda.attrs["descriptions"], tuple)
4305                assert isinstance(rioda.attrs["units"], tuple)
4306                assert isinstance(rioda.attrs["crs"], str)
4307                assert isinstance(rioda.attrs["res"], tuple)
4308                assert isinstance(rioda.attrs["is_tiled"], np.uint8)
4309                assert isinstance(rioda.attrs["transform"], tuple)
4310                assert len(rioda.attrs["transform"]) == 6
4311                np.testing.assert_array_equal(rioda.attrs["nodatavals"], [-9765.0])
4312
4313    # rasterio throws a Warning, which is expected since we test rasterio's defaults
4314    @pytest.mark.filterwarnings("ignore:Dataset has no geotransform")
4315    def test_notransform(self):
4316        # regression test for https://github.com/pydata/xarray/issues/1686
4317
4318        import rasterio
4319
4320        # Create a geotiff file
4321        with create_tmp_file(suffix=".tif") as tmp_file:
4322            # data
4323            nx, ny, nz = 4, 3, 3
4324            data = np.arange(nx * ny * nz, dtype=rasterio.float32).reshape(nz, ny, nx)
4325            with rasterio.open(
4326                tmp_file,
4327                "w",
4328                driver="GTiff",
4329                height=ny,
4330                width=nx,
4331                count=nz,
4332                dtype=rasterio.float32,
4333            ) as s:
4334                s.descriptions = ("nx", "ny", "nz")
4335                s.units = ("cm", "m", "km")
4336                s.write(data)
4337
4338            # Tests
4339            expected = DataArray(
4340                data,
4341                dims=("band", "y", "x"),
4342                coords={
4343                    "band": [1, 2, 3],
4344                    "y": [0.5, 1.5, 2.5],
4345                    "x": [0.5, 1.5, 2.5, 3.5],
4346                },
4347            )
4348            with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda:
4349                assert_allclose(rioda, expected)
4350                assert rioda.attrs["scales"] == (1.0, 1.0, 1.0)
4351                assert rioda.attrs["offsets"] == (0.0, 0.0, 0.0)
4352                assert rioda.attrs["descriptions"] == ("nx", "ny", "nz")
4353                assert rioda.attrs["units"] == ("cm", "m", "km")
4354                assert isinstance(rioda.attrs["res"], tuple)
4355                assert isinstance(rioda.attrs["is_tiled"], np.uint8)
4356                assert isinstance(rioda.attrs["transform"], tuple)
4357                assert len(rioda.attrs["transform"]) == 6
4358
4359    def test_indexing(self):
4360        with create_tmp_geotiff(
4361            8, 10, 3, transform_args=[1, 2, 0.5, 2.0], crs="+proj=latlong"
4362        ) as (tmp_file, expected):
4363            with pytest.warns(DeprecationWarning), xr.open_rasterio(
4364                tmp_file, cache=False
4365            ) as actual:
4366
4367                # tests
4368                # assert_allclose checks all data + coordinates
4369                assert_allclose(actual, expected)
4370                assert not actual.variable._in_memory
4371
4372                # Basic indexer
4373                ind = {"x": slice(2, 5), "y": slice(5, 7)}
4374                assert_allclose(expected.isel(**ind), actual.isel(**ind))
4375                assert not actual.variable._in_memory
4376
4377                ind = {"band": slice(1, 2), "x": slice(2, 5), "y": slice(5, 7)}
4378                assert_allclose(expected.isel(**ind), actual.isel(**ind))
4379                assert not actual.variable._in_memory
4380
4381                ind = {"band": slice(1, 2), "x": slice(2, 5), "y": 0}
4382                assert_allclose(expected.isel(**ind), actual.isel(**ind))
4383                assert not actual.variable._in_memory
4384
4385                # orthogonal indexer
4386                ind = {
4387                    "band": np.array([2, 1, 0]),
4388                    "x": np.array([1, 0]),
4389                    "y": np.array([0, 2]),
4390                }
4391                assert_allclose(expected.isel(**ind), actual.isel(**ind))
4392                assert not actual.variable._in_memory
4393
4394                ind = {"band": np.array([2, 1, 0]), "x": np.array([1, 0]), "y": 0}
4395                assert_allclose(expected.isel(**ind), actual.isel(**ind))
4396                assert not actual.variable._in_memory
4397
4398                ind = {"band": 0, "x": np.array([0, 0]), "y": np.array([1, 1, 1])}
4399                assert_allclose(expected.isel(**ind), actual.isel(**ind))
4400                assert not actual.variable._in_memory
4401
4402                # minus-stepped slice
4403                ind = {"band": np.array([2, 1, 0]), "x": slice(-1, None, -1), "y": 0}
4404                assert_allclose(expected.isel(**ind), actual.isel(**ind))
4405                assert not actual.variable._in_memory
4406
4407                ind = {"band": np.array([2, 1, 0]), "x": 1, "y": slice(-1, 1, -2)}
4408                assert_allclose(expected.isel(**ind), actual.isel(**ind))
4409                assert not actual.variable._in_memory
4410
4411                # empty selection
4412                ind = {"band": np.array([2, 1, 0]), "x": 1, "y": slice(2, 2, 1)}
4413                assert_allclose(expected.isel(**ind), actual.isel(**ind))
4414                assert not actual.variable._in_memory
4415
4416                ind = {"band": slice(0, 0), "x": 1, "y": 2}
4417                assert_allclose(expected.isel(**ind), actual.isel(**ind))
4418                assert not actual.variable._in_memory
4419
4420                # vectorized indexer
4421                ind = {
4422                    "band": DataArray([2, 1, 0], dims="a"),
4423                    "x": DataArray([1, 0, 0], dims="a"),
4424                    "y": np.array([0, 2]),
4425                }
4426                assert_allclose(expected.isel(**ind), actual.isel(**ind))
4427                assert not actual.variable._in_memory
4428
4429                ind = {
4430                    "band": DataArray([[2, 1, 0], [1, 0, 2]], dims=["a", "b"]),
4431                    "x": DataArray([[1, 0, 0], [0, 1, 0]], dims=["a", "b"]),
4432                    "y": 0,
4433                }
4434                assert_allclose(expected.isel(**ind), actual.isel(**ind))
4435                assert not actual.variable._in_memory
4436
4437                # Selecting lists of bands is fine
4438                ex = expected.isel(band=[1, 2])
4439                ac = actual.isel(band=[1, 2])
4440                assert_allclose(ac, ex)
4441                ex = expected.isel(band=[0, 2])
4442                ac = actual.isel(band=[0, 2])
4443                assert_allclose(ac, ex)
4444
4445                # Integer indexing
4446                ex = expected.isel(band=1)
4447                ac = actual.isel(band=1)
4448                assert_allclose(ac, ex)
4449
4450                ex = expected.isel(x=1, y=2)
4451                ac = actual.isel(x=1, y=2)
4452                assert_allclose(ac, ex)
4453
4454                ex = expected.isel(band=0, x=1, y=2)
4455                ac = actual.isel(band=0, x=1, y=2)
4456                assert_allclose(ac, ex)
4457
4458                # Mixed
4459                ex = actual.isel(x=slice(2), y=slice(2))
4460                ac = actual.isel(x=[0, 1], y=[0, 1])
4461                assert_allclose(ac, ex)
4462
4463                ex = expected.isel(band=0, x=1, y=slice(5, 7))
4464                ac = actual.isel(band=0, x=1, y=slice(5, 7))
4465                assert_allclose(ac, ex)
4466
4467                ex = expected.isel(band=0, x=slice(2, 5), y=2)
4468                ac = actual.isel(band=0, x=slice(2, 5), y=2)
4469                assert_allclose(ac, ex)
4470
4471                # One-element lists
4472                ex = expected.isel(band=[0], x=slice(2, 5), y=[2])
4473                ac = actual.isel(band=[0], x=slice(2, 5), y=[2])
4474                assert_allclose(ac, ex)
4475
4476    def test_caching(self):
4477        with create_tmp_geotiff(
4478            8, 10, 3, transform_args=[1, 2, 0.5, 2.0], crs="+proj=latlong"
4479        ) as (tmp_file, expected):
4480            # Cache is the default
4481            with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as actual:
4482
4483                # This should cache everything
4484                assert_allclose(actual, expected)
4485
4486                # once cached, non-windowed indexing should become possible
4487                ac = actual.isel(x=[2, 4])
4488                ex = expected.isel(x=[2, 4])
4489                assert_allclose(ac, ex)
4490
4491    @requires_dask
4492    def test_chunks(self):
4493        with create_tmp_geotiff(
4494            8, 10, 3, transform_args=[1, 2, 0.5, 2.0], crs="+proj=latlong"
4495        ) as (tmp_file, expected):
4496            # Chunk at open time
4497            with pytest.warns(DeprecationWarning), xr.open_rasterio(
4498                tmp_file, chunks=(1, 2, 2)
4499            ) as actual:
4500
4501                import dask.array as da
4502
4503                assert isinstance(actual.data, da.Array)
4504                assert "open_rasterio" in actual.data.name
4505
4506                # do some arithmetic
4507                ac = actual.mean()
4508                ex = expected.mean()
4509                assert_allclose(ac, ex)
4510
4511                ac = actual.sel(band=1).mean(dim="x")
4512                ex = expected.sel(band=1).mean(dim="x")
4513                assert_allclose(ac, ex)
4514
4515    @pytest.mark.xfail(
4516        not has_dask, reason="without dask, a non-serializable lock is used"
4517    )
4518    def test_pickle_rasterio(self):
4519        # regression test for https://github.com/pydata/xarray/issues/2121
4520        with create_tmp_geotiff() as (tmp_file, expected):
4521            with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda:
4522                temp = pickle.dumps(rioda)
4523                with pickle.loads(temp) as actual:
4524                    assert_equal(actual, rioda)
4525
4526    def test_ENVI_tags(self):
4527        rasterio = pytest.importorskip("rasterio", minversion="1.0a")
4528        from rasterio.transform import from_origin
4529
4530        # Create an ENVI file with some tags in the ENVI namespace
4531        # this test uses a custom driver, so we can't use create_tmp_geotiff
4532        with create_tmp_file(suffix=".dat") as tmp_file:
4533            # data
4534            nx, ny, nz = 4, 3, 3
4535            data = np.arange(nx * ny * nz, dtype=rasterio.float32).reshape(nz, ny, nx)
4536            transform = from_origin(5000, 80000, 1000, 2000.0)
4537            with rasterio.open(
4538                tmp_file,
4539                "w",
4540                driver="ENVI",
4541                height=ny,
4542                width=nx,
4543                count=nz,
4544                crs={
4545                    "units": "m",
4546                    "no_defs": True,
4547                    "ellps": "WGS84",
4548                    "proj": "utm",
4549                    "zone": 18,
4550                },
4551                transform=transform,
4552                dtype=rasterio.float32,
4553            ) as s:
4554                s.update_tags(
4555                    ns="ENVI",
4556                    description="{Tagged file}",
4557                    wavelength="{123.000000, 234.234000, 345.345678}",
4558                    fwhm="{1.000000, 0.234000, 0.000345}",
4559                )
4560                s.write(data)
4561                dx, dy = s.res[0], -s.res[1]
4562
4563            # Tests
4564            coords = {
4565                "band": [1, 2, 3],
4566                "y": -np.arange(ny) * 2000 + 80000 + dy / 2,
4567                "x": np.arange(nx) * 1000 + 5000 + dx / 2,
4568                "wavelength": ("band", np.array([123, 234.234, 345.345678])),
4569                "fwhm": ("band", np.array([1, 0.234, 0.000345])),
4570            }
4571            expected = DataArray(data, dims=("band", "y", "x"), coords=coords)
4572
4573            with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda:
4574                assert_allclose(rioda, expected)
4575                assert isinstance(rioda.attrs["crs"], str)
4576                assert isinstance(rioda.attrs["res"], tuple)
4577                assert isinstance(rioda.attrs["is_tiled"], np.uint8)
4578                assert isinstance(rioda.attrs["transform"], tuple)
4579                assert len(rioda.attrs["transform"]) == 6
4580                # from ENVI tags
4581                assert isinstance(rioda.attrs["description"], str)
4582                assert isinstance(rioda.attrs["map_info"], str)
4583                assert isinstance(rioda.attrs["samples"], str)
4584
4585    def test_geotiff_tags(self):
4586        # Create a geotiff file with some tags
4587        with create_tmp_geotiff() as (tmp_file, _):
4588            with pytest.warns(DeprecationWarning), xr.open_rasterio(tmp_file) as rioda:
4589                assert isinstance(rioda.attrs["AREA_OR_POINT"], str)
4590
4591    @requires_dask
4592    def test_no_mftime(self):
4593        # rasterio can accept "filename" urguments that are actually urls,
4594        # including paths to remote files.
4595        # In issue #1816, we found that these caused dask to break, because
4596        # the modification time was used to determine the dask token. This
4597        # tests ensure we can still chunk such files when reading with
4598        # rasterio.
4599        with create_tmp_geotiff(
4600            8, 10, 3, transform_args=[1, 2, 0.5, 2.0], crs="+proj=latlong"
4601        ) as (tmp_file, expected):
4602            with mock.patch("os.path.getmtime", side_effect=OSError):
4603                with pytest.warns(DeprecationWarning), xr.open_rasterio(
4604                    tmp_file, chunks=(1, 2, 2)
4605                ) as actual:
4606                    import dask.array as da
4607
4608                    assert isinstance(actual.data, da.Array)
4609                    assert_allclose(actual, expected)
4610
4611    @network
4612    def test_http_url(self):
4613        # more examples urls here
4614        # http://download.osgeo.org/geotiff/samples/
4615        url = "http://download.osgeo.org/geotiff/samples/made_up/ntf_nord.tif"
4616        with pytest.warns(DeprecationWarning), xr.open_rasterio(url) as actual:
4617            assert actual.shape == (1, 512, 512)
4618        # make sure chunking works
4619        with pytest.warns(DeprecationWarning), xr.open_rasterio(
4620            url, chunks=(1, 256, 256)
4621        ) as actual:
4622            import dask.array as da
4623
4624            assert isinstance(actual.data, da.Array)
4625
4626    def test_rasterio_environment(self):
4627        import rasterio
4628
4629        with create_tmp_geotiff() as (tmp_file, expected):
4630            # Should fail with error since suffix not allowed
4631            with pytest.raises(Exception):
4632                with rasterio.Env(GDAL_SKIP="GTiff"):
4633                    with pytest.warns(DeprecationWarning), xr.open_rasterio(
4634                        tmp_file
4635                    ) as actual:
4636                        assert_allclose(actual, expected)
4637
4638    @pytest.mark.xfail(reason="rasterio 1.1.1 is broken. GH3573")
4639    def test_rasterio_vrt(self):
4640        import rasterio
4641
4642        # tmp_file default crs is UTM: CRS({'init': 'epsg:32618'}
4643        with create_tmp_geotiff() as (tmp_file, expected):
4644            with rasterio.open(tmp_file) as src:
4645                with rasterio.vrt.WarpedVRT(src, crs="epsg:4326") as vrt:
4646                    expected_shape = (vrt.width, vrt.height)
4647                    expected_crs = vrt.crs
4648                    expected_res = vrt.res
4649                    # Value of single pixel in center of image
4650                    lon, lat = vrt.xy(vrt.width // 2, vrt.height // 2)
4651                    expected_val = next(vrt.sample([(lon, lat)]))
4652                    with pytest.warns(DeprecationWarning), xr.open_rasterio(vrt) as da:
4653                        actual_shape = (da.sizes["x"], da.sizes["y"])
4654                        actual_crs = da.crs
4655                        actual_res = da.res
4656                        actual_val = da.sel(dict(x=lon, y=lat), method="nearest").data
4657
4658                        assert actual_crs == expected_crs
4659                        assert actual_res == expected_res
4660                        assert actual_shape == expected_shape
4661                        assert expected_val.all() == actual_val.all()
4662
4663    def test_rasterio_vrt_with_transform_and_size(self):
4664        # Test open_rasterio() support of WarpedVRT with transform, width and
4665        # height (issue #2864)
4666
4667        # https://github.com/mapbox/rasterio/1768
4668        rasterio = pytest.importorskip("rasterio", minversion="1.0.28")
4669        from affine import Affine
4670        from rasterio.warp import calculate_default_transform
4671
4672        with create_tmp_geotiff() as (tmp_file, expected):
4673            with rasterio.open(tmp_file) as src:
4674                # Estimate the transform, width and height
4675                # for a change of resolution
4676                # tmp_file initial res is (1000,2000) (default values)
4677                trans, w, h = calculate_default_transform(
4678                    src.crs, src.crs, src.width, src.height, resolution=500, *src.bounds
4679                )
4680                with rasterio.vrt.WarpedVRT(
4681                    src, transform=trans, width=w, height=h
4682                ) as vrt:
4683                    expected_shape = (vrt.width, vrt.height)
4684                    expected_res = vrt.res
4685                    expected_transform = vrt.transform
4686                    with xr.open_rasterio(vrt) as da:
4687                        actual_shape = (da.sizes["x"], da.sizes["y"])
4688                        actual_res = da.res
4689                        actual_transform = Affine(*da.transform)
4690                        assert actual_res == expected_res
4691                        assert actual_shape == expected_shape
4692                        assert actual_transform == expected_transform
4693
4694    def test_rasterio_vrt_with_src_crs(self):
4695        # Test open_rasterio() support of WarpedVRT with specified src_crs
4696
4697        # https://github.com/mapbox/rasterio/1768
4698        rasterio = pytest.importorskip("rasterio", minversion="1.0.28")
4699
4700        # create geotiff with no CRS and specify it manually
4701        with create_tmp_geotiff(crs=None) as (tmp_file, expected):
4702            src_crs = rasterio.crs.CRS({"init": "epsg:32618"})
4703            with rasterio.open(tmp_file) as src:
4704                assert src.crs is None
4705                with rasterio.vrt.WarpedVRT(src, src_crs=src_crs) as vrt:
4706                    with pytest.warns(DeprecationWarning), xr.open_rasterio(vrt) as da:
4707                        assert da.crs == src_crs
4708
4709    @network
4710    def test_rasterio_vrt_network(self):
4711        # Make sure loading w/ rasterio give same results as xarray
4712        import rasterio
4713
4714        # use same url that rasterio package uses in tests
4715        prefix = "https://landsat-pds.s3.amazonaws.com/L8/139/045/"
4716        image = "LC81390452014295LGN00/LC81390452014295LGN00_B1.TIF"
4717        httpstif = prefix + image
4718        with rasterio.Env(aws_unsigned=True):
4719            with rasterio.open(httpstif) as src:
4720                with rasterio.vrt.WarpedVRT(src, crs="epsg:4326") as vrt:
4721                    expected_shape = vrt.width, vrt.height
4722                    expected_res = vrt.res
4723                    # Value of single pixel in center of image
4724                    lon, lat = vrt.xy(vrt.width // 2, vrt.height // 2)
4725                    expected_val = next(vrt.sample([(lon, lat)]))
4726                    with pytest.warns(DeprecationWarning), xr.open_rasterio(vrt) as da:
4727                        actual_shape = da.sizes["x"], da.sizes["y"]
4728                        actual_res = da.res
4729                        actual_val = da.sel(dict(x=lon, y=lat), method="nearest").data
4730
4731                        assert actual_shape == expected_shape
4732                        assert actual_res == expected_res
4733                        assert expected_val == actual_val
4734
4735
4736class TestEncodingInvalid:
4737    def test_extract_nc4_variable_encoding(self):
4738        var = xr.Variable(("x",), [1, 2, 3], {}, {"foo": "bar"})
4739        with pytest.raises(ValueError, match=r"unexpected encoding"):
4740            _extract_nc4_variable_encoding(var, raise_on_invalid=True)
4741
4742        var = xr.Variable(("x",), [1, 2, 3], {}, {"chunking": (2, 1)})
4743        encoding = _extract_nc4_variable_encoding(var)
4744        assert {} == encoding
4745
4746        # regression test
4747        var = xr.Variable(("x",), [1, 2, 3], {}, {"shuffle": True})
4748        encoding = _extract_nc4_variable_encoding(var, raise_on_invalid=True)
4749        assert {"shuffle": True} == encoding
4750
4751        # Variables with unlim dims must be chunked on output.
4752        var = xr.Variable(("x",), [1, 2, 3], {}, {"contiguous": True})
4753        encoding = _extract_nc4_variable_encoding(var, unlimited_dims=("x",))
4754        assert {} == encoding
4755
4756    def test_extract_h5nc_encoding(self):
4757        # not supported with h5netcdf (yet)
4758        var = xr.Variable(("x",), [1, 2, 3], {}, {"least_sigificant_digit": 2})
4759        with pytest.raises(ValueError, match=r"unexpected encoding"):
4760            _extract_nc4_variable_encoding(var, raise_on_invalid=True)
4761
4762
4763class MiscObject:
4764    pass
4765
4766
4767@requires_netCDF4
4768class TestValidateAttrs:
4769    def test_validating_attrs(self):
4770        def new_dataset():
4771            return Dataset({"data": ("y", np.arange(10.0))}, {"y": np.arange(10)})
4772
4773        def new_dataset_and_dataset_attrs():
4774            ds = new_dataset()
4775            return ds, ds.attrs
4776
4777        def new_dataset_and_data_attrs():
4778            ds = new_dataset()
4779            return ds, ds.data.attrs
4780
4781        def new_dataset_and_coord_attrs():
4782            ds = new_dataset()
4783            return ds, ds.coords["y"].attrs
4784
4785        for new_dataset_and_attrs in [
4786            new_dataset_and_dataset_attrs,
4787            new_dataset_and_data_attrs,
4788            new_dataset_and_coord_attrs,
4789        ]:
4790            ds, attrs = new_dataset_and_attrs()
4791
4792            attrs[123] = "test"
4793            with pytest.raises(TypeError, match=r"Invalid name for attr: 123"):
4794                ds.to_netcdf("test.nc")
4795
4796            ds, attrs = new_dataset_and_attrs()
4797            attrs[MiscObject()] = "test"
4798            with pytest.raises(TypeError, match=r"Invalid name for attr: "):
4799                ds.to_netcdf("test.nc")
4800
4801            ds, attrs = new_dataset_and_attrs()
4802            attrs[""] = "test"
4803            with pytest.raises(ValueError, match=r"Invalid name for attr '':"):
4804                ds.to_netcdf("test.nc")
4805
4806            # This one should work
4807            ds, attrs = new_dataset_and_attrs()
4808            attrs["test"] = "test"
4809            with create_tmp_file() as tmp_file:
4810                ds.to_netcdf(tmp_file)
4811
4812            ds, attrs = new_dataset_and_attrs()
4813            attrs["test"] = {"a": 5}
4814            with pytest.raises(TypeError, match=r"Invalid value for attr 'test'"):
4815                ds.to_netcdf("test.nc")
4816
4817            ds, attrs = new_dataset_and_attrs()
4818            attrs["test"] = MiscObject()
4819            with pytest.raises(TypeError, match=r"Invalid value for attr 'test'"):
4820                ds.to_netcdf("test.nc")
4821
4822            ds, attrs = new_dataset_and_attrs()
4823            attrs["test"] = 5
4824            with create_tmp_file() as tmp_file:
4825                ds.to_netcdf(tmp_file)
4826
4827            ds, attrs = new_dataset_and_attrs()
4828            attrs["test"] = 3.14
4829            with create_tmp_file() as tmp_file:
4830                ds.to_netcdf(tmp_file)
4831
4832            ds, attrs = new_dataset_and_attrs()
4833            attrs["test"] = [1, 2, 3, 4]
4834            with create_tmp_file() as tmp_file:
4835                ds.to_netcdf(tmp_file)
4836
4837            ds, attrs = new_dataset_and_attrs()
4838            attrs["test"] = (1.9, 2.5)
4839            with create_tmp_file() as tmp_file:
4840                ds.to_netcdf(tmp_file)
4841
4842            ds, attrs = new_dataset_and_attrs()
4843            attrs["test"] = np.arange(5)
4844            with create_tmp_file() as tmp_file:
4845                ds.to_netcdf(tmp_file)
4846
4847            ds, attrs = new_dataset_and_attrs()
4848            attrs["test"] = "This is a string"
4849            with create_tmp_file() as tmp_file:
4850                ds.to_netcdf(tmp_file)
4851
4852            ds, attrs = new_dataset_and_attrs()
4853            attrs["test"] = ""
4854            with create_tmp_file() as tmp_file:
4855                ds.to_netcdf(tmp_file)
4856
4857
4858@requires_scipy_or_netCDF4
4859class TestDataArrayToNetCDF:
4860    def test_dataarray_to_netcdf_no_name(self):
4861        original_da = DataArray(np.arange(12).reshape((3, 4)))
4862
4863        with create_tmp_file() as tmp:
4864            original_da.to_netcdf(tmp)
4865
4866            with open_dataarray(tmp) as loaded_da:
4867                assert_identical(original_da, loaded_da)
4868
4869    def test_dataarray_to_netcdf_with_name(self):
4870        original_da = DataArray(np.arange(12).reshape((3, 4)), name="test")
4871
4872        with create_tmp_file() as tmp:
4873            original_da.to_netcdf(tmp)
4874
4875            with open_dataarray(tmp) as loaded_da:
4876                assert_identical(original_da, loaded_da)
4877
4878    def test_dataarray_to_netcdf_coord_name_clash(self):
4879        original_da = DataArray(
4880            np.arange(12).reshape((3, 4)), dims=["x", "y"], name="x"
4881        )
4882
4883        with create_tmp_file() as tmp:
4884            original_da.to_netcdf(tmp)
4885
4886            with open_dataarray(tmp) as loaded_da:
4887                assert_identical(original_da, loaded_da)
4888
4889    def test_open_dataarray_options(self):
4890        data = DataArray(np.arange(5), coords={"y": ("x", range(5))}, dims=["x"])
4891
4892        with create_tmp_file() as tmp:
4893            data.to_netcdf(tmp)
4894
4895            expected = data.drop_vars("y")
4896            with open_dataarray(tmp, drop_variables=["y"]) as loaded:
4897                assert_identical(expected, loaded)
4898
4899    @requires_scipy
4900    def test_dataarray_to_netcdf_return_bytes(self):
4901        # regression test for GH1410
4902        data = xr.DataArray([1, 2, 3])
4903        output = data.to_netcdf()
4904        assert isinstance(output, bytes)
4905
4906    def test_dataarray_to_netcdf_no_name_pathlib(self):
4907        original_da = DataArray(np.arange(12).reshape((3, 4)))
4908
4909        with create_tmp_file() as tmp:
4910            tmp = Path(tmp)
4911            original_da.to_netcdf(tmp)
4912
4913            with open_dataarray(tmp) as loaded_da:
4914                assert_identical(original_da, loaded_da)
4915
4916
4917@requires_scipy_or_netCDF4
4918def test_no_warning_from_dask_effective_get():
4919    with create_tmp_file() as tmpfile:
4920        with pytest.warns(None) as record:
4921            ds = Dataset()
4922            ds.to_netcdf(tmpfile)
4923        assert len(record) == 0
4924
4925
4926@requires_scipy_or_netCDF4
4927def test_source_encoding_always_present():
4928    # Test for GH issue #2550.
4929    rnddata = np.random.randn(10)
4930    original = Dataset({"foo": ("x", rnddata)})
4931    with create_tmp_file() as tmp:
4932        original.to_netcdf(tmp)
4933        with open_dataset(tmp) as ds:
4934            assert ds.encoding["source"] == tmp
4935
4936
4937def _assert_no_dates_out_of_range_warning(record):
4938    undesired_message = "dates out of range"
4939    for warning in record:
4940        assert undesired_message not in str(warning.message)
4941
4942
4943@requires_scipy_or_netCDF4
4944@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS)
4945def test_use_cftime_standard_calendar_default_in_range(calendar):
4946    x = [0, 1]
4947    time = [0, 720]
4948    units_date = "2000-01-01"
4949    units = "days since 2000-01-01"
4950    original = DataArray(x, [("time", time)], name="x")
4951    original = original.to_dataset()
4952    for v in ["x", "time"]:
4953        original[v].attrs["units"] = units
4954        original[v].attrs["calendar"] = calendar
4955
4956    x_timedeltas = np.array(x).astype("timedelta64[D]")
4957    time_timedeltas = np.array(time).astype("timedelta64[D]")
4958    decoded_x = np.datetime64(units_date, "ns") + x_timedeltas
4959    decoded_time = np.datetime64(units_date, "ns") + time_timedeltas
4960    expected_x = DataArray(decoded_x, [("time", decoded_time)], name="x")
4961    expected_time = DataArray(decoded_time, [("time", decoded_time)], name="time")
4962
4963    with create_tmp_file() as tmp_file:
4964        original.to_netcdf(tmp_file)
4965        with pytest.warns(None) as record:
4966            with open_dataset(tmp_file) as ds:
4967                assert_identical(expected_x, ds.x)
4968                assert_identical(expected_time, ds.time)
4969            _assert_no_dates_out_of_range_warning(record)
4970
4971
4972@requires_cftime
4973@requires_scipy_or_netCDF4
4974@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS)
4975@pytest.mark.parametrize("units_year", [1500, 2500])
4976def test_use_cftime_standard_calendar_default_out_of_range(calendar, units_year):
4977    import cftime
4978
4979    x = [0, 1]
4980    time = [0, 720]
4981    units = f"days since {units_year}-01-01"
4982    original = DataArray(x, [("time", time)], name="x")
4983    original = original.to_dataset()
4984    for v in ["x", "time"]:
4985        original[v].attrs["units"] = units
4986        original[v].attrs["calendar"] = calendar
4987
4988    decoded_x = cftime.num2date(x, units, calendar, only_use_cftime_datetimes=True)
4989    decoded_time = cftime.num2date(
4990        time, units, calendar, only_use_cftime_datetimes=True
4991    )
4992    expected_x = DataArray(decoded_x, [("time", decoded_time)], name="x")
4993    expected_time = DataArray(decoded_time, [("time", decoded_time)], name="time")
4994
4995    with create_tmp_file() as tmp_file:
4996        original.to_netcdf(tmp_file)
4997        with pytest.warns(SerializationWarning):
4998            with open_dataset(tmp_file) as ds:
4999                assert_identical(expected_x, ds.x)
5000                assert_identical(expected_time, ds.time)
5001
5002
5003@requires_cftime
5004@requires_scipy_or_netCDF4
5005@pytest.mark.parametrize("calendar", _ALL_CALENDARS)
5006@pytest.mark.parametrize("units_year", [1500, 2000, 2500])
5007def test_use_cftime_true(calendar, units_year):
5008    import cftime
5009
5010    x = [0, 1]
5011    time = [0, 720]
5012    units = f"days since {units_year}-01-01"
5013    original = DataArray(x, [("time", time)], name="x")
5014    original = original.to_dataset()
5015    for v in ["x", "time"]:
5016        original[v].attrs["units"] = units
5017        original[v].attrs["calendar"] = calendar
5018
5019    decoded_x = cftime.num2date(x, units, calendar, only_use_cftime_datetimes=True)
5020    decoded_time = cftime.num2date(
5021        time, units, calendar, only_use_cftime_datetimes=True
5022    )
5023    expected_x = DataArray(decoded_x, [("time", decoded_time)], name="x")
5024    expected_time = DataArray(decoded_time, [("time", decoded_time)], name="time")
5025
5026    with create_tmp_file() as tmp_file:
5027        original.to_netcdf(tmp_file)
5028        with pytest.warns(None) as record:
5029            with open_dataset(tmp_file, use_cftime=True) as ds:
5030                assert_identical(expected_x, ds.x)
5031                assert_identical(expected_time, ds.time)
5032            _assert_no_dates_out_of_range_warning(record)
5033
5034
5035@requires_scipy_or_netCDF4
5036@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS)
5037def test_use_cftime_false_standard_calendar_in_range(calendar):
5038    x = [0, 1]
5039    time = [0, 720]
5040    units_date = "2000-01-01"
5041    units = "days since 2000-01-01"
5042    original = DataArray(x, [("time", time)], name="x")
5043    original = original.to_dataset()
5044    for v in ["x", "time"]:
5045        original[v].attrs["units"] = units
5046        original[v].attrs["calendar"] = calendar
5047
5048    x_timedeltas = np.array(x).astype("timedelta64[D]")
5049    time_timedeltas = np.array(time).astype("timedelta64[D]")
5050    decoded_x = np.datetime64(units_date, "ns") + x_timedeltas
5051    decoded_time = np.datetime64(units_date, "ns") + time_timedeltas
5052    expected_x = DataArray(decoded_x, [("time", decoded_time)], name="x")
5053    expected_time = DataArray(decoded_time, [("time", decoded_time)], name="time")
5054
5055    with create_tmp_file() as tmp_file:
5056        original.to_netcdf(tmp_file)
5057        with pytest.warns(None) as record:
5058            with open_dataset(tmp_file, use_cftime=False) as ds:
5059                assert_identical(expected_x, ds.x)
5060                assert_identical(expected_time, ds.time)
5061            _assert_no_dates_out_of_range_warning(record)
5062
5063
5064@requires_scipy_or_netCDF4
5065@pytest.mark.parametrize("calendar", _STANDARD_CALENDARS)
5066@pytest.mark.parametrize("units_year", [1500, 2500])
5067def test_use_cftime_false_standard_calendar_out_of_range(calendar, units_year):
5068    x = [0, 1]
5069    time = [0, 720]
5070    units = f"days since {units_year}-01-01"
5071    original = DataArray(x, [("time", time)], name="x")
5072    original = original.to_dataset()
5073    for v in ["x", "time"]:
5074        original[v].attrs["units"] = units
5075        original[v].attrs["calendar"] = calendar
5076
5077    with create_tmp_file() as tmp_file:
5078        original.to_netcdf(tmp_file)
5079        with pytest.raises((OutOfBoundsDatetime, ValueError)):
5080            open_dataset(tmp_file, use_cftime=False)
5081
5082
5083@requires_scipy_or_netCDF4
5084@pytest.mark.parametrize("calendar", _NON_STANDARD_CALENDARS)
5085@pytest.mark.parametrize("units_year", [1500, 2000, 2500])
5086def test_use_cftime_false_nonstandard_calendar(calendar, units_year):
5087    x = [0, 1]
5088    time = [0, 720]
5089    units = f"days since {units_year}"
5090    original = DataArray(x, [("time", time)], name="x")
5091    original = original.to_dataset()
5092    for v in ["x", "time"]:
5093        original[v].attrs["units"] = units
5094        original[v].attrs["calendar"] = calendar
5095
5096    with create_tmp_file() as tmp_file:
5097        original.to_netcdf(tmp_file)
5098        with pytest.raises((OutOfBoundsDatetime, ValueError)):
5099            open_dataset(tmp_file, use_cftime=False)
5100
5101
5102@pytest.mark.parametrize("engine", ["netcdf4", "scipy"])
5103def test_invalid_netcdf_raises(engine):
5104    data = create_test_data()
5105    with pytest.raises(ValueError, match=r"unrecognized option 'invalid_netcdf'"):
5106        data.to_netcdf("foo.nc", engine=engine, invalid_netcdf=True)
5107
5108
5109@requires_zarr
5110def test_encode_zarr_attr_value():
5111    # array -> list
5112    arr = np.array([1, 2, 3])
5113    expected = [1, 2, 3]
5114    actual = backends.zarr.encode_zarr_attr_value(arr)
5115    assert isinstance(actual, list)
5116    assert actual == expected
5117
5118    # scalar array -> scalar
5119    sarr = np.array(1)[()]
5120    expected = 1
5121    actual = backends.zarr.encode_zarr_attr_value(sarr)
5122    assert isinstance(actual, int)
5123    assert actual == expected
5124
5125    # string -> string (no change)
5126    expected = "foo"
5127    actual = backends.zarr.encode_zarr_attr_value(expected)
5128    assert isinstance(actual, str)
5129    assert actual == expected
5130
5131
5132@requires_zarr
5133def test_extract_zarr_variable_encoding():
5134
5135    var = xr.Variable("x", [1, 2])
5136    actual = backends.zarr.extract_zarr_variable_encoding(var)
5137    assert "chunks" in actual
5138    assert actual["chunks"] is None
5139
5140    var = xr.Variable("x", [1, 2], encoding={"chunks": (1,)})
5141    actual = backends.zarr.extract_zarr_variable_encoding(var)
5142    assert actual["chunks"] == (1,)
5143
5144    # does not raise on invalid
5145    var = xr.Variable("x", [1, 2], encoding={"foo": (1,)})
5146    actual = backends.zarr.extract_zarr_variable_encoding(var)
5147
5148    # raises on invalid
5149    var = xr.Variable("x", [1, 2], encoding={"foo": (1,)})
5150    with pytest.raises(ValueError, match=r"unexpected encoding parameters"):
5151        actual = backends.zarr.extract_zarr_variable_encoding(
5152            var, raise_on_invalid=True
5153        )
5154
5155
5156@requires_zarr
5157@requires_fsspec
5158@pytest.mark.filterwarnings("ignore:deallocating CachingFileManager")
5159def test_open_fsspec():
5160    import fsspec
5161    import zarr
5162
5163    if not hasattr(zarr.storage, "FSStore") or not hasattr(
5164        zarr.storage.FSStore, "getitems"
5165    ):
5166        pytest.skip("zarr too old")
5167
5168    ds = open_dataset(os.path.join(os.path.dirname(__file__), "data", "example_1.nc"))
5169
5170    m = fsspec.filesystem("memory")
5171    mm = m.get_mapper("out1.zarr")
5172    ds.to_zarr(mm)  # old interface
5173    ds0 = ds.copy()
5174    ds0["time"] = ds.time + pd.to_timedelta("1 day")
5175    mm = m.get_mapper("out2.zarr")
5176    ds0.to_zarr(mm)  # old interface
5177
5178    # single dataset
5179    url = "memory://out2.zarr"
5180    ds2 = open_dataset(url, engine="zarr")
5181    assert ds0 == ds2
5182
5183    # single dataset with caching
5184    url = "simplecache::memory://out2.zarr"
5185    ds2 = open_dataset(url, engine="zarr")
5186    assert ds0 == ds2
5187
5188    # multi dataset
5189    url = "memory://out*.zarr"
5190    ds2 = open_mfdataset(url, engine="zarr")
5191    assert xr.concat([ds, ds0], dim="time") == ds2
5192
5193    # multi dataset with caching
5194    url = "simplecache::memory://out*.zarr"
5195    ds2 = open_mfdataset(url, engine="zarr")
5196    assert xr.concat([ds, ds0], dim="time") == ds2
5197
5198
5199@requires_h5netcdf
5200@requires_netCDF4
5201def test_load_single_value_h5netcdf(tmp_path):
5202    """Test that numeric single-element vector attributes are handled fine.
5203
5204    At present (h5netcdf v0.8.1), the h5netcdf exposes single-valued numeric variable
5205    attributes as arrays of length 1, as opposed to scalars for the NetCDF4
5206    backend.  This was leading to a ValueError upon loading a single value from
5207    a file, see #4471.  Test that loading causes no failure.
5208    """
5209    ds = xr.Dataset(
5210        {
5211            "test": xr.DataArray(
5212                np.array([0]), dims=("x",), attrs={"scale_factor": 1, "add_offset": 0}
5213            )
5214        }
5215    )
5216    ds.to_netcdf(tmp_path / "test.nc")
5217    with xr.open_dataset(tmp_path / "test.nc", engine="h5netcdf") as ds2:
5218        ds2["test"][0].load()
5219
5220
5221@requires_zarr
5222@requires_dask
5223@pytest.mark.parametrize(
5224    "chunks", ["auto", -1, {}, {"x": "auto"}, {"x": -1}, {"x": "auto", "y": -1}]
5225)
5226def test_open_dataset_chunking_zarr(chunks, tmp_path):
5227    encoded_chunks = 100
5228    dask_arr = da.from_array(
5229        np.ones((500, 500), dtype="float64"), chunks=encoded_chunks
5230    )
5231    ds = xr.Dataset(
5232        {
5233            "test": xr.DataArray(
5234                dask_arr,
5235                dims=("x", "y"),
5236            )
5237        }
5238    )
5239    ds["test"].encoding["chunks"] = encoded_chunks
5240    ds.to_zarr(tmp_path / "test.zarr")
5241
5242    with dask.config.set({"array.chunk-size": "1MiB"}):
5243        expected = ds.chunk(chunks)
5244        with open_dataset(
5245            tmp_path / "test.zarr", engine="zarr", chunks=chunks
5246        ) as actual:
5247            xr.testing.assert_chunks_equal(actual, expected)
5248
5249
5250@requires_zarr
5251@requires_dask
5252@pytest.mark.parametrize(
5253    "chunks", ["auto", -1, {}, {"x": "auto"}, {"x": -1}, {"x": "auto", "y": -1}]
5254)
5255@pytest.mark.filterwarnings("ignore:Specified Dask chunks")
5256def test_chunking_consintency(chunks, tmp_path):
5257    encoded_chunks = {}
5258    dask_arr = da.from_array(
5259        np.ones((500, 500), dtype="float64"), chunks=encoded_chunks
5260    )
5261    ds = xr.Dataset(
5262        {
5263            "test": xr.DataArray(
5264                dask_arr,
5265                dims=("x", "y"),
5266            )
5267        }
5268    )
5269    ds["test"].encoding["chunks"] = encoded_chunks
5270    ds.to_zarr(tmp_path / "test.zarr")
5271    ds.to_netcdf(tmp_path / "test.nc")
5272
5273    with dask.config.set({"array.chunk-size": "1MiB"}):
5274        expected = ds.chunk(chunks)
5275        with xr.open_dataset(
5276            tmp_path / "test.zarr", engine="zarr", chunks=chunks
5277        ) as actual:
5278            xr.testing.assert_chunks_equal(actual, expected)
5279
5280        with xr.open_dataset(tmp_path / "test.nc", chunks=chunks) as actual:
5281            xr.testing.assert_chunks_equal(actual, expected)
5282
5283
5284def _check_guess_can_open_and_open(entrypoint, obj, engine, expected):
5285    assert entrypoint.guess_can_open(obj)
5286    with open_dataset(obj, engine=engine) as actual:
5287        assert_identical(expected, actual)
5288
5289
5290@requires_netCDF4
5291def test_netcdf4_entrypoint(tmp_path):
5292    entrypoint = NetCDF4BackendEntrypoint()
5293    ds = create_test_data()
5294
5295    path = tmp_path / "foo"
5296    ds.to_netcdf(path, format="netcdf3_classic")
5297    _check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds)
5298    _check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds)
5299
5300    path = tmp_path / "bar"
5301    ds.to_netcdf(path, format="netcdf4_classic")
5302    _check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds)
5303    _check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds)
5304
5305    assert entrypoint.guess_can_open("http://something/remote")
5306    assert entrypoint.guess_can_open("something-local.nc")
5307    assert entrypoint.guess_can_open("something-local.nc4")
5308    assert entrypoint.guess_can_open("something-local.cdf")
5309    assert not entrypoint.guess_can_open("not-found-and-no-extension")
5310
5311    path = tmp_path / "baz"
5312    with open(path, "wb") as f:
5313        f.write(b"not-a-netcdf-file")
5314    assert not entrypoint.guess_can_open(path)
5315
5316
5317@requires_scipy
5318def test_scipy_entrypoint(tmp_path):
5319    entrypoint = ScipyBackendEntrypoint()
5320    ds = create_test_data()
5321
5322    path = tmp_path / "foo"
5323    ds.to_netcdf(path, engine="scipy")
5324    _check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds)
5325    _check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds)
5326    with open(path, "rb") as f:
5327        _check_guess_can_open_and_open(entrypoint, f, engine="scipy", expected=ds)
5328
5329    contents = ds.to_netcdf(engine="scipy")
5330    _check_guess_can_open_and_open(entrypoint, contents, engine="scipy", expected=ds)
5331    _check_guess_can_open_and_open(
5332        entrypoint, BytesIO(contents), engine="scipy", expected=ds
5333    )
5334
5335    path = tmp_path / "foo.nc.gz"
5336    with gzip.open(path, mode="wb") as f:
5337        f.write(contents)
5338    _check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds)
5339    _check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds)
5340
5341    assert entrypoint.guess_can_open("something-local.nc")
5342    assert entrypoint.guess_can_open("something-local.nc.gz")
5343    assert not entrypoint.guess_can_open("not-found-and-no-extension")
5344    assert not entrypoint.guess_can_open(b"not-a-netcdf-file")
5345
5346
5347@requires_h5netcdf
5348def test_h5netcdf_entrypoint(tmp_path):
5349    entrypoint = H5netcdfBackendEntrypoint()
5350    ds = create_test_data()
5351
5352    path = tmp_path / "foo"
5353    ds.to_netcdf(path, engine="h5netcdf")
5354    _check_guess_can_open_and_open(entrypoint, path, engine="h5netcdf", expected=ds)
5355    _check_guess_can_open_and_open(
5356        entrypoint, str(path), engine="h5netcdf", expected=ds
5357    )
5358    with open(path, "rb") as f:
5359        _check_guess_can_open_and_open(entrypoint, f, engine="h5netcdf", expected=ds)
5360
5361    assert entrypoint.guess_can_open("something-local.nc")
5362    assert entrypoint.guess_can_open("something-local.nc4")
5363    assert entrypoint.guess_can_open("something-local.cdf")
5364    assert not entrypoint.guess_can_open("not-found-and-no-extension")
5365