1import pytest
2
3from pint import UnitRegistry
4
5# Conditionally import NumPy and any upcast type libraries
6np = pytest.importorskip("numpy", reason="NumPy is not available")
7xr = pytest.importorskip("xarray", reason="xarray is not available")
8
9# Set up unit registry and sample
10ureg = UnitRegistry()
11q = [[1.0, 2.0], [3.0, 4.0]] * ureg.m
12
13
14@pytest.fixture
15def da():
16    return xr.DataArray(q.copy())
17
18
19@pytest.fixture
20def ds():
21    return xr.Dataset(
22        {
23            "a": (("x", "y"), [[0, 1], [2, 3], [4, 5]], {"units": "K"}),
24            "b": ("x", [0, 2, 4], {"units": "degC"}),
25            "c": ("y", [-1, 1], {"units": "hPa"}),
26        },
27        coords={
28            "x": ("x", [-1, 0, 1], {"units": "degree"}),
29            "y": ("y", [0, 1], {"units": "degree"}),
30        },
31    )
32
33
34def test_xarray_quantity_creation():
35    with pytest.raises(TypeError) as exc:
36        ureg.Quantity(xr.DataArray(np.arange(4)), "m")
37        assert "Quantity cannot wrap upcast type" in str(exc)
38    assert xr.DataArray(q).data is q
39
40
41def test_quantification(ds):
42    da = ds["a"]
43    da.data = ureg.Quantity(da.values, da.attrs.pop("units"))
44    mean = da.mean().item()
45    assert mean.units == ureg.K
46    assert np.isclose(mean, 2.5 * ureg.K)
47
48
49@pytest.mark.parametrize(
50    "op",
51    [
52        lambda x, y: x + y,
53        lambda x, y: x - (-y),
54        lambda x, y: x * y,
55        lambda x, y: x / (y ** -1),
56    ],
57)
58@pytest.mark.parametrize(
59    "pair",
60    [
61        (q, xr.DataArray(q)),
62        (
63            xr.DataArray([1.0, 2.0] * ureg.m, dims=("y",)),
64            xr.DataArray(
65                np.arange(6, dtype="float").reshape(3, 2, 1), dims=("z", "y", "x")
66            )
67            * ureg.km,
68        ),
69        (1 * ureg.m, xr.DataArray(q)),
70    ],
71)
72def test_binary_arithmetic_commutativity(op, pair):
73    z0 = op(*pair)
74    z1 = op(*pair[::-1])
75    z1 = z1.transpose(*z0.dims)
76    assert np.all(np.isclose(z0.data, z1.data.to(z0.data.units)))
77
78
79def test_eq_commutativity(da):
80    assert np.all((q.T == da) == (da.transpose() == q))
81
82
83def test_ne_commutativity(da):
84    assert np.all((q != da.transpose()) == (da != q.T))
85
86
87def test_dataset_operation_with_unit(ds):
88    ds0 = ureg.K * ds.isel(x=0)
89    ds1 = (ds * ureg.K).isel(x=0)
90    xr.testing.assert_identical(ds0, ds1)
91    assert np.isclose(ds0["a"].mean().item(), 0.5 * ureg.K)
92
93
94def test_dataarray_inplace_arithmetic_roundtrip(da):
95    da_original = da.copy()
96    q_to_modify = q.copy()
97    da += q
98    xr.testing.assert_identical(da, xr.DataArray([[2, 4], [6, 8]] * ureg.m))
99    da -= q
100    xr.testing.assert_identical(da, da_original)
101    da *= ureg.m
102    xr.testing.assert_identical(da, xr.DataArray(q * ureg.m))
103    da /= ureg.m
104    xr.testing.assert_identical(da, da_original)
105    # Operating inplace with DataArray converts to DataArray
106    q_to_modify += da
107    q_to_modify -= da
108    assert np.all(np.isclose(q_to_modify.data, q))
109
110
111def test_dataarray_inequalities(da):
112    xr.testing.assert_identical(
113        2 * ureg.m > da, xr.DataArray([[True, False], [False, False]])
114    )
115    xr.testing.assert_identical(
116        2 * ureg.m < da, xr.DataArray([[False, False], [True, True]])
117    )
118    with pytest.raises(ValueError) as exc:
119        da > 2
120        assert "Cannot compare Quantity and <class 'int'>" in str(exc)
121
122
123def test_array_function_deferral(da):
124    lower = 2 * ureg.m
125    upper = 3 * ureg.m
126    args = (da, lower, upper)
127    assert (
128        lower.__array_function__(
129            np.clip, tuple(set(type(arg) for arg in args)), args, {}
130        )
131        is NotImplemented
132    )
133
134
135def test_array_ufunc_deferral(da):
136    lower = 2 * ureg.m
137    assert lower.__array_ufunc__(np.maximum, "__call__", lower, da) is NotImplemented
138