1import importlib 2import platform 3import warnings 4from contextlib import contextmanager 5from distutils import version 6from unittest import mock # noqa: F401 7 8import numpy as np 9import pandas as pd 10import pytest 11from numpy.testing import assert_array_equal # noqa: F401 12from pandas.testing import assert_frame_equal # noqa: F401 13 14import xarray.testing 15from xarray import Dataset 16from xarray.core import utils 17from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401 18from xarray.core.indexing import ExplicitlyIndexed 19from xarray.core.options import set_options 20from xarray.testing import ( # noqa: F401 21 assert_chunks_equal, 22 assert_duckarray_allclose, 23 assert_duckarray_equal, 24) 25 26# import mpl and change the backend before other mpl imports 27try: 28 import matplotlib as mpl 29 30 # Order of imports is important here. 31 # Using a different backend makes Travis CI work 32 mpl.use("Agg") 33except ImportError: 34 pass 35 36 37arm_xfail = pytest.mark.xfail( 38 platform.machine() == "aarch64" or "arm" in platform.machine(), 39 reason="expected failure on ARM", 40) 41 42 43def _importorskip(modname, minversion=None): 44 try: 45 mod = importlib.import_module(modname) 46 has = True 47 if minversion is not None: 48 if LooseVersion(mod.__version__) < LooseVersion(minversion): 49 raise ImportError("Minimum version not satisfied") 50 except ImportError: 51 has = False 52 func = pytest.mark.skipif(not has, reason=f"requires {modname}") 53 return has, func 54 55 56def LooseVersion(vstring): 57 # Our development version is something like '0.10.9+aac7bfc' 58 # This function just ignored the git commit id. 59 vstring = vstring.split("+")[0] 60 return version.LooseVersion(vstring) 61 62 63has_matplotlib, requires_matplotlib = _importorskip("matplotlib") 64has_scipy, requires_scipy = _importorskip("scipy") 65has_pydap, requires_pydap = _importorskip("pydap.client") 66has_netCDF4, requires_netCDF4 = _importorskip("netCDF4") 67has_h5netcdf, requires_h5netcdf = _importorskip("h5netcdf") 68has_pynio, requires_pynio = _importorskip("Nio") 69has_pseudonetcdf, requires_pseudonetcdf = _importorskip("PseudoNetCDF") 70has_cftime, requires_cftime = _importorskip("cftime") 71has_cftime_1_4_1, requires_cftime_1_4_1 = _importorskip("cftime", minversion="1.4.1") 72has_dask, requires_dask = _importorskip("dask") 73has_bottleneck, requires_bottleneck = _importorskip("bottleneck") 74has_nc_time_axis, requires_nc_time_axis = _importorskip("nc_time_axis") 75has_rasterio, requires_rasterio = _importorskip("rasterio") 76has_zarr, requires_zarr = _importorskip("zarr") 77has_fsspec, requires_fsspec = _importorskip("fsspec") 78has_iris, requires_iris = _importorskip("iris") 79has_cfgrib, requires_cfgrib = _importorskip("cfgrib") 80has_numbagg, requires_numbagg = _importorskip("numbagg") 81has_seaborn, requires_seaborn = _importorskip("seaborn") 82has_sparse, requires_sparse = _importorskip("sparse") 83has_cupy, requires_cupy = _importorskip("cupy") 84has_cartopy, requires_cartopy = _importorskip("cartopy") 85has_pint, requires_pint = _importorskip("pint") 86has_numexpr, requires_numexpr = _importorskip("numexpr") 87 88# some special cases 89has_scipy_or_netCDF4 = has_scipy or has_netCDF4 90requires_scipy_or_netCDF4 = pytest.mark.skipif( 91 not has_scipy_or_netCDF4, reason="requires scipy or netCDF4" 92) 93 94# change some global options for tests 95set_options(warn_for_unclosed_files=True) 96 97if has_dask: 98 import dask 99 100 dask.config.set(scheduler="single-threaded") 101 102 103class CountingScheduler: 104 """Simple dask scheduler counting the number of computes. 105 106 Reference: https://stackoverflow.com/questions/53289286/""" 107 108 def __init__(self, max_computes=0): 109 self.total_computes = 0 110 self.max_computes = max_computes 111 112 def __call__(self, dsk, keys, **kwargs): 113 self.total_computes += 1 114 if self.total_computes > self.max_computes: 115 raise RuntimeError( 116 "Too many computes. Total: %d > max: %d." 117 % (self.total_computes, self.max_computes) 118 ) 119 return dask.get(dsk, keys, **kwargs) 120 121 122@contextmanager 123def dummy_context(): 124 yield None 125 126 127def raise_if_dask_computes(max_computes=0): 128 # return a dummy context manager so that this can be used for non-dask objects 129 if not has_dask: 130 return dummy_context() 131 scheduler = CountingScheduler(max_computes) 132 return dask.config.set(scheduler=scheduler) 133 134 135flaky = pytest.mark.flaky 136network = pytest.mark.network 137 138 139class UnexpectedDataAccess(Exception): 140 pass 141 142 143class InaccessibleArray(utils.NDArrayMixin, ExplicitlyIndexed): 144 def __init__(self, array): 145 self.array = array 146 147 def __getitem__(self, key): 148 raise UnexpectedDataAccess("Tried accessing data") 149 150 151class ReturnItem: 152 def __getitem__(self, key): 153 return key 154 155 156class IndexerMaker: 157 def __init__(self, indexer_cls): 158 self._indexer_cls = indexer_cls 159 160 def __getitem__(self, key): 161 if not isinstance(key, tuple): 162 key = (key,) 163 return self._indexer_cls(key) 164 165 166def source_ndarray(array): 167 """Given an ndarray, return the base object which holds its memory, or the 168 object itself. 169 """ 170 with warnings.catch_warnings(): 171 warnings.filterwarnings("ignore", "DatetimeIndex.base") 172 warnings.filterwarnings("ignore", "TimedeltaIndex.base") 173 base = getattr(array, "base", np.asarray(array).base) 174 if base is None: 175 base = array 176 return base 177 178 179# Internal versions of xarray's test functions that validate additional 180# invariants 181 182 183def assert_equal(a, b): 184 __tracebackhide__ = True 185 xarray.testing.assert_equal(a, b) 186 xarray.testing._assert_internal_invariants(a) 187 xarray.testing._assert_internal_invariants(b) 188 189 190def assert_identical(a, b): 191 __tracebackhide__ = True 192 xarray.testing.assert_identical(a, b) 193 xarray.testing._assert_internal_invariants(a) 194 xarray.testing._assert_internal_invariants(b) 195 196 197def assert_allclose(a, b, **kwargs): 198 __tracebackhide__ = True 199 xarray.testing.assert_allclose(a, b, **kwargs) 200 xarray.testing._assert_internal_invariants(a) 201 xarray.testing._assert_internal_invariants(b) 202 203 204def create_test_data(seed=None, add_attrs=True): 205 rs = np.random.RandomState(seed) 206 _vars = { 207 "var1": ["dim1", "dim2"], 208 "var2": ["dim1", "dim2"], 209 "var3": ["dim3", "dim1"], 210 } 211 _dims = {"dim1": 8, "dim2": 9, "dim3": 10} 212 213 obj = Dataset() 214 obj["dim2"] = ("dim2", 0.5 * np.arange(_dims["dim2"])) 215 obj["dim3"] = ("dim3", list("abcdefghij")) 216 obj["time"] = ("time", pd.date_range("2000-01-01", periods=20)) 217 for v, dims in sorted(_vars.items()): 218 data = rs.normal(size=tuple(_dims[d] for d in dims)) 219 obj[v] = (dims, data) 220 if add_attrs: 221 obj[v].attrs = {"foo": "variable"} 222 obj.coords["numbers"] = ( 223 "dim3", 224 np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64"), 225 ) 226 obj.encoding = {"foo": "bar"} 227 assert all(obj.data.flags.writeable for obj in obj.variables.values()) 228 return obj 229 230 231_CFTIME_CALENDARS = [ 232 "365_day", 233 "360_day", 234 "julian", 235 "all_leap", 236 "366_day", 237 "gregorian", 238 "proleptic_gregorian", 239 "standard", 240] 241