1import pytest 2 3from pint import UnitRegistry 4 5# Conditionally import NumPy and any upcast type libraries 6np = pytest.importorskip("numpy", reason="NumPy is not available") 7xr = pytest.importorskip("xarray", reason="xarray is not available") 8 9# Set up unit registry and sample 10ureg = UnitRegistry() 11q = [[1.0, 2.0], [3.0, 4.0]] * ureg.m 12 13 14@pytest.fixture 15def da(): 16 return xr.DataArray(q.copy()) 17 18 19@pytest.fixture 20def ds(): 21 return xr.Dataset( 22 { 23 "a": (("x", "y"), [[0, 1], [2, 3], [4, 5]], {"units": "K"}), 24 "b": ("x", [0, 2, 4], {"units": "degC"}), 25 "c": ("y", [-1, 1], {"units": "hPa"}), 26 }, 27 coords={ 28 "x": ("x", [-1, 0, 1], {"units": "degree"}), 29 "y": ("y", [0, 1], {"units": "degree"}), 30 }, 31 ) 32 33 34def test_xarray_quantity_creation(): 35 with pytest.raises(TypeError) as exc: 36 ureg.Quantity(xr.DataArray(np.arange(4)), "m") 37 assert "Quantity cannot wrap upcast type" in str(exc) 38 assert xr.DataArray(q).data is q 39 40 41def test_quantification(ds): 42 da = ds["a"] 43 da.data = ureg.Quantity(da.values, da.attrs.pop("units")) 44 mean = da.mean().item() 45 assert mean.units == ureg.K 46 assert np.isclose(mean, 2.5 * ureg.K) 47 48 49@pytest.mark.parametrize( 50 "op", 51 [ 52 lambda x, y: x + y, 53 lambda x, y: x - (-y), 54 lambda x, y: x * y, 55 lambda x, y: x / (y ** -1), 56 ], 57) 58@pytest.mark.parametrize( 59 "pair", 60 [ 61 (q, xr.DataArray(q)), 62 ( 63 xr.DataArray([1.0, 2.0] * ureg.m, dims=("y",)), 64 xr.DataArray( 65 np.arange(6, dtype="float").reshape(3, 2, 1), dims=("z", "y", "x") 66 ) 67 * ureg.km, 68 ), 69 (1 * ureg.m, xr.DataArray(q)), 70 ], 71) 72def test_binary_arithmetic_commutativity(op, pair): 73 z0 = op(*pair) 74 z1 = op(*pair[::-1]) 75 z1 = z1.transpose(*z0.dims) 76 assert np.all(np.isclose(z0.data, z1.data.to(z0.data.units))) 77 78 79def test_eq_commutativity(da): 80 assert np.all((q.T == da) == (da.transpose() == q)) 81 82 83def test_ne_commutativity(da): 84 assert np.all((q != da.transpose()) == (da != q.T)) 85 86 87def test_dataset_operation_with_unit(ds): 88 ds0 = ureg.K * ds.isel(x=0) 89 ds1 = (ds * ureg.K).isel(x=0) 90 xr.testing.assert_identical(ds0, ds1) 91 assert np.isclose(ds0["a"].mean().item(), 0.5 * ureg.K) 92 93 94def test_dataarray_inplace_arithmetic_roundtrip(da): 95 da_original = da.copy() 96 q_to_modify = q.copy() 97 da += q 98 xr.testing.assert_identical(da, xr.DataArray([[2, 4], [6, 8]] * ureg.m)) 99 da -= q 100 xr.testing.assert_identical(da, da_original) 101 da *= ureg.m 102 xr.testing.assert_identical(da, xr.DataArray(q * ureg.m)) 103 da /= ureg.m 104 xr.testing.assert_identical(da, da_original) 105 # Operating inplace with DataArray converts to DataArray 106 q_to_modify += da 107 q_to_modify -= da 108 assert np.all(np.isclose(q_to_modify.data, q)) 109 110 111def test_dataarray_inequalities(da): 112 xr.testing.assert_identical( 113 2 * ureg.m > da, xr.DataArray([[True, False], [False, False]]) 114 ) 115 xr.testing.assert_identical( 116 2 * ureg.m < da, xr.DataArray([[False, False], [True, True]]) 117 ) 118 with pytest.raises(ValueError) as exc: 119 da > 2 120 assert "Cannot compare Quantity and <class 'int'>" in str(exc) 121 122 123def test_array_function_deferral(da): 124 lower = 2 * ureg.m 125 upper = 3 * ureg.m 126 args = (da, lower, upper) 127 assert ( 128 lower.__array_function__( 129 np.clip, tuple(set(type(arg) for arg in args)), args, {} 130 ) 131 is NotImplemented 132 ) 133 134 135def test_array_ufunc_deferral(da): 136 lower = 2 * ureg.m 137 assert lower.__array_ufunc__(np.maximum, "__call__", lower, da) is NotImplemented 138