1# pylint: disable=too-many-lines, too-many-function-args, redefined-outer-name 2"""Diagnostic functions for ArviZ.""" 3import warnings 4from collections.abc import Sequence 5 6import numpy as np 7import pandas as pd 8from scipy import stats 9 10from ..data import convert_to_dataset 11from ..utils import Numba, _numba_var, _stack, _var_names 12from .density_utils import histogram as _histogram 13from .stats_utils import _circular_standard_deviation, _sqrt 14from .stats_utils import autocov as _autocov 15from .stats_utils import not_valid as _not_valid 16from .stats_utils import quantile as _quantile 17from .stats_utils import stats_variance_2d as svar 18from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc 19 20__all__ = ["bfmi", "ess", "rhat", "mcse"] 21 22 23def bfmi(data): 24 r"""Calculate the estimated Bayesian fraction of missing information (BFMI). 25 26 BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more 27 information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that 28 values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may 29 change. See http://mc-stan.org/users/documentation/case-studies/pystan_workflow.html for more 30 information. 31 32 Parameters 33 ---------- 34 data : obj 35 Any object that can be converted to an az.InferenceData object. 36 Refer to documentation of az.convert_to_dataset for details. 37 If InferenceData, energy variable needs to be found. 38 39 Returns 40 ------- 41 z : array 42 The Bayesian fraction of missing information of the model and trace. One element per 43 chain in the trace. 44 45 Examples 46 -------- 47 Compute the BFMI of an InferenceData object 48 49 .. ipython:: 50 51 In [1]: import arviz as az 52 ...: data = az.load_arviz_data('radon') 53 ...: az.bfmi(data) 54 55 """ 56 if isinstance(data, np.ndarray): 57 return _bfmi(data) 58 59 dataset = convert_to_dataset(data, group="sample_stats") 60 if not hasattr(dataset, "energy"): 61 raise TypeError("Energy variable was not found.") 62 return _bfmi(dataset.energy) 63 64 65def ess( 66 data, 67 *, 68 var_names=None, 69 method="bulk", 70 relative=False, 71 prob=None, 72 dask_kwargs=None, 73): 74 r"""Calculate estimate of the effective sample size (ess). 75 76 Parameters 77 ---------- 78 data : obj 79 Any object that can be converted to an ``az.InferenceData`` object. 80 Refer to documentation of ``az.convert_to_dataset`` for details. 81 For ndarray: shape = (chain, draw). 82 For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``. 83 var_names : str or list of str 84 Names of variables to include in the return value Dataset. 85 method : str, optional, default "bulk" 86 Select ess method. Valid methods are: 87 88 - "bulk" 89 - "tail" # prob, optional 90 - "quantile" # prob 91 - "mean" (old ess) 92 - "sd" 93 - "median" 94 - "mad" (mean absolute deviance) 95 - "z_scale" 96 - "folded" 97 - "identity" 98 - "local" 99 relative : bool 100 Return relative ess 101 `ress = ess / n` 102 prob : float, or tuple of two floats, optional 103 probability value for "tail", "quantile" or "local" ess functions. 104 dask_kwargs : dict, optional 105 Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`. 106 107 Returns 108 ------- 109 xarray.Dataset 110 Return the effective sample size, :math:`\hat{N}_{eff}` 111 112 Notes 113 ----- 114 The basic ess (:math:`N_{\mathit{eff}}`) diagnostic is computed by: 115 116 .. math:: \hat{N}_{\mathit{eff}} = \frac{MN}{\hat{\tau}} 117 118 .. math:: \hat{\tau} = -1 + 2 \sum_{t'=0}^K \hat{P}_{t'} 119 120 where :math:`M` is the number of chains, :math:`N` the number of draws, 121 :math:`\hat{\rho}_t` is the estimated _autocorrelation at lag :math:`t`, and 122 :math:`K` is the last integer for which :math:`\hat{P}_{K} = \hat{\rho}_{2K} + 123 \hat{\rho}_{2K+1}` is still positive. 124 125 The current implementation is similar to Stan, which uses Geyer's initial monotone sequence 126 criterion (Geyer, 1992; Geyer, 2011). 127 128 References 129 ---------- 130 * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008 131 * https://mc-stan.org/docs/2_18/reference-manual/effective-sample-size-section.html 132 Section 15.4.2 133 * Gelman et al. BDA (2014) Formula 11.8 134 135 Examples 136 -------- 137 Calculate the effective_sample_size using the default arguments: 138 139 .. ipython:: 140 141 In [1]: import arviz as az 142 ...: data = az.load_arviz_data('non_centered_eight') 143 ...: az.ess(data) 144 145 Calculate the ress of some of the variables 146 147 .. ipython:: 148 149 In [1]: az.ess(data, relative=True, var_names=["mu", "theta_t"]) 150 151 Calculate the ess using the "tail" method, leaving the `prob` argument at its default 152 value. 153 154 .. ipython:: 155 156 In [1]: az.ess(data, method="tail") 157 158 """ 159 methods = { 160 "bulk": _ess_bulk, 161 "tail": _ess_tail, 162 "quantile": _ess_quantile, 163 "mean": _ess_mean, 164 "sd": _ess_sd, 165 "median": _ess_median, 166 "mad": _ess_mad, 167 "z_scale": _ess_z_scale, 168 "folded": _ess_folded, 169 "identity": _ess_identity, 170 "local": _ess_local, 171 } 172 173 if method not in methods: 174 raise TypeError(f"ess method {method} not found. Valid methods are:\n{', '.join(methods)}") 175 ess_func = methods[method] 176 177 if (method == "quantile") and prob is None: 178 raise TypeError("Quantile (prob) information needs to be defined.") 179 180 if isinstance(data, np.ndarray): 181 data = np.atleast_2d(data) 182 if len(data.shape) < 3: 183 if prob is not None: 184 return ess_func( # pylint: disable=unexpected-keyword-arg 185 data, prob=prob, relative=relative 186 ) 187 else: 188 return ess_func(data, relative=relative) 189 else: 190 msg = ( 191 "Only uni-dimensional ndarray variables are supported." 192 " Please transform first to dataset with `az.convert_to_dataset`." 193 ) 194 raise TypeError(msg) 195 196 dataset = convert_to_dataset(data, group="posterior") 197 var_names = _var_names(var_names, dataset) 198 199 dataset = dataset if var_names is None else dataset[var_names] 200 201 ufunc_kwargs = {"ravel": False} 202 func_kwargs = {"relative": relative} if prob is None else {"prob": prob, "relative": relative} 203 return _wrap_xarray_ufunc( 204 ess_func, 205 dataset, 206 ufunc_kwargs=ufunc_kwargs, 207 func_kwargs=func_kwargs, 208 dask_kwargs=dask_kwargs, 209 ) 210 211 212def rhat(data, *, var_names=None, method="rank", dask_kwargs=None): 213 r"""Compute estimate of rank normalized splitR-hat for a set of traces. 214 215 The rank normalized R-hat diagnostic tests for lack of convergence by comparing the variance 216 between multiple chains to the variance within each chain. If convergence has been achieved, 217 the between-chain and within-chain variances should be identical. To be most effective in 218 detecting evidence for nonconvergence, each chain should have been initialized to starting 219 values that are dispersed relative to the target distribution. 220 221 Parameters 222 ---------- 223 data : obj 224 Any object that can be converted to an az.InferenceData object. 225 Refer to documentation of az.convert_to_dataset for details. 226 At least 2 posterior chains are needed to compute this diagnostic of one or more 227 stochastic parameters. 228 For ndarray: shape = (chain, draw). 229 For n-dimensional ndarray transform first to dataset with az.convert_to_dataset. 230 var_names : list 231 Names of variables to include in the rhat report 232 method : str 233 Select R-hat method. Valid methods are: 234 - "rank" # recommended by Vehtari et al. (2019) 235 - "split" 236 - "folded" 237 - "z_scale" 238 - "identity" 239 dask_kwargs : dict, optional 240 Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`. 241 242 Returns 243 ------- 244 xarray.Dataset 245 Returns dataset of the potential scale reduction factors, :math:`\hat{R}` 246 247 Notes 248 ----- 249 The diagnostic is computed by: 250 251 .. math:: \hat{R} = \frac{\hat{V}}{W} 252 253 where :math:`W` is the within-chain variance and :math:`\hat{V}` is the posterior variance 254 estimate for the pooled rank-traces. This is the potential scale reduction factor, which 255 converges to unity when each of the traces is a sample from the target posterior. Values 256 greater than one indicate that one or more chains have not yet converged. 257 258 Rank values are calculated over all the chains with `scipy.stats.rankdata`. 259 Each chain is split in two and normalized with the z-transform following Vehtari et al. (2019). 260 261 References 262 ---------- 263 * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008 264 * Gelman et al. BDA (2014) 265 * Brooks and Gelman (1998) 266 * Gelman and Rubin (1992) 267 268 Examples 269 -------- 270 Calculate the R-hat using the default arguments: 271 272 .. ipython:: 273 274 In [1]: import arviz as az 275 ...: data = az.load_arviz_data("non_centered_eight") 276 ...: az.rhat(data) 277 278 Calculate the R-hat of some variables using the folded method: 279 280 .. ipython:: 281 282 In [1]: az.rhat(data, var_names=["mu", "theta_t"], method="folded") 283 284 """ 285 methods = { 286 "rank": _rhat_rank, 287 "split": _rhat_split, 288 "folded": _rhat_folded, 289 "z_scale": _rhat_z_scale, 290 "identity": _rhat_identity, 291 } 292 if method not in methods: 293 raise TypeError( 294 f"R-hat method {method} not found. Valid methods are:\n{', '.join(methods)}" 295 ) 296 rhat_func = methods[method] 297 298 if isinstance(data, np.ndarray): 299 data = np.atleast_2d(data) 300 if len(data.shape) < 3: 301 return rhat_func(data) 302 else: 303 msg = ( 304 "Only uni-dimensional ndarray variables are supported." 305 " Please transform first to dataset with `az.convert_to_dataset`." 306 ) 307 raise TypeError(msg) 308 309 dataset = convert_to_dataset(data, group="posterior") 310 var_names = _var_names(var_names, dataset) 311 312 dataset = dataset if var_names is None else dataset[var_names] 313 314 ufunc_kwargs = {"ravel": False} 315 func_kwargs = {} 316 return _wrap_xarray_ufunc( 317 rhat_func, 318 dataset, 319 ufunc_kwargs=ufunc_kwargs, 320 func_kwargs=func_kwargs, 321 dask_kwargs=dask_kwargs, 322 ) 323 324 325def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None): 326 """Calculate Markov Chain Standard Error statistic. 327 328 Parameters 329 ---------- 330 data : obj 331 Any object that can be converted to an az.InferenceData object 332 Refer to documentation of az.convert_to_dataset for details 333 For ndarray: shape = (chain, draw). 334 For n-dimensional ndarray transform first to dataset with az.convert_to_dataset. 335 var_names : list 336 Names of variables to include in the rhat report 337 method : str 338 Select mcse method. Valid methods are: 339 - "mean" 340 - "sd" 341 - "median" 342 - "quantile" 343 344 prob : float 345 Quantile information. 346 dask_kwargs : dict, optional 347 Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`. 348 349 Returns 350 ------- 351 xarray.Dataset 352 Return the msce dataset 353 354 Examples 355 -------- 356 Calculate the Markov Chain Standard Error using the default arguments: 357 358 .. ipython:: 359 360 In [1]: import arviz as az 361 ...: data = az.load_arviz_data("non_centered_eight") 362 ...: az.mcse(data) 363 364 Calculate the Markov Chain Standard Error using the quantile method: 365 366 .. ipython:: 367 368 In [1]: az.mcse(data, method="quantile", prob=0.7) 369 370 """ 371 methods = { 372 "mean": _mcse_mean, 373 "sd": _mcse_sd, 374 "median": _mcse_median, 375 "quantile": _mcse_quantile, 376 } 377 if method not in methods: 378 raise TypeError( 379 "mcse method {} not found. Valid methods are:\n{}".format( 380 method, "\n ".join(methods) 381 ) 382 ) 383 mcse_func = methods[method] 384 385 if method == "quantile" and prob is None: 386 raise TypeError("Quantile (prob) information needs to be defined.") 387 388 if isinstance(data, np.ndarray): 389 data = np.atleast_2d(data) 390 if len(data.shape) < 3: 391 if prob is not None: 392 return mcse_func(data, prob=prob) # pylint: disable=unexpected-keyword-arg 393 else: 394 return mcse_func(data) 395 else: 396 msg = ( 397 "Only uni-dimensional ndarray variables are supported." 398 " Please transform first to dataset with `az.convert_to_dataset`." 399 ) 400 raise TypeError(msg) 401 402 dataset = convert_to_dataset(data, group="posterior") 403 var_names = _var_names(var_names, dataset) 404 405 dataset = dataset if var_names is None else dataset[var_names] 406 407 ufunc_kwargs = {"ravel": False} 408 func_kwargs = {} if prob is None else {"prob": prob} 409 return _wrap_xarray_ufunc( 410 mcse_func, 411 dataset, 412 ufunc_kwargs=ufunc_kwargs, 413 func_kwargs=func_kwargs, 414 dask_kwargs=dask_kwargs, 415 ) 416 417 418def ks_summary(pareto_tail_indices): 419 """Display a summary of Pareto tail indices. 420 421 Parameters 422 ---------- 423 pareto_tail_indices : array 424 Pareto tail indices. 425 426 Returns 427 ------- 428 df_k : dataframe 429 Dataframe containing k diagnostic values. 430 """ 431 _numba_flag = Numba.numba_flag 432 if _numba_flag: 433 bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf]) 434 kcounts, *_ = _histogram(pareto_tail_indices, bins) 435 else: 436 kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.Inf, 0.5, 0.7, 1, np.Inf]) 437 kprop = kcounts / len(pareto_tail_indices) * 100 438 df_k = pd.DataFrame( 439 dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop) 440 ).rename(index={0: "(-Inf, 0.5]", 1: " (0.5, 0.7]", 2: " (0.7, 1]", 3: " (1, Inf)"}) 441 442 if np.sum(kcounts[1:]) == 0: 443 warnings.warn("All Pareto k estimates are good (k < 0.5)") 444 elif np.sum(kcounts[2:]) == 0: 445 warnings.warn("All Pareto k estimates are ok (k < 0.7)") 446 447 return df_k 448 449 450def _bfmi(energy): 451 r"""Calculate the estimated Bayesian fraction of missing information (BFMI). 452 453 BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more 454 information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that 455 values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may 456 change. See http://mc-stan.org/users/documentation/case-studies/pystan_workflow.html for more 457 information. 458 459 Parameters 460 ---------- 461 energy : NumPy array 462 Should be extracted from a gradient based sampler, such as in Stan or PyMC3. Typically, 463 after converting a trace or fit to InferenceData, the energy will be in 464 `data.sample_stats.energy`. 465 466 Returns 467 ------- 468 z : array 469 The Bayesian fraction of missing information of the model and trace. One element per 470 chain in the trace. 471 """ 472 energy_mat = np.atleast_2d(energy) 473 num = np.square(np.diff(energy_mat, axis=1)).mean(axis=1) # pylint: disable=no-member 474 if energy_mat.ndim == 2: 475 den = _numba_var(svar, np.var, energy_mat, axis=1, ddof=1) 476 else: 477 den = np.var(energy, axis=1, ddof=1) 478 return num / den 479 480 481def _backtransform_ranks(arr, c=3 / 8): # pylint: disable=invalid-name 482 """Backtransformation of ranks. 483 484 Parameters 485 ---------- 486 arr : np.ndarray 487 Ranks array 488 c : float 489 Fractional offset. Defaults to c = 3/8 as recommended by Blom (1958). 490 491 Returns 492 ------- 493 np.ndarray 494 495 References 496 ---------- 497 Blom, G. (1958). Statistical Estimates and Transformed Beta-Variables. Wiley; New York. 498 """ 499 arr = np.asarray(arr) 500 size = arr.size 501 return (arr - c) / (size - 2 * c + 1) 502 503 504def _z_scale(ary): 505 """Calculate z_scale. 506 507 Parameters 508 ---------- 509 ary : np.ndarray 510 511 Returns 512 ------- 513 np.ndarray 514 """ 515 ary = np.asarray(ary) 516 rank = stats.rankdata(ary, method="average") 517 rank = _backtransform_ranks(rank) 518 z = stats.norm.ppf(rank) 519 z = z.reshape(ary.shape) 520 return z 521 522 523def _split_chains(ary): 524 """Split and stack chains.""" 525 ary = np.asarray(ary) 526 if len(ary.shape) > 1: 527 _, n_draw = ary.shape 528 else: 529 ary = np.atleast_2d(ary) 530 _, n_draw = ary.shape 531 half = n_draw // 2 532 return _stack(ary[:, :half], ary[:, -half:]) 533 534 535def _z_fold(ary): 536 """Fold and z-scale values.""" 537 ary = np.asarray(ary) 538 ary = abs(ary - np.median(ary)) 539 ary = _z_scale(ary) 540 return ary 541 542 543def _rhat(ary): 544 """Compute the rhat for a 2d array.""" 545 _numba_flag = Numba.numba_flag 546 ary = np.asarray(ary, dtype=float) 547 if _not_valid(ary, check_shape=False): 548 return np.nan 549 _, num_samples = ary.shape 550 551 # Calculate chain mean 552 chain_mean = np.mean(ary, axis=1) 553 # Calculate chain variance 554 chain_var = _numba_var(svar, np.var, ary, axis=1, ddof=1) 555 # Calculate between-chain variance 556 between_chain_variance = num_samples * _numba_var(svar, np.var, chain_mean, axis=None, ddof=1) 557 # Calculate within-chain variance 558 within_chain_variance = np.mean(chain_var) 559 # Estimate of marginal posterior variance 560 rhat_value = np.sqrt( 561 (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples) 562 ) 563 return rhat_value 564 565 566def _rhat_rank(ary): 567 """Compute the rank normalized rhat for 2d array. 568 569 Computation follows https://arxiv.org/abs/1903.08008 570 """ 571 ary = np.asarray(ary) 572 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)): 573 return np.nan 574 split_ary = _split_chains(ary) 575 rhat_bulk = _rhat(_z_scale(split_ary)) 576 577 split_ary_folded = abs(split_ary - np.median(split_ary)) 578 rhat_tail = _rhat(_z_scale(split_ary_folded)) 579 580 rhat_rank = max(rhat_bulk, rhat_tail) 581 return rhat_rank 582 583 584def _rhat_folded(ary): 585 """Calculate split-Rhat for folded z-values.""" 586 ary = np.asarray(ary) 587 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)): 588 return np.nan 589 ary = _z_fold(_split_chains(ary)) 590 return _rhat(ary) 591 592 593def _rhat_z_scale(ary): 594 ary = np.asarray(ary) 595 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)): 596 return np.nan 597 return _rhat(_z_scale(_split_chains(ary))) 598 599 600def _rhat_split(ary): 601 ary = np.asarray(ary) 602 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)): 603 return np.nan 604 return _rhat(_split_chains(ary)) 605 606 607def _rhat_identity(ary): 608 ary = np.asarray(ary) 609 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)): 610 return np.nan 611 return _rhat(ary) 612 613 614def _ess(ary, relative=False): 615 """Compute the effective sample size for a 2D array.""" 616 _numba_flag = Numba.numba_flag 617 ary = np.asarray(ary, dtype=float) 618 if _not_valid(ary, check_shape=False): 619 return np.nan 620 if (np.max(ary) - np.min(ary)) < np.finfo(float).resolution: # pylint: disable=no-member 621 return ary.size 622 if len(ary.shape) < 2: 623 ary = np.atleast_2d(ary) 624 n_chain, n_draw = ary.shape 625 acov = _autocov(ary, axis=1) 626 chain_mean = ary.mean(axis=1) 627 mean_var = np.mean(acov[:, 0]) * n_draw / (n_draw - 1.0) 628 var_plus = mean_var * (n_draw - 1.0) / n_draw 629 if n_chain > 1: 630 var_plus += _numba_var(svar, np.var, chain_mean, axis=None, ddof=1) 631 632 rho_hat_t = np.zeros(n_draw) 633 rho_hat_even = 1.0 634 rho_hat_t[0] = rho_hat_even 635 rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, 1])) / var_plus 636 rho_hat_t[1] = rho_hat_odd 637 638 # Geyer's initial positive sequence 639 t = 1 640 while t < (n_draw - 3) and (rho_hat_even + rho_hat_odd) > 0.0: 641 rho_hat_even = 1.0 - (mean_var - np.mean(acov[:, t + 1])) / var_plus 642 rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, t + 2])) / var_plus 643 if (rho_hat_even + rho_hat_odd) >= 0: 644 rho_hat_t[t + 1] = rho_hat_even 645 rho_hat_t[t + 2] = rho_hat_odd 646 t += 2 647 648 max_t = t - 2 649 # improve estimation 650 if rho_hat_even > 0: 651 rho_hat_t[max_t + 1] = rho_hat_even 652 # Geyer's initial monotone sequence 653 t = 1 654 while t <= max_t - 2: 655 if (rho_hat_t[t + 1] + rho_hat_t[t + 2]) > (rho_hat_t[t - 1] + rho_hat_t[t]): 656 rho_hat_t[t + 1] = (rho_hat_t[t - 1] + rho_hat_t[t]) / 2.0 657 rho_hat_t[t + 2] = rho_hat_t[t + 1] 658 t += 2 659 660 ess = n_chain * n_draw 661 tau_hat = -1.0 + 2.0 * np.sum(rho_hat_t[: max_t + 1]) + np.sum(rho_hat_t[max_t + 1 : max_t + 2]) 662 tau_hat = max(tau_hat, 1 / np.log10(ess)) 663 ess = (1 if relative else ess) / tau_hat 664 if np.isnan(rho_hat_t).any(): 665 ess = np.nan 666 return ess 667 668 669def _ess_bulk(ary, relative=False): 670 """Compute the effective sample size for the bulk.""" 671 ary = np.asarray(ary) 672 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 673 return np.nan 674 z_scaled = _z_scale(_split_chains(ary)) 675 ess_bulk = _ess(z_scaled, relative=relative) 676 return ess_bulk 677 678 679def _ess_tail(ary, prob=None, relative=False): 680 """Compute the effective sample size for the tail. 681 682 If `prob` defined, ess = min(qess(prob), qess(1-prob)) 683 """ 684 if prob is None: 685 prob = (0.05, 0.95) 686 elif not isinstance(prob, Sequence): 687 prob = (prob, 1 - prob) 688 689 ary = np.asarray(ary) 690 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 691 return np.nan 692 693 prob_low, prob_high = prob 694 quantile_low_ess = _ess_quantile(ary, prob_low, relative=relative) 695 quantile_high_ess = _ess_quantile(ary, prob_high, relative=relative) 696 return min(quantile_low_ess, quantile_high_ess) 697 698 699def _ess_mean(ary, relative=False): 700 """Compute the effective sample size for the mean.""" 701 ary = np.asarray(ary) 702 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 703 return np.nan 704 return _ess(_split_chains(ary), relative=relative) 705 706 707def _ess_sd(ary, relative=False): 708 """Compute the effective sample size for the sd.""" 709 ary = np.asarray(ary) 710 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 711 return np.nan 712 ary = _split_chains(ary) 713 return min(_ess(ary, relative=relative), _ess(ary ** 2, relative=relative)) 714 715 716def _ess_quantile(ary, prob, relative=False): 717 """Compute the effective sample size for the specific residual.""" 718 ary = np.asarray(ary) 719 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 720 return np.nan 721 if prob is None: 722 raise TypeError("Prob not defined.") 723 (quantile,) = _quantile(ary, prob) 724 iquantile = ary <= quantile 725 return _ess(_split_chains(iquantile), relative=relative) 726 727 728def _ess_local(ary, prob, relative=False): 729 """Compute the effective sample size for the specific residual.""" 730 ary = np.asarray(ary) 731 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 732 return np.nan 733 if prob is None: 734 raise TypeError("Prob not defined.") 735 if len(prob) != 2: 736 raise ValueError("Prob argument in ess local must be upper and lower bound") 737 quantile = _quantile(ary, prob) 738 iquantile = (quantile[0] <= ary) & (ary <= quantile[1]) 739 return _ess(_split_chains(iquantile), relative=relative) 740 741 742def _ess_z_scale(ary, relative=False): 743 """Calculate ess for z-scaLe.""" 744 ary = np.asarray(ary) 745 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 746 return np.nan 747 return _ess(_z_scale(_split_chains(ary)), relative=relative) 748 749 750def _ess_folded(ary, relative=False): 751 """Calculate split-ess for folded data.""" 752 ary = np.asarray(ary) 753 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 754 return np.nan 755 return _ess(_z_fold(_split_chains(ary)), relative=relative) 756 757 758def _ess_median(ary, relative=False): 759 """Calculate split-ess for median.""" 760 ary = np.asarray(ary) 761 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 762 return np.nan 763 return _ess_quantile(ary, 0.5, relative=relative) 764 765 766def _ess_mad(ary, relative=False): 767 """Calculate split-ess for mean absolute deviance.""" 768 ary = np.asarray(ary) 769 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 770 return np.nan 771 ary = abs(ary - np.median(ary)) 772 ary = ary <= np.median(ary) 773 ary = _z_scale(_split_chains(ary)) 774 return _ess(ary, relative=relative) 775 776 777def _ess_identity(ary, relative=False): 778 """Calculate ess.""" 779 ary = np.asarray(ary) 780 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 781 return np.nan 782 return _ess(ary, relative=relative) 783 784 785def _mcse_mean(ary): 786 """Compute the Markov Chain mean error.""" 787 _numba_flag = Numba.numba_flag 788 ary = np.asarray(ary) 789 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 790 return np.nan 791 ess = _ess_mean(ary) 792 if _numba_flag: 793 sd = _sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)) 794 else: 795 sd = np.std(ary, ddof=1) 796 mcse_mean_value = sd / np.sqrt(ess) 797 return mcse_mean_value 798 799 800def _mcse_sd(ary): 801 """Compute the Markov Chain sd error.""" 802 _numba_flag = Numba.numba_flag 803 ary = np.asarray(ary) 804 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 805 return np.nan 806 ess = _ess_sd(ary) 807 if _numba_flag: 808 sd = np.float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1))) 809 else: 810 sd = np.std(ary, ddof=1) 811 fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1) 812 mcse_sd_value = sd * fac_mcse_sd 813 return mcse_sd_value 814 815 816def _mcse_median(ary): 817 """Compute the Markov Chain median error.""" 818 return _mcse_quantile(ary, 0.5) 819 820 821def _mcse_quantile(ary, prob): 822 """Compute the Markov Chain quantile error at quantile=prob.""" 823 ary = np.asarray(ary) 824 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 825 return np.nan 826 ess = _ess_quantile(ary, prob) 827 probability = [0.1586553, 0.8413447] 828 with np.errstate(invalid="ignore"): 829 ppf = stats.beta.ppf(probability, ess * prob + 1, ess * (1 - prob) + 1) 830 sorted_ary = np.sort(ary.ravel()) 831 size = sorted_ary.size 832 ppf_size = ppf * size - 1 833 th1 = sorted_ary[int(np.floor(np.nanmax((ppf_size[0], 0))))] 834 th2 = sorted_ary[int(np.ceil(np.nanmin((ppf_size[1], size - 1))))] 835 return (th2 - th1) / 2 836 837 838def _mc_error(ary, batches=5, circular=False): 839 """Calculate the simulation standard error, accounting for non-independent samples. 840 841 The trace is divided into batches, and the standard deviation of the batch 842 means is calculated. 843 844 Parameters 845 ---------- 846 ary : Numpy array 847 An array containing MCMC samples 848 batches : integer 849 Number of batches 850 circular : bool 851 Whether to compute the error taking into account `ary` is a circular variable 852 (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables). 853 854 Returns 855 ------- 856 mc_error : float 857 Simulation standard error 858 """ 859 _numba_flag = Numba.numba_flag 860 if ary.ndim > 1: 861 862 dims = np.shape(ary) 863 trace = np.transpose([t.ravel() for t in ary]) 864 865 return np.reshape([_mc_error(t, batches) for t in trace], dims[1:]) 866 867 else: 868 if _not_valid(ary, check_shape=False): 869 return np.nan 870 if batches == 1: 871 if circular: 872 if _numba_flag: 873 std = _circular_standard_deviation(ary, high=np.pi, low=-np.pi) 874 else: 875 std = stats.circstd(ary, high=np.pi, low=-np.pi) 876 else: 877 if _numba_flag: 878 std = np.float(_sqrt(svar(ary), np.zeros(1))) 879 else: 880 std = np.std(ary) 881 return std / np.sqrt(len(ary)) 882 883 batched_traces = np.resize(ary, (batches, int(len(ary) / batches))) 884 885 if circular: 886 means = stats.circmean(batched_traces, high=np.pi, low=-np.pi, axis=1) 887 if _numba_flag: 888 std = _circular_standard_deviation(means, high=np.pi, low=-np.pi) 889 else: 890 std = stats.circstd(means, high=np.pi, low=-np.pi) 891 else: 892 means = np.mean(batched_traces, 1) 893 if _numba_flag: 894 std = _sqrt(svar(means), np.zeros(1)) 895 else: 896 std = np.std(means) 897 898 return std / np.sqrt(batches) 899 900 901def _multichain_statistics(ary): 902 """Calculate efficiently multichain statistics for summary. 903 904 Parameters 905 ---------- 906 ary : numpy.ndarray 907 908 Returns 909 ------- 910 tuple 911 Order of return parameters is 912 - mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat 913 """ 914 ary = np.atleast_2d(ary) 915 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)): 916 return np.nan, np.nan, np.nan, np.nan, np.nan 917 # ess mean 918 ess_mean_value = _ess_mean(ary) 919 920 # ess sd 921 ess_sd_value = _ess_sd(ary) 922 923 # ess bulk 924 z_split = _z_scale(_split_chains(ary)) 925 ess_bulk_value = _ess(z_split) 926 927 # ess tail 928 quantile05, quantile95 = _quantile(ary, [0.05, 0.95]) 929 iquantile05 = ary <= quantile05 930 quantile05_ess = _ess(_split_chains(iquantile05)) 931 iquantile95 = ary <= quantile95 932 quantile95_ess = _ess(_split_chains(iquantile95)) 933 ess_tail_value = min(quantile05_ess, quantile95_ess) 934 935 if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)): 936 rhat_value = np.nan 937 else: 938 # r_hat 939 rhat_bulk = _rhat(z_split) 940 ary_folded = np.abs(ary - np.median(ary)) 941 rhat_tail = _rhat(_z_scale(_split_chains(ary_folded))) 942 rhat_value = max(rhat_bulk, rhat_tail) 943 944 # mcse_mean 945 sd = np.std(ary, ddof=1) 946 mcse_mean_value = sd / np.sqrt(ess_mean_value) 947 948 # mcse_sd 949 fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess_sd_value) ** (ess_sd_value - 1) - 1) 950 mcse_sd_value = sd * fac_mcse_sd 951 952 return ( 953 mcse_mean_value, 954 mcse_sd_value, 955 ess_bulk_value, 956 ess_tail_value, 957 rhat_value, 958 ) 959