1import numpy as np 2import pytest 3 4import xarray as xr 5from xarray import DataArray 6from xarray.tests import assert_allclose, assert_equal 7 8from . import raise_if_dask_computes, requires_cftime, requires_dask 9 10 11@pytest.mark.parametrize("as_dataset", (True, False)) 12def test_weighted_non_DataArray_weights(as_dataset): 13 14 data = DataArray([1, 2]) 15 if as_dataset: 16 data = data.to_dataset(name="data") 17 18 with pytest.raises(ValueError, match=r"`weights` must be a DataArray"): 19 data.weighted([1, 2]) 20 21 22@pytest.mark.parametrize("as_dataset", (True, False)) 23@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan])) 24def test_weighted_weights_nan_raises(as_dataset, weights): 25 26 data = DataArray([1, 2]) 27 if as_dataset: 28 data = data.to_dataset(name="data") 29 30 with pytest.raises(ValueError, match="`weights` cannot contain missing values."): 31 data.weighted(DataArray(weights)) 32 33 34@requires_dask 35@pytest.mark.parametrize("as_dataset", (True, False)) 36@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan])) 37def test_weighted_weights_nan_raises_dask(as_dataset, weights): 38 39 data = DataArray([1, 2]).chunk({"dim_0": -1}) 40 if as_dataset: 41 data = data.to_dataset(name="data") 42 43 weights = DataArray(weights).chunk({"dim_0": -1}) 44 45 with raise_if_dask_computes(): 46 weighted = data.weighted(weights) 47 48 with pytest.raises(ValueError, match="`weights` cannot contain missing values."): 49 weighted.sum().load() 50 51 52@requires_cftime 53@requires_dask 54@pytest.mark.parametrize("time_chunks", (1, 5)) 55@pytest.mark.parametrize("resample_spec", ("1AS", "5AS", "10AS")) 56def test_weighted_lazy_resample(time_chunks, resample_spec): 57 # https://github.com/pydata/xarray/issues/4625 58 59 # simple customized weighted mean function 60 def mean_func(ds): 61 return ds.weighted(ds.weights).mean("time") 62 63 # example dataset 64 t = xr.cftime_range(start="2000", periods=20, freq="1AS") 65 weights = xr.DataArray(np.random.rand(len(t)), dims=["time"], coords={"time": t}) 66 data = xr.DataArray( 67 np.random.rand(len(t)), dims=["time"], coords={"time": t, "weights": weights} 68 ) 69 ds = xr.Dataset({"data": data}).chunk({"time": time_chunks}) 70 71 with raise_if_dask_computes(): 72 ds.resample(time=resample_spec).map(mean_func) 73 74 75@pytest.mark.parametrize( 76 ("weights", "expected"), 77 (([1, 2], 3), ([2, 0], 2), ([0, 0], np.nan), ([-1, 1], np.nan)), 78) 79def test_weighted_sum_of_weights_no_nan(weights, expected): 80 81 da = DataArray([1, 2]) 82 weights = DataArray(weights) 83 result = da.weighted(weights).sum_of_weights() 84 85 expected = DataArray(expected) 86 87 assert_equal(expected, result) 88 89 90@pytest.mark.parametrize( 91 ("weights", "expected"), 92 (([1, 2], 2), ([2, 0], np.nan), ([0, 0], np.nan), ([-1, 1], 1)), 93) 94def test_weighted_sum_of_weights_nan(weights, expected): 95 96 da = DataArray([np.nan, 2]) 97 weights = DataArray(weights) 98 result = da.weighted(weights).sum_of_weights() 99 100 expected = DataArray(expected) 101 102 assert_equal(expected, result) 103 104 105def test_weighted_sum_of_weights_bool(): 106 # https://github.com/pydata/xarray/issues/4074 107 108 da = DataArray([1, 2]) 109 weights = DataArray([True, True]) 110 result = da.weighted(weights).sum_of_weights() 111 112 expected = DataArray(2) 113 114 assert_equal(expected, result) 115 116 117@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan])) 118@pytest.mark.parametrize("factor", [0, 1, 3.14]) 119@pytest.mark.parametrize("skipna", (True, False)) 120def test_weighted_sum_equal_weights(da, factor, skipna): 121 # if all weights are 'f'; weighted sum is f times the ordinary sum 122 123 da = DataArray(da) 124 weights = xr.full_like(da, factor) 125 126 expected = da.sum(skipna=skipna) * factor 127 result = da.weighted(weights).sum(skipna=skipna) 128 129 assert_equal(expected, result) 130 131 132@pytest.mark.parametrize( 133 ("weights", "expected"), (([1, 2], 5), ([0, 2], 4), ([0, 0], 0)) 134) 135def test_weighted_sum_no_nan(weights, expected): 136 137 da = DataArray([1, 2]) 138 139 weights = DataArray(weights) 140 result = da.weighted(weights).sum() 141 expected = DataArray(expected) 142 143 assert_equal(expected, result) 144 145 146@pytest.mark.parametrize( 147 ("weights", "expected"), (([1, 2], 4), ([0, 2], 4), ([1, 0], 0), ([0, 0], 0)) 148) 149@pytest.mark.parametrize("skipna", (True, False)) 150def test_weighted_sum_nan(weights, expected, skipna): 151 152 da = DataArray([np.nan, 2]) 153 154 weights = DataArray(weights) 155 result = da.weighted(weights).sum(skipna=skipna) 156 157 if skipna: 158 expected = DataArray(expected) 159 else: 160 expected = DataArray(np.nan) 161 162 assert_equal(expected, result) 163 164 165@pytest.mark.filterwarnings("error") 166@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan])) 167@pytest.mark.parametrize("skipna", (True, False)) 168@pytest.mark.parametrize("factor", [1, 2, 3.14]) 169def test_weighted_mean_equal_weights(da, skipna, factor): 170 # if all weights are equal (!= 0), should yield the same result as mean 171 172 da = DataArray(da) 173 174 # all weights as 1. 175 weights = xr.full_like(da, factor) 176 177 expected = da.mean(skipna=skipna) 178 result = da.weighted(weights).mean(skipna=skipna) 179 180 assert_equal(expected, result) 181 182 183@pytest.mark.parametrize( 184 ("weights", "expected"), (([4, 6], 1.6), ([1, 0], 1.0), ([0, 0], np.nan)) 185) 186def test_weighted_mean_no_nan(weights, expected): 187 188 da = DataArray([1, 2]) 189 weights = DataArray(weights) 190 expected = DataArray(expected) 191 192 result = da.weighted(weights).mean() 193 194 assert_equal(expected, result) 195 196 197@pytest.mark.parametrize( 198 ("weights", "expected"), (([4, 6], 2.0), ([1, 0], np.nan), ([0, 0], np.nan)) 199) 200@pytest.mark.parametrize("skipna", (True, False)) 201def test_weighted_mean_nan(weights, expected, skipna): 202 203 da = DataArray([np.nan, 2]) 204 weights = DataArray(weights) 205 206 if skipna: 207 expected = DataArray(expected) 208 else: 209 expected = DataArray(np.nan) 210 211 result = da.weighted(weights).mean(skipna=skipna) 212 213 assert_equal(expected, result) 214 215 216def test_weighted_mean_bool(): 217 # https://github.com/pydata/xarray/issues/4074 218 da = DataArray([1, 1]) 219 weights = DataArray([True, True]) 220 expected = DataArray(1) 221 222 result = da.weighted(weights).mean() 223 224 assert_equal(expected, result) 225 226 227@pytest.mark.parametrize( 228 ("weights", "expected"), 229 (([1, 2], 2 / 3), ([2, 0], 0), ([0, 0], 0), ([-1, 1], 0)), 230) 231def test_weighted_sum_of_squares_no_nan(weights, expected): 232 233 da = DataArray([1, 2]) 234 weights = DataArray(weights) 235 result = da.weighted(weights).sum_of_squares() 236 237 expected = DataArray(expected) 238 239 assert_equal(expected, result) 240 241 242@pytest.mark.parametrize( 243 ("weights", "expected"), 244 (([1, 2], 0), ([2, 0], 0), ([0, 0], 0), ([-1, 1], 0)), 245) 246def test_weighted_sum_of_squares_nan(weights, expected): 247 248 da = DataArray([np.nan, 2]) 249 weights = DataArray(weights) 250 result = da.weighted(weights).sum_of_squares() 251 252 expected = DataArray(expected) 253 254 assert_equal(expected, result) 255 256 257@pytest.mark.filterwarnings("error") 258@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan])) 259@pytest.mark.parametrize("skipna", (True, False)) 260@pytest.mark.parametrize("factor", [1, 2, 3.14]) 261def test_weighted_var_equal_weights(da, skipna, factor): 262 # if all weights are equal (!= 0), should yield the same result as var 263 264 da = DataArray(da) 265 266 # all weights as 1. 267 weights = xr.full_like(da, factor) 268 269 expected = da.var(skipna=skipna) 270 result = da.weighted(weights).var(skipna=skipna) 271 272 assert_equal(expected, result) 273 274 275@pytest.mark.parametrize( 276 ("weights", "expected"), (([4, 6], 0.24), ([1, 0], 0.0), ([0, 0], np.nan)) 277) 278def test_weighted_var_no_nan(weights, expected): 279 280 da = DataArray([1, 2]) 281 weights = DataArray(weights) 282 expected = DataArray(expected) 283 284 result = da.weighted(weights).var() 285 286 assert_equal(expected, result) 287 288 289@pytest.mark.parametrize( 290 ("weights", "expected"), (([4, 6], 0), ([1, 0], np.nan), ([0, 0], np.nan)) 291) 292def test_weighted_var_nan(weights, expected): 293 294 da = DataArray([np.nan, 2]) 295 weights = DataArray(weights) 296 expected = DataArray(expected) 297 298 result = da.weighted(weights).var() 299 300 assert_equal(expected, result) 301 302 303def test_weighted_var_bool(): 304 # https://github.com/pydata/xarray/issues/4074 305 da = DataArray([1, 1]) 306 weights = DataArray([True, True]) 307 expected = DataArray(0) 308 309 result = da.weighted(weights).var() 310 311 assert_equal(expected, result) 312 313 314@pytest.mark.filterwarnings("error") 315@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan])) 316@pytest.mark.parametrize("skipna", (True, False)) 317@pytest.mark.parametrize("factor", [1, 2, 3.14]) 318def test_weighted_std_equal_weights(da, skipna, factor): 319 # if all weights are equal (!= 0), should yield the same result as std 320 321 da = DataArray(da) 322 323 # all weights as 1. 324 weights = xr.full_like(da, factor) 325 326 expected = da.std(skipna=skipna) 327 result = da.weighted(weights).std(skipna=skipna) 328 329 assert_equal(expected, result) 330 331 332@pytest.mark.parametrize( 333 ("weights", "expected"), (([4, 6], np.sqrt(0.24)), ([1, 0], 0.0), ([0, 0], np.nan)) 334) 335def test_weighted_std_no_nan(weights, expected): 336 337 da = DataArray([1, 2]) 338 weights = DataArray(weights) 339 expected = DataArray(expected) 340 341 result = da.weighted(weights).std() 342 343 assert_equal(expected, result) 344 345 346@pytest.mark.parametrize( 347 ("weights", "expected"), (([4, 6], 0), ([1, 0], np.nan), ([0, 0], np.nan)) 348) 349def test_weighted_std_nan(weights, expected): 350 351 da = DataArray([np.nan, 2]) 352 weights = DataArray(weights) 353 expected = DataArray(expected) 354 355 result = da.weighted(weights).std() 356 357 assert_equal(expected, result) 358 359 360def test_weighted_std_bool(): 361 # https://github.com/pydata/xarray/issues/4074 362 da = DataArray([1, 1]) 363 weights = DataArray([True, True]) 364 expected = DataArray(0) 365 366 result = da.weighted(weights).std() 367 368 assert_equal(expected, result) 369 370 371def expected_weighted(da, weights, dim, skipna, operation): 372 """ 373 Generate expected result using ``*`` and ``sum``. This is checked against 374 the result of da.weighted which uses ``dot`` 375 """ 376 377 weighted_sum = (da * weights).sum(dim=dim, skipna=skipna) 378 379 if operation == "sum": 380 return weighted_sum 381 382 masked_weights = weights.where(da.notnull()) 383 sum_of_weights = masked_weights.sum(dim=dim, skipna=True) 384 valid_weights = sum_of_weights != 0 385 sum_of_weights = sum_of_weights.where(valid_weights) 386 387 if operation == "sum_of_weights": 388 return sum_of_weights 389 390 weighted_mean = weighted_sum / sum_of_weights 391 392 if operation == "mean": 393 return weighted_mean 394 395 demeaned = da - weighted_mean 396 sum_of_squares = ((demeaned ** 2) * weights).sum(dim=dim, skipna=skipna) 397 398 if operation == "sum_of_squares": 399 return sum_of_squares 400 401 var = sum_of_squares / sum_of_weights 402 403 if operation == "var": 404 return var 405 406 if operation == "std": 407 return np.sqrt(var) 408 409 410def check_weighted_operations(data, weights, dim, skipna): 411 412 # check sum of weights 413 result = data.weighted(weights).sum_of_weights(dim) 414 expected = expected_weighted(data, weights, dim, skipna, "sum_of_weights") 415 assert_allclose(expected, result) 416 417 # check weighted sum 418 result = data.weighted(weights).sum(dim, skipna=skipna) 419 expected = expected_weighted(data, weights, dim, skipna, "sum") 420 assert_allclose(expected, result) 421 422 # check weighted mean 423 result = data.weighted(weights).mean(dim, skipna=skipna) 424 expected = expected_weighted(data, weights, dim, skipna, "mean") 425 assert_allclose(expected, result) 426 427 # check weighted sum of squares 428 result = data.weighted(weights).sum_of_squares(dim, skipna=skipna) 429 expected = expected_weighted(data, weights, dim, skipna, "sum_of_squares") 430 assert_allclose(expected, result) 431 432 # check weighted var 433 result = data.weighted(weights).var(dim, skipna=skipna) 434 expected = expected_weighted(data, weights, dim, skipna, "var") 435 assert_allclose(expected, result) 436 437 # check weighted std 438 result = data.weighted(weights).std(dim, skipna=skipna) 439 expected = expected_weighted(data, weights, dim, skipna, "std") 440 assert_allclose(expected, result) 441 442 443@pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None)) 444@pytest.mark.parametrize("add_nans", (True, False)) 445@pytest.mark.parametrize("skipna", (None, True, False)) 446def test_weighted_operations_3D(dim, add_nans, skipna): 447 448 dims = ("a", "b", "c") 449 coords = dict(a=[0, 1, 2, 3], b=[0, 1, 2, 3], c=[0, 1, 2, 3]) 450 451 weights = DataArray(np.random.randn(4, 4, 4), dims=dims, coords=coords) 452 453 data = np.random.randn(4, 4, 4) 454 455 # add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700) 456 if add_nans: 457 c = int(data.size * 0.25) 458 data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN 459 460 data = DataArray(data, dims=dims, coords=coords) 461 462 check_weighted_operations(data, weights, dim, skipna) 463 464 data = data.to_dataset(name="data") 465 check_weighted_operations(data, weights, dim, skipna) 466 467 468def test_weighted_operations_nonequal_coords(): 469 470 weights = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[0, 1, 2, 3])) 471 data = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[1, 2, 3, 4])) 472 473 check_weighted_operations(data, weights, dim="a", skipna=None) 474 475 data = data.to_dataset(name="data") 476 check_weighted_operations(data, weights, dim="a", skipna=None) 477 478 479@pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4))) 480@pytest.mark.parametrize("shape_weights", ((4,), (4, 4), (4, 4, 4))) 481@pytest.mark.parametrize("add_nans", (True, False)) 482@pytest.mark.parametrize("skipna", (None, True, False)) 483def test_weighted_operations_different_shapes( 484 shape_data, shape_weights, add_nans, skipna 485): 486 487 weights = DataArray(np.random.randn(*shape_weights)) 488 489 data = np.random.randn(*shape_data) 490 491 # add approximately 25 % NaNs 492 if add_nans: 493 c = int(data.size * 0.25) 494 data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN 495 496 data = DataArray(data) 497 498 check_weighted_operations(data, weights, "dim_0", skipna) 499 check_weighted_operations(data, weights, None, skipna) 500 501 data = data.to_dataset(name="data") 502 check_weighted_operations(data, weights, "dim_0", skipna) 503 check_weighted_operations(data, weights, None, skipna) 504 505 506@pytest.mark.parametrize( 507 "operation", ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std") 508) 509@pytest.mark.parametrize("as_dataset", (True, False)) 510@pytest.mark.parametrize("keep_attrs", (True, False, None)) 511def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs): 512 513 weights = DataArray(np.random.randn(2, 2), attrs=dict(attr="weights")) 514 data = DataArray(np.random.randn(2, 2)) 515 516 if as_dataset: 517 data = data.to_dataset(name="data") 518 519 data.attrs = dict(attr="weights") 520 521 result = getattr(data.weighted(weights), operation)(keep_attrs=True) 522 523 if operation == "sum_of_weights": 524 assert weights.attrs == result.attrs 525 else: 526 assert data.attrs == result.attrs 527 528 result = getattr(data.weighted(weights), operation)(keep_attrs=None) 529 assert not result.attrs 530 531 result = getattr(data.weighted(weights), operation)(keep_attrs=False) 532 assert not result.attrs 533 534 535@pytest.mark.parametrize( 536 "operation", ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std") 537) 538def test_weighted_operations_keep_attr_da_in_ds(operation): 539 # GH #3595 540 541 weights = DataArray(np.random.randn(2, 2)) 542 data = DataArray(np.random.randn(2, 2), attrs=dict(attr="data")) 543 data = data.to_dataset(name="a") 544 545 result = getattr(data.weighted(weights), operation)(keep_attrs=True) 546 547 assert data.a.attrs == result.a.attrs 548 549 550@pytest.mark.parametrize("as_dataset", (True, False)) 551def test_weighted_bad_dim(as_dataset): 552 553 data = DataArray(np.random.randn(2, 2)) 554 weights = xr.ones_like(data) 555 if as_dataset: 556 data = data.to_dataset(name="data") 557 558 error_msg = ( 559 f"{data.__class__.__name__}Weighted" 560 " does not contain the dimensions: {'bad_dim'}" 561 ) 562 with pytest.raises(ValueError, match=error_msg): 563 data.weighted(weights).mean("bad_dim") 564