1import importlib
2import platform
3import warnings
4from contextlib import contextmanager
5from distutils import version
6from unittest import mock  # noqa: F401
7
8import numpy as np
9import pandas as pd
10import pytest
11from numpy.testing import assert_array_equal  # noqa: F401
12from pandas.testing import assert_frame_equal  # noqa: F401
13
14import xarray.testing
15from xarray import Dataset
16from xarray.core import utils
17from xarray.core.duck_array_ops import allclose_or_equiv  # noqa: F401
18from xarray.core.indexing import ExplicitlyIndexed
19from xarray.core.options import set_options
20from xarray.testing import (  # noqa: F401
21    assert_chunks_equal,
22    assert_duckarray_allclose,
23    assert_duckarray_equal,
24)
25
26# import mpl and change the backend before other mpl imports
27try:
28    import matplotlib as mpl
29
30    # Order of imports is important here.
31    # Using a different backend makes Travis CI work
32    mpl.use("Agg")
33except ImportError:
34    pass
35
36
37arm_xfail = pytest.mark.xfail(
38    platform.machine() == "aarch64" or "arm" in platform.machine(),
39    reason="expected failure on ARM",
40)
41
42
43def _importorskip(modname, minversion=None):
44    try:
45        mod = importlib.import_module(modname)
46        has = True
47        if minversion is not None:
48            if LooseVersion(mod.__version__) < LooseVersion(minversion):
49                raise ImportError("Minimum version not satisfied")
50    except ImportError:
51        has = False
52    func = pytest.mark.skipif(not has, reason=f"requires {modname}")
53    return has, func
54
55
56def LooseVersion(vstring):
57    # Our development version is something like '0.10.9+aac7bfc'
58    # This function just ignored the git commit id.
59    vstring = vstring.split("+")[0]
60    return version.LooseVersion(vstring)
61
62
63has_matplotlib, requires_matplotlib = _importorskip("matplotlib")
64has_scipy, requires_scipy = _importorskip("scipy")
65has_pydap, requires_pydap = _importorskip("pydap.client")
66has_netCDF4, requires_netCDF4 = _importorskip("netCDF4")
67has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf")
68has_pynio, requires_pynio = _importorskip("Nio")
69has_pseudonetcdf, requires_pseudonetcdf = _importorskip("PseudoNetCDF")
70has_cftime, requires_cftime = _importorskip("cftime")
71has_cftime_1_4_1, requires_cftime_1_4_1 = _importorskip("cftime", minversion="1.4.1")
72has_dask, requires_dask = _importorskip("dask")
73has_bottleneck, requires_bottleneck = _importorskip("bottleneck")
74has_nc_time_axis, requires_nc_time_axis = _importorskip("nc_time_axis")
75has_rasterio, requires_rasterio = _importorskip("rasterio")
76has_zarr, requires_zarr = _importorskip("zarr")
77has_fsspec, requires_fsspec = _importorskip("fsspec")
78has_iris, requires_iris = _importorskip("iris")
79has_cfgrib, requires_cfgrib = _importorskip("cfgrib")
80has_numbagg, requires_numbagg = _importorskip("numbagg")
81has_seaborn, requires_seaborn = _importorskip("seaborn")
82has_sparse, requires_sparse = _importorskip("sparse")
83has_cupy, requires_cupy = _importorskip("cupy")
84has_cartopy, requires_cartopy = _importorskip("cartopy")
85has_pint, requires_pint = _importorskip("pint")
86has_numexpr, requires_numexpr = _importorskip("numexpr")
87
88# some special cases
89has_scipy_or_netCDF4 = has_scipy or has_netCDF4
90requires_scipy_or_netCDF4 = pytest.mark.skipif(
91    not has_scipy_or_netCDF4, reason="requires scipy or netCDF4"
92)
93
94# change some global options for tests
95set_options(warn_for_unclosed_files=True)
96
97if has_dask:
98    import dask
99
100    dask.config.set(scheduler="single-threaded")
101
102
103class CountingScheduler:
104    """Simple dask scheduler counting the number of computes.
105
106    Reference: https://stackoverflow.com/questions/53289286/"""
107
108    def __init__(self, max_computes=0):
109        self.total_computes = 0
110        self.max_computes = max_computes
111
112    def __call__(self, dsk, keys, **kwargs):
113        self.total_computes += 1
114        if self.total_computes > self.max_computes:
115            raise RuntimeError(
116                "Too many computes. Total: %d > max: %d."
117                % (self.total_computes, self.max_computes)
118            )
119        return dask.get(dsk, keys, **kwargs)
120
121
122@contextmanager
123def dummy_context():
124    yield None
125
126
127def raise_if_dask_computes(max_computes=0):
128    # return a dummy context manager so that this can be used for non-dask objects
129    if not has_dask:
130        return dummy_context()
131    scheduler = CountingScheduler(max_computes)
132    return dask.config.set(scheduler=scheduler)
133
134
135flaky = pytest.mark.flaky
136network = pytest.mark.network
137
138
139class UnexpectedDataAccess(Exception):
140    pass
141
142
143class InaccessibleArray(utils.NDArrayMixin, ExplicitlyIndexed):
144    def __init__(self, array):
145        self.array = array
146
147    def __getitem__(self, key):
148        raise UnexpectedDataAccess("Tried accessing data")
149
150
151class ReturnItem:
152    def __getitem__(self, key):
153        return key
154
155
156class IndexerMaker:
157    def __init__(self, indexer_cls):
158        self._indexer_cls = indexer_cls
159
160    def __getitem__(self, key):
161        if not isinstance(key, tuple):
162            key = (key,)
163        return self._indexer_cls(key)
164
165
166def source_ndarray(array):
167    """Given an ndarray, return the base object which holds its memory, or the
168    object itself.
169    """
170    with warnings.catch_warnings():
171        warnings.filterwarnings("ignore", "DatetimeIndex.base")
172        warnings.filterwarnings("ignore", "TimedeltaIndex.base")
173        base = getattr(array, "base", np.asarray(array).base)
174    if base is None:
175        base = array
176    return base
177
178
179# Internal versions of xarray's test functions that validate additional
180# invariants
181
182
183def assert_equal(a, b):
184    __tracebackhide__ = True
185    xarray.testing.assert_equal(a, b)
186    xarray.testing._assert_internal_invariants(a)
187    xarray.testing._assert_internal_invariants(b)
188
189
190def assert_identical(a, b):
191    __tracebackhide__ = True
192    xarray.testing.assert_identical(a, b)
193    xarray.testing._assert_internal_invariants(a)
194    xarray.testing._assert_internal_invariants(b)
195
196
197def assert_allclose(a, b, **kwargs):
198    __tracebackhide__ = True
199    xarray.testing.assert_allclose(a, b, **kwargs)
200    xarray.testing._assert_internal_invariants(a)
201    xarray.testing._assert_internal_invariants(b)
202
203
204def create_test_data(seed=None, add_attrs=True):
205    rs = np.random.RandomState(seed)
206    _vars = {
207        "var1": ["dim1", "dim2"],
208        "var2": ["dim1", "dim2"],
209        "var3": ["dim3", "dim1"],
210    }
211    _dims = {"dim1": 8, "dim2": 9, "dim3": 10}
212
213    obj = Dataset()
214    obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"]))
215    obj["dim3"] = ("dim3", list("abcdefghij"))
216    obj["time"] = ("time", pd.date_range("2000-01-01", periods=20))
217    for v, dims in sorted(_vars.items()):
218        data = rs.normal(size=tuple(_dims[d] for d in dims))
219        obj[v] = (dims, data)
220        if add_attrs:
221            obj[v].attrs = {"foo": "variable"}
222    obj.coords["numbers"] = (
223        "dim3",
224        np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64"),
225    )
226    obj.encoding = {"foo": "bar"}
227    assert all(obj.data.flags.writeable for obj in obj.variables.values())
228    return obj
229
230
231_CFTIME_CALENDARS = [
232    "365_day",
233    "360_day",
234    "julian",
235    "all_leap",
236    "366_day",
237    "gregorian",
238    "proleptic_gregorian",
239    "standard",
240]
241