1# pylint: disable=redefined-outer-name, no-member 2from copy import deepcopy 3 4import numpy as np 5import pytest 6from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal 7from scipy.stats import linregress 8from xarray import DataArray, Dataset 9 10from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data 11from ...rcparams import rcParams 12from ...stats import ( 13 apply_test_function, 14 compare, 15 ess, 16 hdi, 17 loo, 18 loo_pit, 19 psislw, 20 r2_score, 21 summary, 22 waic, 23) 24from ...stats.stats import _gpinv 25from ...stats.stats_utils import get_log_likelihood 26from ..helpers import check_multiple_attrs, multidim_models # pylint: disable=unused-import 27 28rcParams["data.load"] = "eager" 29 30 31@pytest.fixture(scope="session") 32def centered_eight(): 33 centered_eight = load_arviz_data("centered_eight") 34 return centered_eight 35 36 37@pytest.fixture(scope="session") 38def non_centered_eight(): 39 non_centered_eight = load_arviz_data("non_centered_eight") 40 return non_centered_eight 41 42 43@pytest.fixture(scope="module") 44def multivariable_log_likelihood(centered_eight): 45 centered_eight = centered_eight.copy() 46 centered_eight.add_groups({"log_likelihood": centered_eight.sample_stats.log_likelihood}) 47 centered_eight.log_likelihood = centered_eight.log_likelihood.rename_vars( 48 {"log_likelihood": "obs"} 49 ) 50 new_arr = DataArray( 51 np.zeros(centered_eight.log_likelihood["obs"].values.shape), 52 dims=["chain", "draw", "school"], 53 coords=centered_eight.log_likelihood.coords, 54 ) 55 centered_eight.log_likelihood["decoy"] = new_arr 56 delattr(centered_eight, "sample_stats") 57 return centered_eight 58 59 60def test_hdp(): 61 normal_sample = np.random.randn(5000000) 62 interval = hdi(normal_sample) 63 assert_array_almost_equal(interval, [-1.88, 1.88], 2) 64 65 66def test_hdp_2darray(): 67 normal_sample = np.random.randn(12000, 5) 68 msg = ( 69 r"hdi currently interprets 2d data as \(draw, shape\) but this will " 70 r"change in a future release to \(chain, draw\) for coherence with other functions" 71 ) 72 with pytest.warns(FutureWarning, match=msg): 73 result = hdi(normal_sample) 74 assert result.shape == (5, 2) 75 76 77def test_hdi_multidimension(): 78 normal_sample = np.random.randn(12000, 10, 3) 79 result = hdi(normal_sample) 80 assert result.shape == (3, 2) 81 82 83def test_hdi_idata(centered_eight): 84 data = centered_eight.posterior 85 result = hdi(data) 86 assert isinstance(result, Dataset) 87 assert dict(result.dims) == {"school": 8, "hdi": 2} 88 89 result = hdi(data, input_core_dims=[["chain"]]) 90 assert isinstance(result, Dataset) 91 assert result.dims == {"draw": 500, "hdi": 2, "school": 8} 92 93 94def test_hdi_idata_varnames(centered_eight): 95 data = centered_eight.posterior 96 result = hdi(data, var_names=["mu", "theta"]) 97 assert isinstance(result, Dataset) 98 assert result.dims == {"hdi": 2, "school": 8} 99 assert list(result.data_vars.keys()) == ["mu", "theta"] 100 101 102def test_hdi_idata_group(centered_eight): 103 result_posterior = hdi(centered_eight, group="posterior", var_names="mu") 104 result_prior = hdi(centered_eight, group="prior", var_names="mu") 105 assert result_prior.dims == {"hdi": 2} 106 range_posterior = result_posterior.mu.values[1] - result_posterior.mu.values[0] 107 range_prior = result_prior.mu.values[1] - result_prior.mu.values[0] 108 assert range_posterior < range_prior 109 110 111def test_hdi_coords(centered_eight): 112 data = centered_eight.posterior 113 result = hdi(data, coords={"chain": [0, 1, 3]}, input_core_dims=[["draw"]]) 114 assert_array_equal(result.coords["chain"], [0, 1, 3]) 115 116 117def test_hdi_multimodal(): 118 normal_sample = np.concatenate( 119 (np.random.normal(-4, 1, 2500000), np.random.normal(2, 0.5, 2500000)) 120 ) 121 intervals = hdi(normal_sample, multimodal=True) 122 assert_array_almost_equal(intervals, [[-5.8, -2.2], [0.9, 3.1]], 1) 123 124 125def test_hdi_circular(): 126 normal_sample = np.random.vonmises(np.pi, 1, 5000000) 127 interval = hdi(normal_sample, circular=True) 128 assert_array_almost_equal(interval, [0.6, -0.6], 1) 129 130 131def test_hdi_bad_ci(): 132 normal_sample = np.random.randn(10) 133 with pytest.raises(ValueError): 134 hdi(normal_sample, hdi_prob=2) 135 136 137def test_hdi_skipna(): 138 normal_sample = np.random.randn(500) 139 interval = hdi(normal_sample[10:]) 140 normal_sample[:10] = np.nan 141 interval_ = hdi(normal_sample, skipna=True) 142 assert_array_almost_equal(interval, interval_) 143 144 145def test_r2_score(): 146 x = np.linspace(0, 1, 100) 147 y = np.random.normal(x, 1) 148 res = linregress(x, y) 149 assert_allclose(res.rvalue ** 2, r2_score(y, res.intercept + res.slope * x).r2, 2) 150 151 152def test_r2_score_multivariate(): 153 x = np.linspace(0, 1, 100) 154 y = np.random.normal(x, 1) 155 res = linregress(x, y) 156 y_multivariate = np.c_[y, y] 157 y_multivariate_pred = np.c_[res.intercept + res.slope * x, res.intercept + res.slope * x] 158 assert not np.isnan(r2_score(y_multivariate, y_multivariate_pred).r2) 159 160 161@pytest.mark.parametrize("method", ["stacking", "BB-pseudo-BMA", "pseudo-BMA"]) 162@pytest.mark.parametrize("multidim", [True, False]) 163def test_compare_same(centered_eight, multidim_models, method, multidim): 164 if multidim: 165 data_dict = {"first": multidim_models.model_1, "second": multidim_models.model_1} 166 else: 167 data_dict = {"first": centered_eight, "second": centered_eight} 168 169 weight = compare(data_dict, method=method)["weight"] 170 assert_allclose(weight[0], weight[1]) 171 assert_allclose(np.sum(weight), 1.0) 172 173 174def test_compare_unknown_ic_and_method(centered_eight, non_centered_eight): 175 model_dict = {"centered": centered_eight, "non_centered": non_centered_eight} 176 with pytest.raises(ValueError): 177 compare(model_dict, ic="Unknown", method="stacking") 178 with pytest.raises(ValueError): 179 compare(model_dict, ic="loo", method="Unknown") 180 181 182@pytest.mark.parametrize("ic", ["loo", "waic"]) 183@pytest.mark.parametrize("method", ["stacking", "BB-pseudo-BMA", "pseudo-BMA"]) 184@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"]) 185def test_compare_different(centered_eight, non_centered_eight, ic, method, scale): 186 model_dict = {"centered": centered_eight, "non_centered": non_centered_eight} 187 weight = compare(model_dict, ic=ic, method=method, scale=scale)["weight"] 188 assert weight["non_centered"] > weight["centered"] 189 assert_allclose(np.sum(weight), 1.0) 190 191 192@pytest.mark.parametrize("ic", ["loo", "waic"]) 193@pytest.mark.parametrize("method", ["stacking", "BB-pseudo-BMA", "pseudo-BMA"]) 194def test_compare_different_multidim(multidim_models, ic, method): 195 model_dict = {"model_1": multidim_models.model_1, "model_2": multidim_models.model_2} 196 weight = compare(model_dict, ic=ic, method=method)["weight"] 197 198 # this should hold because the same seed is always used 199 assert weight["model_1"] > weight["model_2"] 200 assert_allclose(np.sum(weight), 1.0) 201 202 203def test_compare_different_size(centered_eight, non_centered_eight): 204 centered_eight = deepcopy(centered_eight) 205 centered_eight.posterior = centered_eight.posterior.drop("Choate", "school") 206 centered_eight.sample_stats = centered_eight.sample_stats.drop("Choate", "school") 207 centered_eight.posterior_predictive = centered_eight.posterior_predictive.drop( 208 "Choate", "school" 209 ) 210 centered_eight.prior = centered_eight.prior.drop("Choate", "school") 211 centered_eight.observed_data = centered_eight.observed_data.drop("Choate", "school") 212 model_dict = {"centered": centered_eight, "non_centered": non_centered_eight} 213 with pytest.raises(ValueError): 214 compare(model_dict, ic="waic", method="stacking") 215 216 217@pytest.mark.parametrize("ic", ["loo", "waic"]) 218def test_compare_multiple_obs(multivariable_log_likelihood, centered_eight, non_centered_eight, ic): 219 compare_dict = { 220 "centered_eight": centered_eight, 221 "non_centered_eight": non_centered_eight, 222 "problematic": multivariable_log_likelihood, 223 } 224 with pytest.raises(TypeError, match="several log likelihood arrays"): 225 get_log_likelihood(compare_dict["problematic"]) 226 with pytest.raises(TypeError, match=f"{ic}.*model problematic"): 227 compare(compare_dict, ic=ic) 228 assert compare(compare_dict, ic=ic, var_name="obs") is not None 229 230 231def test_summary_ndarray(): 232 array = np.random.randn(4, 100, 2) 233 summary_df = summary(array) 234 assert summary_df.shape 235 236 237@pytest.mark.parametrize("var_names_expected", ((None, 10), ("mu", 1), (["mu", "tau"], 2))) 238def test_summary_var_names(centered_eight, var_names_expected): 239 var_names, expected = var_names_expected 240 summary_df = summary(centered_eight, var_names=var_names) 241 assert len(summary_df.index) == expected 242 243 244@pytest.mark.parametrize("missing_groups", (None, "posterior", "prior")) 245def test_summary_groups(centered_eight, missing_groups): 246 if missing_groups == "posterior": 247 centered_eight = deepcopy(centered_eight) 248 del centered_eight.posterior 249 elif missing_groups == "prior": 250 centered_eight = deepcopy(centered_eight) 251 del centered_eight.posterior 252 del centered_eight.prior 253 if missing_groups == "prior": 254 with pytest.warns(UserWarning): 255 summary_df = summary(centered_eight) 256 else: 257 summary_df = summary(centered_eight) 258 assert summary_df.shape 259 260 261def test_summary_group_argument(centered_eight): 262 summary_df_posterior = summary(centered_eight, group="posterior") 263 summary_df_prior = summary(centered_eight, group="prior") 264 assert list(summary_df_posterior.index) != list(summary_df_prior.index) 265 266 267def test_summary_wrong_group(centered_eight): 268 with pytest.raises(TypeError, match=r"InferenceData does not contain group: InvalidGroup"): 269 summary(centered_eight, group="InvalidGroup") 270 271 272METRICS_NAMES = [ 273 "mean", 274 "sd", 275 "hdi_3%", 276 "hdi_97%", 277 "mcse_mean", 278 "mcse_sd", 279 "ess_bulk", 280 "ess_tail", 281 "r_hat", 282] 283 284 285@pytest.mark.parametrize( 286 "params", 287 (("all", METRICS_NAMES), ("stats", METRICS_NAMES[:4]), ("diagnostics", METRICS_NAMES[4:])), 288) 289def test_summary_kind(centered_eight, params): 290 kind, metrics_names_ = params 291 summary_df = summary(centered_eight, kind=kind) 292 assert_array_equal(summary_df.columns, metrics_names_) 293 294 295@pytest.mark.parametrize("fmt", ["wide", "long", "xarray"]) 296def test_summary_fmt(centered_eight, fmt): 297 assert summary(centered_eight, fmt=fmt) is not None 298 299 300def test_summary_labels(): 301 coords1 = list("abcd") 302 coords2 = np.arange(1, 6) 303 data = from_dict( 304 {"a": np.random.randn(4, 100, 4, 5)}, 305 coords={"dim1": coords1, "dim2": coords2}, 306 dims={"a": ["dim1", "dim2"]}, 307 ) 308 az_summary = summary(data, fmt="wide") 309 assert az_summary is not None 310 column_order = [] 311 for coord1 in coords1: 312 for coord2 in coords2: 313 column_order.append(f"a[{coord1}, {coord2}]") 314 for col1, col2 in zip(list(az_summary.index), column_order): 315 assert col1 == col2 316 317 318@pytest.mark.parametrize( 319 "stat_funcs", [[np.var], {"var": np.var, "var2": lambda x: np.var(x) ** 2}] 320) 321def test_summary_stat_func(centered_eight, stat_funcs): 322 arviz_summary = summary(centered_eight, stat_funcs=stat_funcs) 323 assert arviz_summary is not None 324 assert hasattr(arviz_summary, "var") 325 326 327def test_summary_nan(centered_eight): 328 centered_eight = deepcopy(centered_eight) 329 centered_eight.posterior["theta"].loc[{"school": "Deerfield"}] = np.nan 330 summary_xarray = summary(centered_eight) 331 assert summary_xarray is not None 332 assert summary_xarray.loc["theta[Deerfield]"].isnull().all() 333 assert ( 334 summary_xarray.loc[[ix for ix in summary_xarray.index if ix != "theta[Deerfield]"]] 335 .notnull() 336 .all() 337 .all() 338 ) 339 340 341def test_summary_skip_nan(centered_eight): 342 centered_eight = deepcopy(centered_eight) 343 centered_eight.posterior["theta"].loc[{"draw": slice(10), "school": "Deerfield"}] = np.nan 344 summary_xarray = summary(centered_eight) 345 theta_1 = summary_xarray.loc["theta[Deerfield]"].isnull() 346 assert summary_xarray is not None 347 assert ~theta_1[:4].all() 348 assert theta_1[4:].all() 349 350 351@pytest.mark.parametrize("fmt", [1, "bad_fmt"]) 352def test_summary_bad_fmt(centered_eight, fmt): 353 with pytest.raises(TypeError, match="Invalid format"): 354 summary(centered_eight, fmt=fmt) 355 356 357def test_summary_order_deprecation(centered_eight): 358 with pytest.warns(DeprecationWarning, match="order"): 359 summary(centered_eight, order="C") 360 361 362def test_summary_index_origin_deprecation(centered_eight): 363 with pytest.warns(DeprecationWarning, match="index_origin"): 364 summary(centered_eight, index_origin=1) 365 366 367@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"]) 368@pytest.mark.parametrize("multidim", (True, False)) 369def test_waic(centered_eight, multidim_models, scale, multidim): 370 """Test widely available information criterion calculation""" 371 if multidim: 372 assert waic(multidim_models.model_1, scale=scale) is not None 373 waic_pointwise = waic(multidim_models.model_1, pointwise=True, scale=scale) 374 else: 375 assert waic(centered_eight, scale=scale) is not None 376 waic_pointwise = waic(centered_eight, pointwise=True, scale=scale) 377 assert waic_pointwise is not None 378 assert "waic_i" in waic_pointwise 379 380 381def test_waic_bad(centered_eight): 382 """Test widely available information criterion calculation""" 383 centered_eight = deepcopy(centered_eight) 384 del centered_eight.sample_stats["log_likelihood"] 385 with pytest.raises(TypeError): 386 waic(centered_eight) 387 388 del centered_eight.sample_stats 389 with pytest.raises(TypeError): 390 waic(centered_eight) 391 392 393def test_waic_bad_scale(centered_eight): 394 """Test widely available information criterion calculation with bad scale.""" 395 with pytest.raises(TypeError): 396 waic(centered_eight, scale="bad_value") 397 398 399def test_waic_warning(centered_eight): 400 centered_eight = deepcopy(centered_eight) 401 centered_eight.sample_stats["log_likelihood"][:, :250, 1] = 10 402 with pytest.warns(UserWarning): 403 assert waic(centered_eight, pointwise=True) is not None 404 # this should throw a warning, but due to numerical issues it fails 405 centered_eight.sample_stats["log_likelihood"][:, :, :] = 0 406 with pytest.warns(UserWarning): 407 assert waic(centered_eight, pointwise=True) is not None 408 409 410@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"]) 411def test_waic_print(centered_eight, scale): 412 waic_data = waic(centered_eight, scale=scale).__repr__() 413 waic_pointwise = waic(centered_eight, scale=scale, pointwise=True).__repr__() 414 assert waic_data is not None 415 assert waic_pointwise is not None 416 assert waic_data == waic_pointwise 417 418 419@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"]) 420@pytest.mark.parametrize("multidim", (True, False)) 421def test_loo(centered_eight, multidim_models, scale, multidim): 422 """Test approximate leave one out criterion calculation""" 423 if multidim: 424 assert loo(multidim_models.model_1, scale=scale) is not None 425 loo_pointwise = loo(multidim_models.model_1, pointwise=True, scale=scale) 426 else: 427 assert loo(centered_eight, scale=scale) is not None 428 loo_pointwise = loo(centered_eight, pointwise=True, scale=scale) 429 assert loo_pointwise is not None 430 assert "loo_i" in loo_pointwise 431 assert "pareto_k" in loo_pointwise 432 assert "loo_scale" in loo_pointwise 433 434 435def test_loo_one_chain(centered_eight): 436 centered_eight = deepcopy(centered_eight) 437 centered_eight.posterior = centered_eight.posterior.drop([1, 2, 3], "chain") 438 centered_eight.sample_stats = centered_eight.sample_stats.drop([1, 2, 3], "chain") 439 assert loo(centered_eight) is not None 440 441 442def test_loo_bad(centered_eight): 443 with pytest.raises(TypeError): 444 loo(np.random.randn(2, 10)) 445 446 centered_eight = deepcopy(centered_eight) 447 del centered_eight.sample_stats["log_likelihood"] 448 with pytest.raises(TypeError): 449 loo(centered_eight) 450 451 452def test_loo_bad_scale(centered_eight): 453 """Test loo with bad scale value.""" 454 with pytest.raises(TypeError): 455 loo(centered_eight, scale="bad_scale") 456 457 458def test_loo_bad_no_posterior_reff(centered_eight): 459 loo(centered_eight, reff=None) 460 centered_eight = deepcopy(centered_eight) 461 del centered_eight.posterior 462 with pytest.raises(TypeError): 463 loo(centered_eight, reff=None) 464 loo(centered_eight, reff=0.7) 465 466 467def test_loo_warning(centered_eight): 468 centered_eight = deepcopy(centered_eight) 469 # make one of the khats infinity 470 centered_eight.sample_stats["log_likelihood"][:, :, 1] = 10 471 with pytest.warns(UserWarning) as records: 472 assert loo(centered_eight, pointwise=True) is not None 473 assert any("Estimated shape parameter" in str(record.message) for record in records) 474 475 # make all of the khats infinity 476 centered_eight.sample_stats["log_likelihood"][:, :, :] = 1 477 with pytest.warns(UserWarning) as records: 478 assert loo(centered_eight, pointwise=True) is not None 479 assert any("Estimated shape parameter" in str(record.message) for record in records) 480 481 482@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"]) 483def test_loo_print(centered_eight, scale): 484 loo_data = loo(centered_eight, scale=scale).__repr__() 485 loo_pointwise = loo(centered_eight, scale=scale, pointwise=True).__repr__() 486 assert loo_data is not None 487 assert loo_pointwise is not None 488 assert len(loo_data) < len(loo_pointwise) 489 490 491def test_psislw(centered_eight): 492 pareto_k = loo(centered_eight, pointwise=True, reff=0.7)["pareto_k"] 493 log_likelihood = get_log_likelihood(centered_eight) 494 log_likelihood = log_likelihood.stack(__sample__=("chain", "draw")) 495 assert_allclose(pareto_k, psislw(-log_likelihood, 0.7)[1]) 496 497 498@pytest.mark.parametrize("probs", [True, False]) 499@pytest.mark.parametrize("kappa", [-1, -0.5, 1e-30, 0.5, 1]) 500@pytest.mark.parametrize("sigma", [0, 2]) 501def test_gpinv(probs, kappa, sigma): 502 if probs: 503 probs = np.array([0.1, 0.1, 0.1, 0.2, 0.3]) 504 else: 505 probs = np.array([-0.1, 0.1, 0.1, 0.2, 0.3]) 506 assert len(_gpinv(probs, kappa, sigma)) == len(probs) 507 508 509@pytest.mark.parametrize("func", [loo, waic]) 510def test_multidimensional_log_likelihood(func): 511 llm = np.random.rand(4, 23, 15, 2) 512 ll1 = llm.reshape(4, 23, 15 * 2) 513 statsm = Dataset(dict(log_likelihood=DataArray(llm, dims=["chain", "draw", "a", "b"]))) 514 515 stats1 = Dataset(dict(log_likelihood=DataArray(ll1, dims=["chain", "draw", "v"]))) 516 517 post = Dataset(dict(mu=DataArray(np.random.rand(4, 23, 2), dims=["chain", "draw", "v"]))) 518 519 dsm = convert_to_inference_data(statsm, group="sample_stats") 520 ds1 = convert_to_inference_data(stats1, group="sample_stats") 521 dsp = convert_to_inference_data(post, group="posterior") 522 523 dsm = concat(dsp, dsm) 524 ds1 = concat(dsp, ds1) 525 526 frm = func(dsm) 527 fr1 = func(ds1) 528 529 assert (fr1 == frm).all() 530 assert_array_almost_equal(frm[:4], fr1[:4]) 531 532 533@pytest.mark.parametrize( 534 "args", 535 [ 536 {"y": "obs"}, 537 {"y": "obs", "y_hat": "obs"}, 538 {"y": "arr", "y_hat": "obs"}, 539 {"y": "obs", "y_hat": "arr"}, 540 {"y": "arr", "y_hat": "arr"}, 541 {"y": "obs", "y_hat": "obs", "log_weights": "arr"}, 542 {"y": "arr", "y_hat": "obs", "log_weights": "arr"}, 543 {"y": "obs", "y_hat": "arr", "log_weights": "arr"}, 544 {"idata": False}, 545 ], 546) 547def test_loo_pit(centered_eight, args): 548 y = args.get("y", None) 549 y_hat = args.get("y_hat", None) 550 log_weights = args.get("log_weights", None) 551 y_arr = centered_eight.observed_data.obs 552 y_hat_arr = centered_eight.posterior_predictive.obs.stack(__sample__=("chain", "draw")) 553 log_like = get_log_likelihood(centered_eight).stack(__sample__=("chain", "draw")) 554 n_samples = len(log_like.__sample__) 555 ess_p = ess(centered_eight.posterior, method="mean") 556 reff = np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples 557 log_weights_arr = psislw(-log_like, reff=reff)[0] 558 559 if args.get("idata", True): 560 if y == "arr": 561 y = y_arr 562 if y_hat == "arr": 563 y_hat = y_hat_arr 564 if log_weights == "arr": 565 log_weights = log_weights_arr 566 loo_pit_data = loo_pit(idata=centered_eight, y=y, y_hat=y_hat, log_weights=log_weights) 567 else: 568 loo_pit_data = loo_pit(idata=None, y=y_arr, y_hat=y_hat_arr, log_weights=log_weights_arr) 569 assert np.all((loo_pit_data >= 0) & (loo_pit_data <= 1)) 570 571 572@pytest.mark.parametrize( 573 "args", 574 [ 575 {"y": "y"}, 576 {"y": "y", "y_hat": "y"}, 577 {"y": "arr", "y_hat": "y"}, 578 {"y": "y", "y_hat": "arr"}, 579 {"y": "arr", "y_hat": "arr"}, 580 {"y": "y", "y_hat": "y", "log_weights": "arr"}, 581 {"y": "arr", "y_hat": "y", "log_weights": "arr"}, 582 {"y": "y", "y_hat": "arr", "log_weights": "arr"}, 583 {"idata": False}, 584 ], 585) 586def test_loo_pit_multidim(multidim_models, args): 587 y = args.get("y", None) 588 y_hat = args.get("y_hat", None) 589 log_weights = args.get("log_weights", None) 590 idata = multidim_models.model_1 591 y_arr = idata.observed_data.y 592 y_hat_arr = idata.posterior_predictive.y.stack(__sample__=("chain", "draw")) 593 log_like = get_log_likelihood(idata).stack(__sample__=("chain", "draw")) 594 n_samples = len(log_like.__sample__) 595 ess_p = ess(idata.posterior, method="mean") 596 reff = np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples 597 log_weights_arr = psislw(-log_like, reff=reff)[0] 598 599 if args.get("idata", True): 600 if y == "arr": 601 y = y_arr 602 if y_hat == "arr": 603 y_hat = y_hat_arr 604 if log_weights == "arr": 605 log_weights = log_weights_arr 606 loo_pit_data = loo_pit(idata=idata, y=y, y_hat=y_hat, log_weights=log_weights) 607 else: 608 loo_pit_data = loo_pit(idata=None, y=y_arr, y_hat=y_hat_arr, log_weights=log_weights_arr) 609 assert np.all((loo_pit_data >= 0) & (loo_pit_data <= 1)) 610 611 612def test_loo_pit_multi_lik(): 613 rng = np.random.default_rng(0) 614 post_pred = rng.standard_normal(size=(4, 100, 10)) 615 obs = np.quantile(post_pred, np.linspace(0, 1, 10)) 616 obs[0] *= 0.9 617 obs[-1] *= 1.1 618 idata = from_dict( 619 posterior={"a": np.random.randn(4, 100)}, 620 posterior_predictive={"y": post_pred}, 621 observed_data={"y": obs}, 622 log_likelihood={"y": -(post_pred ** 2), "decoy": np.zeros_like(post_pred)}, 623 ) 624 loo_pit_data = loo_pit(idata, y="y") 625 assert np.all((loo_pit_data >= 0) & (loo_pit_data <= 1)) 626 627 628@pytest.mark.parametrize("input_type", ["idataarray", "idatanone_ystr", "yarr_yhatnone"]) 629def test_loo_pit_bad_input(centered_eight, input_type): 630 """Test incompatible input combinations.""" 631 arr = np.random.random((8, 200)) 632 if input_type == "idataarray": 633 with pytest.raises(ValueError, match=r"type InferenceData or None"): 634 loo_pit(idata=arr, y="obs") 635 elif input_type == "idatanone_ystr": 636 with pytest.raises(ValueError, match=r"all 3.+must be array or DataArray"): 637 loo_pit(idata=None, y="obs") 638 elif input_type == "yarr_yhatnone": 639 with pytest.raises(ValueError, match=r"y_hat.+None.+y.+str"): 640 loo_pit(idata=centered_eight, y=arr, y_hat=None) 641 642 643@pytest.mark.parametrize("arg", ["y", "y_hat", "log_weights"]) 644def test_loo_pit_bad_input_type(centered_eight, arg): 645 """Test wrong input type (not None, str not DataArray.""" 646 kwargs = {"y": "obs", "y_hat": "obs", "log_weights": None} 647 kwargs[arg] = 2 # use int instead of array-like 648 with pytest.raises(ValueError, match=f"not {type(2)}"): 649 loo_pit(idata=centered_eight, **kwargs) 650 651 652@pytest.mark.parametrize("incompatibility", ["y-y_hat1", "y-y_hat2", "y_hat-log_weights"]) 653def test_loo_pit_bad_input_shape(incompatibility): 654 """Test shape incompatibilities.""" 655 y = np.random.random(8) 656 y_hat = np.random.random((8, 200)) 657 log_weights = np.random.random((8, 200)) 658 if incompatibility == "y-y_hat1": 659 with pytest.raises(ValueError, match="1 more dimension"): 660 loo_pit(y=y, y_hat=y_hat[None, :], log_weights=log_weights) 661 elif incompatibility == "y-y_hat2": 662 with pytest.raises(ValueError, match="y has shape"): 663 loo_pit(y=y, y_hat=y_hat[1:3, :], log_weights=log_weights) 664 elif incompatibility == "y_hat-log_weights": 665 with pytest.raises(ValueError, match="must have the same shape"): 666 loo_pit(y=y, y_hat=y_hat[:, :100], log_weights=log_weights) 667 668 669@pytest.mark.parametrize("pointwise", [True, False]) 670@pytest.mark.parametrize("inplace", [True, False]) 671@pytest.mark.parametrize( 672 "kwargs", 673 [ 674 {}, 675 {"group": "posterior_predictive", "var_names": {"posterior_predictive": "obs"}}, 676 {"group": "observed_data", "var_names": {"both": "obs"}, "out_data_shape": "shape"}, 677 {"var_names": {"both": "obs", "posterior": ["theta", "mu"]}}, 678 {"group": "observed_data", "out_name_data": "T_name"}, 679 ], 680) 681def test_apply_test_function(centered_eight, pointwise, inplace, kwargs): 682 """Test some usual call cases of apply_test_function""" 683 centered_eight = deepcopy(centered_eight) 684 group = kwargs.get("group", "both") 685 var_names = kwargs.get("var_names", None) 686 out_data_shape = kwargs.get("out_data_shape", None) 687 out_pp_shape = kwargs.get("out_pp_shape", None) 688 out_name_data = kwargs.get("out_name_data", "T") 689 if out_data_shape == "shape": 690 out_data_shape = (8,) if pointwise else () 691 if out_pp_shape == "shape": 692 out_pp_shape = (4, 500, 8) if pointwise else (4, 500) 693 idata = deepcopy(centered_eight) 694 idata_out = apply_test_function( 695 idata, 696 lambda y, theta: np.mean(y), 697 group=group, 698 var_names=var_names, 699 pointwise=pointwise, 700 out_name_data=out_name_data, 701 out_data_shape=out_data_shape, 702 out_pp_shape=out_pp_shape, 703 ) 704 if inplace: 705 assert idata is idata_out 706 707 if group == "both": 708 test_dict = {"observed_data": ["T"], "posterior_predictive": ["T"]} 709 else: 710 test_dict = {group: [kwargs.get("out_name_data", "T")]} 711 712 fails = check_multiple_attrs(test_dict, idata_out) 713 assert not fails 714 715 716def test_apply_test_function_bad_group(centered_eight): 717 """Test error when group is an invalid name.""" 718 with pytest.raises(ValueError, match="Invalid group argument"): 719 apply_test_function(centered_eight, lambda y, theta: y, group="bad_group") 720 721 722def test_apply_test_function_missing_group(): 723 """Test error when InferenceData object is missing a required group. 724 725 The function cannot work if group="both" but InferenceData object has no 726 posterior_predictive group. 727 """ 728 idata = from_dict( 729 posterior={"a": np.random.random((4, 500, 30))}, observed_data={"y": np.random.random(30)} 730 ) 731 with pytest.raises(ValueError, match="must have posterior_predictive"): 732 apply_test_function(idata, lambda y, theta: np.mean, group="both") 733 734 735def test_apply_test_function_should_overwrite_error(centered_eight): 736 """Test error when overwrite=False but out_name is already a present variable.""" 737 with pytest.raises(ValueError, match="Should overwrite"): 738 apply_test_function(centered_eight, lambda y, theta: y, out_name_data="obs") 739