1import numpy as np
2import pytest
3
4import xarray as xr
5from xarray import DataArray
6from xarray.tests import assert_allclose, assert_equal
7
8from . import raise_if_dask_computes, requires_cftime, requires_dask
9
10
11@pytest.mark.parametrize("as_dataset", (True, False))
12def test_weighted_non_DataArray_weights(as_dataset):
13
14    data = DataArray([1, 2])
15    if as_dataset:
16        data = data.to_dataset(name="data")
17
18    with pytest.raises(ValueError, match=r"`weights` must be a DataArray"):
19        data.weighted([1, 2])
20
21
22@pytest.mark.parametrize("as_dataset", (True, False))
23@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan]))
24def test_weighted_weights_nan_raises(as_dataset, weights):
25
26    data = DataArray([1, 2])
27    if as_dataset:
28        data = data.to_dataset(name="data")
29
30    with pytest.raises(ValueError, match="`weights` cannot contain missing values."):
31        data.weighted(DataArray(weights))
32
33
34@requires_dask
35@pytest.mark.parametrize("as_dataset", (True, False))
36@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan]))
37def test_weighted_weights_nan_raises_dask(as_dataset, weights):
38
39    data = DataArray([1, 2]).chunk({"dim_0": -1})
40    if as_dataset:
41        data = data.to_dataset(name="data")
42
43    weights = DataArray(weights).chunk({"dim_0": -1})
44
45    with raise_if_dask_computes():
46        weighted = data.weighted(weights)
47
48    with pytest.raises(ValueError, match="`weights` cannot contain missing values."):
49        weighted.sum().load()
50
51
52@requires_cftime
53@requires_dask
54@pytest.mark.parametrize("time_chunks", (1, 5))
55@pytest.mark.parametrize("resample_spec", ("1AS", "5AS", "10AS"))
56def test_weighted_lazy_resample(time_chunks, resample_spec):
57    # https://github.com/pydata/xarray/issues/4625
58
59    # simple customized weighted mean function
60    def mean_func(ds):
61        return ds.weighted(ds.weights).mean("time")
62
63    # example dataset
64    t = xr.cftime_range(start="2000", periods=20, freq="1AS")
65    weights = xr.DataArray(np.random.rand(len(t)), dims=["time"], coords={"time": t})
66    data = xr.DataArray(
67        np.random.rand(len(t)), dims=["time"], coords={"time": t, "weights": weights}
68    )
69    ds = xr.Dataset({"data": data}).chunk({"time": time_chunks})
70
71    with raise_if_dask_computes():
72        ds.resample(time=resample_spec).map(mean_func)
73
74
75@pytest.mark.parametrize(
76    ("weights", "expected"),
77    (([1, 2], 3), ([2, 0], 2), ([0, 0], np.nan), ([-1, 1], np.nan)),
78)
79def test_weighted_sum_of_weights_no_nan(weights, expected):
80
81    da = DataArray([1, 2])
82    weights = DataArray(weights)
83    result = da.weighted(weights).sum_of_weights()
84
85    expected = DataArray(expected)
86
87    assert_equal(expected, result)
88
89
90@pytest.mark.parametrize(
91    ("weights", "expected"),
92    (([1, 2], 2), ([2, 0], np.nan), ([0, 0], np.nan), ([-1, 1], 1)),
93)
94def test_weighted_sum_of_weights_nan(weights, expected):
95
96    da = DataArray([np.nan, 2])
97    weights = DataArray(weights)
98    result = da.weighted(weights).sum_of_weights()
99
100    expected = DataArray(expected)
101
102    assert_equal(expected, result)
103
104
105def test_weighted_sum_of_weights_bool():
106    # https://github.com/pydata/xarray/issues/4074
107
108    da = DataArray([1, 2])
109    weights = DataArray([True, True])
110    result = da.weighted(weights).sum_of_weights()
111
112    expected = DataArray(2)
113
114    assert_equal(expected, result)
115
116
117@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan]))
118@pytest.mark.parametrize("factor", [0, 1, 3.14])
119@pytest.mark.parametrize("skipna", (True, False))
120def test_weighted_sum_equal_weights(da, factor, skipna):
121    # if all weights are 'f'; weighted sum is f times the ordinary sum
122
123    da = DataArray(da)
124    weights = xr.full_like(da, factor)
125
126    expected = da.sum(skipna=skipna) * factor
127    result = da.weighted(weights).sum(skipna=skipna)
128
129    assert_equal(expected, result)
130
131
132@pytest.mark.parametrize(
133    ("weights", "expected"), (([1, 2], 5), ([0, 2], 4), ([0, 0], 0))
134)
135def test_weighted_sum_no_nan(weights, expected):
136
137    da = DataArray([1, 2])
138
139    weights = DataArray(weights)
140    result = da.weighted(weights).sum()
141    expected = DataArray(expected)
142
143    assert_equal(expected, result)
144
145
146@pytest.mark.parametrize(
147    ("weights", "expected"), (([1, 2], 4), ([0, 2], 4), ([1, 0], 0), ([0, 0], 0))
148)
149@pytest.mark.parametrize("skipna", (True, False))
150def test_weighted_sum_nan(weights, expected, skipna):
151
152    da = DataArray([np.nan, 2])
153
154    weights = DataArray(weights)
155    result = da.weighted(weights).sum(skipna=skipna)
156
157    if skipna:
158        expected = DataArray(expected)
159    else:
160        expected = DataArray(np.nan)
161
162    assert_equal(expected, result)
163
164
165@pytest.mark.filterwarnings("error")
166@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan]))
167@pytest.mark.parametrize("skipna", (True, False))
168@pytest.mark.parametrize("factor", [1, 2, 3.14])
169def test_weighted_mean_equal_weights(da, skipna, factor):
170    # if all weights are equal (!= 0), should yield the same result as mean
171
172    da = DataArray(da)
173
174    # all weights as 1.
175    weights = xr.full_like(da, factor)
176
177    expected = da.mean(skipna=skipna)
178    result = da.weighted(weights).mean(skipna=skipna)
179
180    assert_equal(expected, result)
181
182
183@pytest.mark.parametrize(
184    ("weights", "expected"), (([4, 6], 1.6), ([1, 0], 1.0), ([0, 0], np.nan))
185)
186def test_weighted_mean_no_nan(weights, expected):
187
188    da = DataArray([1, 2])
189    weights = DataArray(weights)
190    expected = DataArray(expected)
191
192    result = da.weighted(weights).mean()
193
194    assert_equal(expected, result)
195
196
197@pytest.mark.parametrize(
198    ("weights", "expected"), (([4, 6], 2.0), ([1, 0], np.nan), ([0, 0], np.nan))
199)
200@pytest.mark.parametrize("skipna", (True, False))
201def test_weighted_mean_nan(weights, expected, skipna):
202
203    da = DataArray([np.nan, 2])
204    weights = DataArray(weights)
205
206    if skipna:
207        expected = DataArray(expected)
208    else:
209        expected = DataArray(np.nan)
210
211    result = da.weighted(weights).mean(skipna=skipna)
212
213    assert_equal(expected, result)
214
215
216def test_weighted_mean_bool():
217    # https://github.com/pydata/xarray/issues/4074
218    da = DataArray([1, 1])
219    weights = DataArray([True, True])
220    expected = DataArray(1)
221
222    result = da.weighted(weights).mean()
223
224    assert_equal(expected, result)
225
226
227@pytest.mark.parametrize(
228    ("weights", "expected"),
229    (([1, 2], 2 / 3), ([2, 0], 0), ([0, 0], 0), ([-1, 1], 0)),
230)
231def test_weighted_sum_of_squares_no_nan(weights, expected):
232
233    da = DataArray([1, 2])
234    weights = DataArray(weights)
235    result = da.weighted(weights).sum_of_squares()
236
237    expected = DataArray(expected)
238
239    assert_equal(expected, result)
240
241
242@pytest.mark.parametrize(
243    ("weights", "expected"),
244    (([1, 2], 0), ([2, 0], 0), ([0, 0], 0), ([-1, 1], 0)),
245)
246def test_weighted_sum_of_squares_nan(weights, expected):
247
248    da = DataArray([np.nan, 2])
249    weights = DataArray(weights)
250    result = da.weighted(weights).sum_of_squares()
251
252    expected = DataArray(expected)
253
254    assert_equal(expected, result)
255
256
257@pytest.mark.filterwarnings("error")
258@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan]))
259@pytest.mark.parametrize("skipna", (True, False))
260@pytest.mark.parametrize("factor", [1, 2, 3.14])
261def test_weighted_var_equal_weights(da, skipna, factor):
262    # if all weights are equal (!= 0), should yield the same result as var
263
264    da = DataArray(da)
265
266    # all weights as 1.
267    weights = xr.full_like(da, factor)
268
269    expected = da.var(skipna=skipna)
270    result = da.weighted(weights).var(skipna=skipna)
271
272    assert_equal(expected, result)
273
274
275@pytest.mark.parametrize(
276    ("weights", "expected"), (([4, 6], 0.24), ([1, 0], 0.0), ([0, 0], np.nan))
277)
278def test_weighted_var_no_nan(weights, expected):
279
280    da = DataArray([1, 2])
281    weights = DataArray(weights)
282    expected = DataArray(expected)
283
284    result = da.weighted(weights).var()
285
286    assert_equal(expected, result)
287
288
289@pytest.mark.parametrize(
290    ("weights", "expected"), (([4, 6], 0), ([1, 0], np.nan), ([0, 0], np.nan))
291)
292def test_weighted_var_nan(weights, expected):
293
294    da = DataArray([np.nan, 2])
295    weights = DataArray(weights)
296    expected = DataArray(expected)
297
298    result = da.weighted(weights).var()
299
300    assert_equal(expected, result)
301
302
303def test_weighted_var_bool():
304    # https://github.com/pydata/xarray/issues/4074
305    da = DataArray([1, 1])
306    weights = DataArray([True, True])
307    expected = DataArray(0)
308
309    result = da.weighted(weights).var()
310
311    assert_equal(expected, result)
312
313
314@pytest.mark.filterwarnings("error")
315@pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan]))
316@pytest.mark.parametrize("skipna", (True, False))
317@pytest.mark.parametrize("factor", [1, 2, 3.14])
318def test_weighted_std_equal_weights(da, skipna, factor):
319    # if all weights are equal (!= 0), should yield the same result as std
320
321    da = DataArray(da)
322
323    # all weights as 1.
324    weights = xr.full_like(da, factor)
325
326    expected = da.std(skipna=skipna)
327    result = da.weighted(weights).std(skipna=skipna)
328
329    assert_equal(expected, result)
330
331
332@pytest.mark.parametrize(
333    ("weights", "expected"), (([4, 6], np.sqrt(0.24)), ([1, 0], 0.0), ([0, 0], np.nan))
334)
335def test_weighted_std_no_nan(weights, expected):
336
337    da = DataArray([1, 2])
338    weights = DataArray(weights)
339    expected = DataArray(expected)
340
341    result = da.weighted(weights).std()
342
343    assert_equal(expected, result)
344
345
346@pytest.mark.parametrize(
347    ("weights", "expected"), (([4, 6], 0), ([1, 0], np.nan), ([0, 0], np.nan))
348)
349def test_weighted_std_nan(weights, expected):
350
351    da = DataArray([np.nan, 2])
352    weights = DataArray(weights)
353    expected = DataArray(expected)
354
355    result = da.weighted(weights).std()
356
357    assert_equal(expected, result)
358
359
360def test_weighted_std_bool():
361    # https://github.com/pydata/xarray/issues/4074
362    da = DataArray([1, 1])
363    weights = DataArray([True, True])
364    expected = DataArray(0)
365
366    result = da.weighted(weights).std()
367
368    assert_equal(expected, result)
369
370
371def expected_weighted(da, weights, dim, skipna, operation):
372    """
373    Generate expected result using ``*`` and ``sum``. This is checked against
374    the result of da.weighted which uses ``dot``
375    """
376
377    weighted_sum = (da * weights).sum(dim=dim, skipna=skipna)
378
379    if operation == "sum":
380        return weighted_sum
381
382    masked_weights = weights.where(da.notnull())
383    sum_of_weights = masked_weights.sum(dim=dim, skipna=True)
384    valid_weights = sum_of_weights != 0
385    sum_of_weights = sum_of_weights.where(valid_weights)
386
387    if operation == "sum_of_weights":
388        return sum_of_weights
389
390    weighted_mean = weighted_sum / sum_of_weights
391
392    if operation == "mean":
393        return weighted_mean
394
395    demeaned = da - weighted_mean
396    sum_of_squares = ((demeaned ** 2) * weights).sum(dim=dim, skipna=skipna)
397
398    if operation == "sum_of_squares":
399        return sum_of_squares
400
401    var = sum_of_squares / sum_of_weights
402
403    if operation == "var":
404        return var
405
406    if operation == "std":
407        return np.sqrt(var)
408
409
410def check_weighted_operations(data, weights, dim, skipna):
411
412    # check sum of weights
413    result = data.weighted(weights).sum_of_weights(dim)
414    expected = expected_weighted(data, weights, dim, skipna, "sum_of_weights")
415    assert_allclose(expected, result)
416
417    # check weighted sum
418    result = data.weighted(weights).sum(dim, skipna=skipna)
419    expected = expected_weighted(data, weights, dim, skipna, "sum")
420    assert_allclose(expected, result)
421
422    # check weighted mean
423    result = data.weighted(weights).mean(dim, skipna=skipna)
424    expected = expected_weighted(data, weights, dim, skipna, "mean")
425    assert_allclose(expected, result)
426
427    # check weighted sum of squares
428    result = data.weighted(weights).sum_of_squares(dim, skipna=skipna)
429    expected = expected_weighted(data, weights, dim, skipna, "sum_of_squares")
430    assert_allclose(expected, result)
431
432    # check weighted var
433    result = data.weighted(weights).var(dim, skipna=skipna)
434    expected = expected_weighted(data, weights, dim, skipna, "var")
435    assert_allclose(expected, result)
436
437    # check weighted std
438    result = data.weighted(weights).std(dim, skipna=skipna)
439    expected = expected_weighted(data, weights, dim, skipna, "std")
440    assert_allclose(expected, result)
441
442
443@pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None))
444@pytest.mark.parametrize("add_nans", (True, False))
445@pytest.mark.parametrize("skipna", (None, True, False))
446def test_weighted_operations_3D(dim, add_nans, skipna):
447
448    dims = ("a", "b", "c")
449    coords = dict(a=[0, 1, 2, 3], b=[0, 1, 2, 3], c=[0, 1, 2, 3])
450
451    weights = DataArray(np.random.randn(4, 4, 4), dims=dims, coords=coords)
452
453    data = np.random.randn(4, 4, 4)
454
455    # add approximately 25 % NaNs (https://stackoverflow.com/a/32182680/3010700)
456    if add_nans:
457        c = int(data.size * 0.25)
458        data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN
459
460    data = DataArray(data, dims=dims, coords=coords)
461
462    check_weighted_operations(data, weights, dim, skipna)
463
464    data = data.to_dataset(name="data")
465    check_weighted_operations(data, weights, dim, skipna)
466
467
468def test_weighted_operations_nonequal_coords():
469
470    weights = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[0, 1, 2, 3]))
471    data = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[1, 2, 3, 4]))
472
473    check_weighted_operations(data, weights, dim="a", skipna=None)
474
475    data = data.to_dataset(name="data")
476    check_weighted_operations(data, weights, dim="a", skipna=None)
477
478
479@pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4)))
480@pytest.mark.parametrize("shape_weights", ((4,), (4, 4), (4, 4, 4)))
481@pytest.mark.parametrize("add_nans", (True, False))
482@pytest.mark.parametrize("skipna", (None, True, False))
483def test_weighted_operations_different_shapes(
484    shape_data, shape_weights, add_nans, skipna
485):
486
487    weights = DataArray(np.random.randn(*shape_weights))
488
489    data = np.random.randn(*shape_data)
490
491    # add approximately 25 % NaNs
492    if add_nans:
493        c = int(data.size * 0.25)
494        data.ravel()[np.random.choice(data.size, c, replace=False)] = np.NaN
495
496    data = DataArray(data)
497
498    check_weighted_operations(data, weights, "dim_0", skipna)
499    check_weighted_operations(data, weights, None, skipna)
500
501    data = data.to_dataset(name="data")
502    check_weighted_operations(data, weights, "dim_0", skipna)
503    check_weighted_operations(data, weights, None, skipna)
504
505
506@pytest.mark.parametrize(
507    "operation", ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std")
508)
509@pytest.mark.parametrize("as_dataset", (True, False))
510@pytest.mark.parametrize("keep_attrs", (True, False, None))
511def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs):
512
513    weights = DataArray(np.random.randn(2, 2), attrs=dict(attr="weights"))
514    data = DataArray(np.random.randn(2, 2))
515
516    if as_dataset:
517        data = data.to_dataset(name="data")
518
519    data.attrs = dict(attr="weights")
520
521    result = getattr(data.weighted(weights), operation)(keep_attrs=True)
522
523    if operation == "sum_of_weights":
524        assert weights.attrs == result.attrs
525    else:
526        assert data.attrs == result.attrs
527
528    result = getattr(data.weighted(weights), operation)(keep_attrs=None)
529    assert not result.attrs
530
531    result = getattr(data.weighted(weights), operation)(keep_attrs=False)
532    assert not result.attrs
533
534
535@pytest.mark.parametrize(
536    "operation", ("sum_of_weights", "sum", "mean", "sum_of_squares", "var", "std")
537)
538def test_weighted_operations_keep_attr_da_in_ds(operation):
539    # GH #3595
540
541    weights = DataArray(np.random.randn(2, 2))
542    data = DataArray(np.random.randn(2, 2), attrs=dict(attr="data"))
543    data = data.to_dataset(name="a")
544
545    result = getattr(data.weighted(weights), operation)(keep_attrs=True)
546
547    assert data.a.attrs == result.a.attrs
548
549
550@pytest.mark.parametrize("as_dataset", (True, False))
551def test_weighted_bad_dim(as_dataset):
552
553    data = DataArray(np.random.randn(2, 2))
554    weights = xr.ones_like(data)
555    if as_dataset:
556        data = data.to_dataset(name="data")
557
558    error_msg = (
559        f"{data.__class__.__name__}Weighted"
560        " does not contain the dimensions: {'bad_dim'}"
561    )
562    with pytest.raises(ValueError, match=error_msg):
563        data.weighted(weights).mean("bad_dim")
564