1# pylint: disable=redefined-outer-name, no-member
2from copy import deepcopy
3
4import numpy as np
5import pytest
6from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal
7from scipy.stats import linregress
8from xarray import DataArray, Dataset
9
10from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data
11from ...rcparams import rcParams
12from ...stats import (
13    apply_test_function,
14    compare,
15    ess,
16    hdi,
17    loo,
18    loo_pit,
19    psislw,
20    r2_score,
21    summary,
22    waic,
23)
24from ...stats.stats import _gpinv
25from ...stats.stats_utils import get_log_likelihood
26from ..helpers import check_multiple_attrs, multidim_models  # pylint: disable=unused-import
27
28rcParams["data.load"] = "eager"
29
30
31@pytest.fixture(scope="session")
32def centered_eight():
33    centered_eight = load_arviz_data("centered_eight")
34    return centered_eight
35
36
37@pytest.fixture(scope="session")
38def non_centered_eight():
39    non_centered_eight = load_arviz_data("non_centered_eight")
40    return non_centered_eight
41
42
43@pytest.fixture(scope="module")
44def multivariable_log_likelihood(centered_eight):
45    centered_eight = centered_eight.copy()
46    centered_eight.add_groups({"log_likelihood": centered_eight.sample_stats.log_likelihood})
47    centered_eight.log_likelihood = centered_eight.log_likelihood.rename_vars(
48        {"log_likelihood": "obs"}
49    )
50    new_arr = DataArray(
51        np.zeros(centered_eight.log_likelihood["obs"].values.shape),
52        dims=["chain", "draw", "school"],
53        coords=centered_eight.log_likelihood.coords,
54    )
55    centered_eight.log_likelihood["decoy"] = new_arr
56    delattr(centered_eight, "sample_stats")
57    return centered_eight
58
59
60def test_hdp():
61    normal_sample = np.random.randn(5000000)
62    interval = hdi(normal_sample)
63    assert_array_almost_equal(interval, [-1.88, 1.88], 2)
64
65
66def test_hdp_2darray():
67    normal_sample = np.random.randn(12000, 5)
68    msg = (
69        r"hdi currently interprets 2d data as \(draw, shape\) but this will "
70        r"change in a future release to \(chain, draw\) for coherence with other functions"
71    )
72    with pytest.warns(FutureWarning, match=msg):
73        result = hdi(normal_sample)
74    assert result.shape == (5, 2)
75
76
77def test_hdi_multidimension():
78    normal_sample = np.random.randn(12000, 10, 3)
79    result = hdi(normal_sample)
80    assert result.shape == (3, 2)
81
82
83def test_hdi_idata(centered_eight):
84    data = centered_eight.posterior
85    result = hdi(data)
86    assert isinstance(result, Dataset)
87    assert dict(result.dims) == {"school": 8, "hdi": 2}
88
89    result = hdi(data, input_core_dims=[["chain"]])
90    assert isinstance(result, Dataset)
91    assert result.dims == {"draw": 500, "hdi": 2, "school": 8}
92
93
94def test_hdi_idata_varnames(centered_eight):
95    data = centered_eight.posterior
96    result = hdi(data, var_names=["mu", "theta"])
97    assert isinstance(result, Dataset)
98    assert result.dims == {"hdi": 2, "school": 8}
99    assert list(result.data_vars.keys()) == ["mu", "theta"]
100
101
102def test_hdi_idata_group(centered_eight):
103    result_posterior = hdi(centered_eight, group="posterior", var_names="mu")
104    result_prior = hdi(centered_eight, group="prior", var_names="mu")
105    assert result_prior.dims == {"hdi": 2}
106    range_posterior = result_posterior.mu.values[1] - result_posterior.mu.values[0]
107    range_prior = result_prior.mu.values[1] - result_prior.mu.values[0]
108    assert range_posterior < range_prior
109
110
111def test_hdi_coords(centered_eight):
112    data = centered_eight.posterior
113    result = hdi(data, coords={"chain": [0, 1, 3]}, input_core_dims=[["draw"]])
114    assert_array_equal(result.coords["chain"], [0, 1, 3])
115
116
117def test_hdi_multimodal():
118    normal_sample = np.concatenate(
119        (np.random.normal(-4, 1, 2500000), np.random.normal(2, 0.5, 2500000))
120    )
121    intervals = hdi(normal_sample, multimodal=True)
122    assert_array_almost_equal(intervals, [[-5.8, -2.2], [0.9, 3.1]], 1)
123
124
125def test_hdi_circular():
126    normal_sample = np.random.vonmises(np.pi, 1, 5000000)
127    interval = hdi(normal_sample, circular=True)
128    assert_array_almost_equal(interval, [0.6, -0.6], 1)
129
130
131def test_hdi_bad_ci():
132    normal_sample = np.random.randn(10)
133    with pytest.raises(ValueError):
134        hdi(normal_sample, hdi_prob=2)
135
136
137def test_hdi_skipna():
138    normal_sample = np.random.randn(500)
139    interval = hdi(normal_sample[10:])
140    normal_sample[:10] = np.nan
141    interval_ = hdi(normal_sample, skipna=True)
142    assert_array_almost_equal(interval, interval_)
143
144
145def test_r2_score():
146    x = np.linspace(0, 1, 100)
147    y = np.random.normal(x, 1)
148    res = linregress(x, y)
149    assert_allclose(res.rvalue ** 2, r2_score(y, res.intercept + res.slope * x).r2, 2)
150
151
152def test_r2_score_multivariate():
153    x = np.linspace(0, 1, 100)
154    y = np.random.normal(x, 1)
155    res = linregress(x, y)
156    y_multivariate = np.c_[y, y]
157    y_multivariate_pred = np.c_[res.intercept + res.slope * x, res.intercept + res.slope * x]
158    assert not np.isnan(r2_score(y_multivariate, y_multivariate_pred).r2)
159
160
161@pytest.mark.parametrize("method", ["stacking", "BB-pseudo-BMA", "pseudo-BMA"])
162@pytest.mark.parametrize("multidim", [True, False])
163def test_compare_same(centered_eight, multidim_models, method, multidim):
164    if multidim:
165        data_dict = {"first": multidim_models.model_1, "second": multidim_models.model_1}
166    else:
167        data_dict = {"first": centered_eight, "second": centered_eight}
168
169    weight = compare(data_dict, method=method)["weight"]
170    assert_allclose(weight[0], weight[1])
171    assert_allclose(np.sum(weight), 1.0)
172
173
174def test_compare_unknown_ic_and_method(centered_eight, non_centered_eight):
175    model_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
176    with pytest.raises(ValueError):
177        compare(model_dict, ic="Unknown", method="stacking")
178    with pytest.raises(ValueError):
179        compare(model_dict, ic="loo", method="Unknown")
180
181
182@pytest.mark.parametrize("ic", ["loo", "waic"])
183@pytest.mark.parametrize("method", ["stacking", "BB-pseudo-BMA", "pseudo-BMA"])
184@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
185def test_compare_different(centered_eight, non_centered_eight, ic, method, scale):
186    model_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
187    weight = compare(model_dict, ic=ic, method=method, scale=scale)["weight"]
188    assert weight["non_centered"] > weight["centered"]
189    assert_allclose(np.sum(weight), 1.0)
190
191
192@pytest.mark.parametrize("ic", ["loo", "waic"])
193@pytest.mark.parametrize("method", ["stacking", "BB-pseudo-BMA", "pseudo-BMA"])
194def test_compare_different_multidim(multidim_models, ic, method):
195    model_dict = {"model_1": multidim_models.model_1, "model_2": multidim_models.model_2}
196    weight = compare(model_dict, ic=ic, method=method)["weight"]
197
198    # this should hold because the same seed is always used
199    assert weight["model_1"] > weight["model_2"]
200    assert_allclose(np.sum(weight), 1.0)
201
202
203def test_compare_different_size(centered_eight, non_centered_eight):
204    centered_eight = deepcopy(centered_eight)
205    centered_eight.posterior = centered_eight.posterior.drop("Choate", "school")
206    centered_eight.sample_stats = centered_eight.sample_stats.drop("Choate", "school")
207    centered_eight.posterior_predictive = centered_eight.posterior_predictive.drop(
208        "Choate", "school"
209    )
210    centered_eight.prior = centered_eight.prior.drop("Choate", "school")
211    centered_eight.observed_data = centered_eight.observed_data.drop("Choate", "school")
212    model_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
213    with pytest.raises(ValueError):
214        compare(model_dict, ic="waic", method="stacking")
215
216
217@pytest.mark.parametrize("ic", ["loo", "waic"])
218def test_compare_multiple_obs(multivariable_log_likelihood, centered_eight, non_centered_eight, ic):
219    compare_dict = {
220        "centered_eight": centered_eight,
221        "non_centered_eight": non_centered_eight,
222        "problematic": multivariable_log_likelihood,
223    }
224    with pytest.raises(TypeError, match="several log likelihood arrays"):
225        get_log_likelihood(compare_dict["problematic"])
226    with pytest.raises(TypeError, match=f"{ic}.*model problematic"):
227        compare(compare_dict, ic=ic)
228    assert compare(compare_dict, ic=ic, var_name="obs") is not None
229
230
231def test_summary_ndarray():
232    array = np.random.randn(4, 100, 2)
233    summary_df = summary(array)
234    assert summary_df.shape
235
236
237@pytest.mark.parametrize("var_names_expected", ((None, 10), ("mu", 1), (["mu", "tau"], 2)))
238def test_summary_var_names(centered_eight, var_names_expected):
239    var_names, expected = var_names_expected
240    summary_df = summary(centered_eight, var_names=var_names)
241    assert len(summary_df.index) == expected
242
243
244@pytest.mark.parametrize("missing_groups", (None, "posterior", "prior"))
245def test_summary_groups(centered_eight, missing_groups):
246    if missing_groups == "posterior":
247        centered_eight = deepcopy(centered_eight)
248        del centered_eight.posterior
249    elif missing_groups == "prior":
250        centered_eight = deepcopy(centered_eight)
251        del centered_eight.posterior
252        del centered_eight.prior
253    if missing_groups == "prior":
254        with pytest.warns(UserWarning):
255            summary_df = summary(centered_eight)
256    else:
257        summary_df = summary(centered_eight)
258    assert summary_df.shape
259
260
261def test_summary_group_argument(centered_eight):
262    summary_df_posterior = summary(centered_eight, group="posterior")
263    summary_df_prior = summary(centered_eight, group="prior")
264    assert list(summary_df_posterior.index) != list(summary_df_prior.index)
265
266
267def test_summary_wrong_group(centered_eight):
268    with pytest.raises(TypeError, match=r"InferenceData does not contain group: InvalidGroup"):
269        summary(centered_eight, group="InvalidGroup")
270
271
272METRICS_NAMES = [
273    "mean",
274    "sd",
275    "hdi_3%",
276    "hdi_97%",
277    "mcse_mean",
278    "mcse_sd",
279    "ess_bulk",
280    "ess_tail",
281    "r_hat",
282]
283
284
285@pytest.mark.parametrize(
286    "params",
287    (("all", METRICS_NAMES), ("stats", METRICS_NAMES[:4]), ("diagnostics", METRICS_NAMES[4:])),
288)
289def test_summary_kind(centered_eight, params):
290    kind, metrics_names_ = params
291    summary_df = summary(centered_eight, kind=kind)
292    assert_array_equal(summary_df.columns, metrics_names_)
293
294
295@pytest.mark.parametrize("fmt", ["wide", "long", "xarray"])
296def test_summary_fmt(centered_eight, fmt):
297    assert summary(centered_eight, fmt=fmt) is not None
298
299
300def test_summary_labels():
301    coords1 = list("abcd")
302    coords2 = np.arange(1, 6)
303    data = from_dict(
304        {"a": np.random.randn(4, 100, 4, 5)},
305        coords={"dim1": coords1, "dim2": coords2},
306        dims={"a": ["dim1", "dim2"]},
307    )
308    az_summary = summary(data, fmt="wide")
309    assert az_summary is not None
310    column_order = []
311    for coord1 in coords1:
312        for coord2 in coords2:
313            column_order.append(f"a[{coord1}, {coord2}]")
314    for col1, col2 in zip(list(az_summary.index), column_order):
315        assert col1 == col2
316
317
318@pytest.mark.parametrize(
319    "stat_funcs", [[np.var], {"var": np.var, "var2": lambda x: np.var(x) ** 2}]
320)
321def test_summary_stat_func(centered_eight, stat_funcs):
322    arviz_summary = summary(centered_eight, stat_funcs=stat_funcs)
323    assert arviz_summary is not None
324    assert hasattr(arviz_summary, "var")
325
326
327def test_summary_nan(centered_eight):
328    centered_eight = deepcopy(centered_eight)
329    centered_eight.posterior["theta"].loc[{"school": "Deerfield"}] = np.nan
330    summary_xarray = summary(centered_eight)
331    assert summary_xarray is not None
332    assert summary_xarray.loc["theta[Deerfield]"].isnull().all()
333    assert (
334        summary_xarray.loc[[ix for ix in summary_xarray.index if ix != "theta[Deerfield]"]]
335        .notnull()
336        .all()
337        .all()
338    )
339
340
341def test_summary_skip_nan(centered_eight):
342    centered_eight = deepcopy(centered_eight)
343    centered_eight.posterior["theta"].loc[{"draw": slice(10), "school": "Deerfield"}] = np.nan
344    summary_xarray = summary(centered_eight)
345    theta_1 = summary_xarray.loc["theta[Deerfield]"].isnull()
346    assert summary_xarray is not None
347    assert ~theta_1[:4].all()
348    assert theta_1[4:].all()
349
350
351@pytest.mark.parametrize("fmt", [1, "bad_fmt"])
352def test_summary_bad_fmt(centered_eight, fmt):
353    with pytest.raises(TypeError, match="Invalid format"):
354        summary(centered_eight, fmt=fmt)
355
356
357def test_summary_order_deprecation(centered_eight):
358    with pytest.warns(DeprecationWarning, match="order"):
359        summary(centered_eight, order="C")
360
361
362def test_summary_index_origin_deprecation(centered_eight):
363    with pytest.warns(DeprecationWarning, match="index_origin"):
364        summary(centered_eight, index_origin=1)
365
366
367@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
368@pytest.mark.parametrize("multidim", (True, False))
369def test_waic(centered_eight, multidim_models, scale, multidim):
370    """Test widely available information criterion calculation"""
371    if multidim:
372        assert waic(multidim_models.model_1, scale=scale) is not None
373        waic_pointwise = waic(multidim_models.model_1, pointwise=True, scale=scale)
374    else:
375        assert waic(centered_eight, scale=scale) is not None
376        waic_pointwise = waic(centered_eight, pointwise=True, scale=scale)
377    assert waic_pointwise is not None
378    assert "waic_i" in waic_pointwise
379
380
381def test_waic_bad(centered_eight):
382    """Test widely available information criterion calculation"""
383    centered_eight = deepcopy(centered_eight)
384    del centered_eight.sample_stats["log_likelihood"]
385    with pytest.raises(TypeError):
386        waic(centered_eight)
387
388    del centered_eight.sample_stats
389    with pytest.raises(TypeError):
390        waic(centered_eight)
391
392
393def test_waic_bad_scale(centered_eight):
394    """Test widely available information criterion calculation with bad scale."""
395    with pytest.raises(TypeError):
396        waic(centered_eight, scale="bad_value")
397
398
399def test_waic_warning(centered_eight):
400    centered_eight = deepcopy(centered_eight)
401    centered_eight.sample_stats["log_likelihood"][:, :250, 1] = 10
402    with pytest.warns(UserWarning):
403        assert waic(centered_eight, pointwise=True) is not None
404    # this should throw a warning, but due to numerical issues it fails
405    centered_eight.sample_stats["log_likelihood"][:, :, :] = 0
406    with pytest.warns(UserWarning):
407        assert waic(centered_eight, pointwise=True) is not None
408
409
410@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
411def test_waic_print(centered_eight, scale):
412    waic_data = waic(centered_eight, scale=scale).__repr__()
413    waic_pointwise = waic(centered_eight, scale=scale, pointwise=True).__repr__()
414    assert waic_data is not None
415    assert waic_pointwise is not None
416    assert waic_data == waic_pointwise
417
418
419@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
420@pytest.mark.parametrize("multidim", (True, False))
421def test_loo(centered_eight, multidim_models, scale, multidim):
422    """Test approximate leave one out criterion calculation"""
423    if multidim:
424        assert loo(multidim_models.model_1, scale=scale) is not None
425        loo_pointwise = loo(multidim_models.model_1, pointwise=True, scale=scale)
426    else:
427        assert loo(centered_eight, scale=scale) is not None
428        loo_pointwise = loo(centered_eight, pointwise=True, scale=scale)
429    assert loo_pointwise is not None
430    assert "loo_i" in loo_pointwise
431    assert "pareto_k" in loo_pointwise
432    assert "loo_scale" in loo_pointwise
433
434
435def test_loo_one_chain(centered_eight):
436    centered_eight = deepcopy(centered_eight)
437    centered_eight.posterior = centered_eight.posterior.drop([1, 2, 3], "chain")
438    centered_eight.sample_stats = centered_eight.sample_stats.drop([1, 2, 3], "chain")
439    assert loo(centered_eight) is not None
440
441
442def test_loo_bad(centered_eight):
443    with pytest.raises(TypeError):
444        loo(np.random.randn(2, 10))
445
446    centered_eight = deepcopy(centered_eight)
447    del centered_eight.sample_stats["log_likelihood"]
448    with pytest.raises(TypeError):
449        loo(centered_eight)
450
451
452def test_loo_bad_scale(centered_eight):
453    """Test loo with bad scale value."""
454    with pytest.raises(TypeError):
455        loo(centered_eight, scale="bad_scale")
456
457
458def test_loo_bad_no_posterior_reff(centered_eight):
459    loo(centered_eight, reff=None)
460    centered_eight = deepcopy(centered_eight)
461    del centered_eight.posterior
462    with pytest.raises(TypeError):
463        loo(centered_eight, reff=None)
464    loo(centered_eight, reff=0.7)
465
466
467def test_loo_warning(centered_eight):
468    centered_eight = deepcopy(centered_eight)
469    # make one of the khats infinity
470    centered_eight.sample_stats["log_likelihood"][:, :, 1] = 10
471    with pytest.warns(UserWarning) as records:
472        assert loo(centered_eight, pointwise=True) is not None
473    assert any("Estimated shape parameter" in str(record.message) for record in records)
474
475    # make all of the khats infinity
476    centered_eight.sample_stats["log_likelihood"][:, :, :] = 1
477    with pytest.warns(UserWarning) as records:
478        assert loo(centered_eight, pointwise=True) is not None
479    assert any("Estimated shape parameter" in str(record.message) for record in records)
480
481
482@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
483def test_loo_print(centered_eight, scale):
484    loo_data = loo(centered_eight, scale=scale).__repr__()
485    loo_pointwise = loo(centered_eight, scale=scale, pointwise=True).__repr__()
486    assert loo_data is not None
487    assert loo_pointwise is not None
488    assert len(loo_data) < len(loo_pointwise)
489
490
491def test_psislw(centered_eight):
492    pareto_k = loo(centered_eight, pointwise=True, reff=0.7)["pareto_k"]
493    log_likelihood = get_log_likelihood(centered_eight)
494    log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
495    assert_allclose(pareto_k, psislw(-log_likelihood, 0.7)[1])
496
497
498@pytest.mark.parametrize("probs", [True, False])
499@pytest.mark.parametrize("kappa", [-1, -0.5, 1e-30, 0.5, 1])
500@pytest.mark.parametrize("sigma", [0, 2])
501def test_gpinv(probs, kappa, sigma):
502    if probs:
503        probs = np.array([0.1, 0.1, 0.1, 0.2, 0.3])
504    else:
505        probs = np.array([-0.1, 0.1, 0.1, 0.2, 0.3])
506    assert len(_gpinv(probs, kappa, sigma)) == len(probs)
507
508
509@pytest.mark.parametrize("func", [loo, waic])
510def test_multidimensional_log_likelihood(func):
511    llm = np.random.rand(4, 23, 15, 2)
512    ll1 = llm.reshape(4, 23, 15 * 2)
513    statsm = Dataset(dict(log_likelihood=DataArray(llm, dims=["chain", "draw", "a", "b"])))
514
515    stats1 = Dataset(dict(log_likelihood=DataArray(ll1, dims=["chain", "draw", "v"])))
516
517    post = Dataset(dict(mu=DataArray(np.random.rand(4, 23, 2), dims=["chain", "draw", "v"])))
518
519    dsm = convert_to_inference_data(statsm, group="sample_stats")
520    ds1 = convert_to_inference_data(stats1, group="sample_stats")
521    dsp = convert_to_inference_data(post, group="posterior")
522
523    dsm = concat(dsp, dsm)
524    ds1 = concat(dsp, ds1)
525
526    frm = func(dsm)
527    fr1 = func(ds1)
528
529    assert (fr1 == frm).all()
530    assert_array_almost_equal(frm[:4], fr1[:4])
531
532
533@pytest.mark.parametrize(
534    "args",
535    [
536        {"y": "obs"},
537        {"y": "obs", "y_hat": "obs"},
538        {"y": "arr", "y_hat": "obs"},
539        {"y": "obs", "y_hat": "arr"},
540        {"y": "arr", "y_hat": "arr"},
541        {"y": "obs", "y_hat": "obs", "log_weights": "arr"},
542        {"y": "arr", "y_hat": "obs", "log_weights": "arr"},
543        {"y": "obs", "y_hat": "arr", "log_weights": "arr"},
544        {"idata": False},
545    ],
546)
547def test_loo_pit(centered_eight, args):
548    y = args.get("y", None)
549    y_hat = args.get("y_hat", None)
550    log_weights = args.get("log_weights", None)
551    y_arr = centered_eight.observed_data.obs
552    y_hat_arr = centered_eight.posterior_predictive.obs.stack(__sample__=("chain", "draw"))
553    log_like = get_log_likelihood(centered_eight).stack(__sample__=("chain", "draw"))
554    n_samples = len(log_like.__sample__)
555    ess_p = ess(centered_eight.posterior, method="mean")
556    reff = np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
557    log_weights_arr = psislw(-log_like, reff=reff)[0]
558
559    if args.get("idata", True):
560        if y == "arr":
561            y = y_arr
562        if y_hat == "arr":
563            y_hat = y_hat_arr
564        if log_weights == "arr":
565            log_weights = log_weights_arr
566        loo_pit_data = loo_pit(idata=centered_eight, y=y, y_hat=y_hat, log_weights=log_weights)
567    else:
568        loo_pit_data = loo_pit(idata=None, y=y_arr, y_hat=y_hat_arr, log_weights=log_weights_arr)
569    assert np.all((loo_pit_data >= 0) & (loo_pit_data <= 1))
570
571
572@pytest.mark.parametrize(
573    "args",
574    [
575        {"y": "y"},
576        {"y": "y", "y_hat": "y"},
577        {"y": "arr", "y_hat": "y"},
578        {"y": "y", "y_hat": "arr"},
579        {"y": "arr", "y_hat": "arr"},
580        {"y": "y", "y_hat": "y", "log_weights": "arr"},
581        {"y": "arr", "y_hat": "y", "log_weights": "arr"},
582        {"y": "y", "y_hat": "arr", "log_weights": "arr"},
583        {"idata": False},
584    ],
585)
586def test_loo_pit_multidim(multidim_models, args):
587    y = args.get("y", None)
588    y_hat = args.get("y_hat", None)
589    log_weights = args.get("log_weights", None)
590    idata = multidim_models.model_1
591    y_arr = idata.observed_data.y
592    y_hat_arr = idata.posterior_predictive.y.stack(__sample__=("chain", "draw"))
593    log_like = get_log_likelihood(idata).stack(__sample__=("chain", "draw"))
594    n_samples = len(log_like.__sample__)
595    ess_p = ess(idata.posterior, method="mean")
596    reff = np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
597    log_weights_arr = psislw(-log_like, reff=reff)[0]
598
599    if args.get("idata", True):
600        if y == "arr":
601            y = y_arr
602        if y_hat == "arr":
603            y_hat = y_hat_arr
604        if log_weights == "arr":
605            log_weights = log_weights_arr
606        loo_pit_data = loo_pit(idata=idata, y=y, y_hat=y_hat, log_weights=log_weights)
607    else:
608        loo_pit_data = loo_pit(idata=None, y=y_arr, y_hat=y_hat_arr, log_weights=log_weights_arr)
609    assert np.all((loo_pit_data >= 0) & (loo_pit_data <= 1))
610
611
612def test_loo_pit_multi_lik():
613    rng = np.random.default_rng(0)
614    post_pred = rng.standard_normal(size=(4, 100, 10))
615    obs = np.quantile(post_pred, np.linspace(0, 1, 10))
616    obs[0] *= 0.9
617    obs[-1] *= 1.1
618    idata = from_dict(
619        posterior={"a": np.random.randn(4, 100)},
620        posterior_predictive={"y": post_pred},
621        observed_data={"y": obs},
622        log_likelihood={"y": -(post_pred ** 2), "decoy": np.zeros_like(post_pred)},
623    )
624    loo_pit_data = loo_pit(idata, y="y")
625    assert np.all((loo_pit_data >= 0) & (loo_pit_data <= 1))
626
627
628@pytest.mark.parametrize("input_type", ["idataarray", "idatanone_ystr", "yarr_yhatnone"])
629def test_loo_pit_bad_input(centered_eight, input_type):
630    """Test incompatible input combinations."""
631    arr = np.random.random((8, 200))
632    if input_type == "idataarray":
633        with pytest.raises(ValueError, match=r"type InferenceData or None"):
634            loo_pit(idata=arr, y="obs")
635    elif input_type == "idatanone_ystr":
636        with pytest.raises(ValueError, match=r"all 3.+must be array or DataArray"):
637            loo_pit(idata=None, y="obs")
638    elif input_type == "yarr_yhatnone":
639        with pytest.raises(ValueError, match=r"y_hat.+None.+y.+str"):
640            loo_pit(idata=centered_eight, y=arr, y_hat=None)
641
642
643@pytest.mark.parametrize("arg", ["y", "y_hat", "log_weights"])
644def test_loo_pit_bad_input_type(centered_eight, arg):
645    """Test wrong input type (not None, str not DataArray."""
646    kwargs = {"y": "obs", "y_hat": "obs", "log_weights": None}
647    kwargs[arg] = 2  # use int instead of array-like
648    with pytest.raises(ValueError, match=f"not {type(2)}"):
649        loo_pit(idata=centered_eight, **kwargs)
650
651
652@pytest.mark.parametrize("incompatibility", ["y-y_hat1", "y-y_hat2", "y_hat-log_weights"])
653def test_loo_pit_bad_input_shape(incompatibility):
654    """Test shape incompatibilities."""
655    y = np.random.random(8)
656    y_hat = np.random.random((8, 200))
657    log_weights = np.random.random((8, 200))
658    if incompatibility == "y-y_hat1":
659        with pytest.raises(ValueError, match="1 more dimension"):
660            loo_pit(y=y, y_hat=y_hat[None, :], log_weights=log_weights)
661    elif incompatibility == "y-y_hat2":
662        with pytest.raises(ValueError, match="y has shape"):
663            loo_pit(y=y, y_hat=y_hat[1:3, :], log_weights=log_weights)
664    elif incompatibility == "y_hat-log_weights":
665        with pytest.raises(ValueError, match="must have the same shape"):
666            loo_pit(y=y, y_hat=y_hat[:, :100], log_weights=log_weights)
667
668
669@pytest.mark.parametrize("pointwise", [True, False])
670@pytest.mark.parametrize("inplace", [True, False])
671@pytest.mark.parametrize(
672    "kwargs",
673    [
674        {},
675        {"group": "posterior_predictive", "var_names": {"posterior_predictive": "obs"}},
676        {"group": "observed_data", "var_names": {"both": "obs"}, "out_data_shape": "shape"},
677        {"var_names": {"both": "obs", "posterior": ["theta", "mu"]}},
678        {"group": "observed_data", "out_name_data": "T_name"},
679    ],
680)
681def test_apply_test_function(centered_eight, pointwise, inplace, kwargs):
682    """Test some usual call cases of apply_test_function"""
683    centered_eight = deepcopy(centered_eight)
684    group = kwargs.get("group", "both")
685    var_names = kwargs.get("var_names", None)
686    out_data_shape = kwargs.get("out_data_shape", None)
687    out_pp_shape = kwargs.get("out_pp_shape", None)
688    out_name_data = kwargs.get("out_name_data", "T")
689    if out_data_shape == "shape":
690        out_data_shape = (8,) if pointwise else ()
691    if out_pp_shape == "shape":
692        out_pp_shape = (4, 500, 8) if pointwise else (4, 500)
693    idata = deepcopy(centered_eight)
694    idata_out = apply_test_function(
695        idata,
696        lambda y, theta: np.mean(y),
697        group=group,
698        var_names=var_names,
699        pointwise=pointwise,
700        out_name_data=out_name_data,
701        out_data_shape=out_data_shape,
702        out_pp_shape=out_pp_shape,
703    )
704    if inplace:
705        assert idata is idata_out
706
707    if group == "both":
708        test_dict = {"observed_data": ["T"], "posterior_predictive": ["T"]}
709    else:
710        test_dict = {group: [kwargs.get("out_name_data", "T")]}
711
712    fails = check_multiple_attrs(test_dict, idata_out)
713    assert not fails
714
715
716def test_apply_test_function_bad_group(centered_eight):
717    """Test error when group is an invalid name."""
718    with pytest.raises(ValueError, match="Invalid group argument"):
719        apply_test_function(centered_eight, lambda y, theta: y, group="bad_group")
720
721
722def test_apply_test_function_missing_group():
723    """Test error when InferenceData object is missing a required group.
724
725    The function cannot work if group="both" but InferenceData object has no
726    posterior_predictive group.
727    """
728    idata = from_dict(
729        posterior={"a": np.random.random((4, 500, 30))}, observed_data={"y": np.random.random(30)}
730    )
731    with pytest.raises(ValueError, match="must have posterior_predictive"):
732        apply_test_function(idata, lambda y, theta: np.mean, group="both")
733
734
735def test_apply_test_function_should_overwrite_error(centered_eight):
736    """Test error when overwrite=False but out_name is already a present variable."""
737    with pytest.raises(ValueError, match="Should overwrite"):
738        apply_test_function(centered_eight, lambda y, theta: y, out_name_data="obs")
739