1import warnings
2
3import numpy as np
4import pytest
5
6import xarray as xr
7
8from . import has_dask
9
10try:
11    from dask.array import from_array as dask_from_array
12except ImportError:
13    dask_from_array = lambda x: x
14
15try:
16    import pint
17
18    unit_registry = pint.UnitRegistry(force_ndarray_like=True)
19
20    def quantity(x):
21        return unit_registry.Quantity(x, "m")
22
23    has_pint = True
24except ImportError:
25
26    def quantity(x):
27        return x
28
29    has_pint = False
30
31
32def test_allclose_regression() -> None:
33    x = xr.DataArray(1.01)
34    y = xr.DataArray(1.02)
35    xr.testing.assert_allclose(x, y, atol=0.01)
36
37
38@pytest.mark.parametrize(
39    "obj1,obj2",
40    (
41        pytest.param(
42            xr.Variable("x", [1e-17, 2]), xr.Variable("x", [0, 3]), id="Variable"
43        ),
44        pytest.param(
45            xr.DataArray([1e-17, 2], dims="x"),
46            xr.DataArray([0, 3], dims="x"),
47            id="DataArray",
48        ),
49        pytest.param(
50            xr.Dataset({"a": ("x", [1e-17, 2]), "b": ("y", [-2e-18, 2])}),
51            xr.Dataset({"a": ("x", [0, 2]), "b": ("y", [0, 1])}),
52            id="Dataset",
53        ),
54    ),
55)
56def test_assert_allclose(obj1, obj2) -> None:
57    with pytest.raises(AssertionError):
58        xr.testing.assert_allclose(obj1, obj2)
59
60
61@pytest.mark.filterwarnings("error")
62@pytest.mark.parametrize(
63    "duckarray",
64    (
65        pytest.param(np.array, id="numpy"),
66        pytest.param(
67            dask_from_array,
68            id="dask",
69            marks=pytest.mark.skipif(not has_dask, reason="requires dask"),
70        ),
71        pytest.param(
72            quantity,
73            id="pint",
74            marks=pytest.mark.skipif(not has_pint, reason="requires pint"),
75        ),
76    ),
77)
78@pytest.mark.parametrize(
79    ["obj1", "obj2"],
80    (
81        pytest.param([1e-10, 2], [0.0, 2.0], id="both arrays"),
82        pytest.param([1e-17, 2], 0.0, id="second scalar"),
83        pytest.param(0.0, [1e-17, 2], id="first scalar"),
84    ),
85)
86def test_assert_duckarray_equal_failing(duckarray, obj1, obj2) -> None:
87    # TODO: actually check the repr
88    a = duckarray(obj1)
89    b = duckarray(obj2)
90    with pytest.raises(AssertionError):
91        xr.testing.assert_duckarray_equal(a, b)
92
93
94@pytest.mark.filterwarnings("error")
95@pytest.mark.parametrize(
96    "duckarray",
97    (
98        pytest.param(
99            np.array,
100            id="numpy",
101        ),
102        pytest.param(
103            dask_from_array,
104            id="dask",
105            marks=pytest.mark.skipif(not has_dask, reason="requires dask"),
106        ),
107        pytest.param(
108            quantity,
109            id="pint",
110            marks=pytest.mark.skipif(not has_pint, reason="requires pint"),
111        ),
112    ),
113)
114@pytest.mark.parametrize(
115    ["obj1", "obj2"],
116    (
117        pytest.param([0, 2], [0.0, 2.0], id="both arrays"),
118        pytest.param([0, 0], 0.0, id="second scalar"),
119        pytest.param(0.0, [0, 0], id="first scalar"),
120    ),
121)
122def test_assert_duckarray_equal(duckarray, obj1, obj2) -> None:
123    a = duckarray(obj1)
124    b = duckarray(obj2)
125
126    xr.testing.assert_duckarray_equal(a, b)
127
128
129@pytest.mark.parametrize(
130    "func",
131    [
132        "assert_equal",
133        "assert_identical",
134        "assert_allclose",
135        "assert_duckarray_equal",
136        "assert_duckarray_allclose",
137    ],
138)
139def test_ensure_warnings_not_elevated(func) -> None:
140    # make sure warnings are not elevated to errors in the assertion functions
141    # e.g. by @pytest.mark.filterwarnings("error")
142    # see https://github.com/pydata/xarray/pull/4760#issuecomment-774101639
143
144    # define a custom Variable class that raises a warning in assert_*
145    class WarningVariable(xr.Variable):
146        @property  # type: ignore[misc]
147        def dims(self):
148            warnings.warn("warning in test")
149            return super().dims
150
151        def __array__(self):
152            warnings.warn("warning in test")
153            return super().__array__()
154
155    a = WarningVariable("x", [1])
156    b = WarningVariable("x", [2])
157
158    with warnings.catch_warnings(record=True) as w:
159        # elevate warnings to errors
160        warnings.filterwarnings("error")
161        with pytest.raises(AssertionError):
162            getattr(xr.testing, func)(a, b)
163
164        assert len(w) > 0
165