1# pylint: disable=too-many-lines, too-many-function-args, redefined-outer-name
2"""Diagnostic functions for ArviZ."""
3import warnings
4from collections.abc import Sequence
5
6import numpy as np
7import pandas as pd
8from scipy import stats
9
10from ..data import convert_to_dataset
11from ..utils import Numba, _numba_var, _stack, _var_names
12from .density_utils import histogram as _histogram
13from .stats_utils import _circular_standard_deviation, _sqrt
14from .stats_utils import autocov as _autocov
15from .stats_utils import not_valid as _not_valid
16from .stats_utils import quantile as _quantile
17from .stats_utils import stats_variance_2d as svar
18from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc
19
20__all__ = ["bfmi", "ess", "rhat", "mcse"]
21
22
23def bfmi(data):
24    r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
25
26    BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
27    information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
28    values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
29    change. See http://mc-stan.org/users/documentation/case-studies/pystan_workflow.html for more
30    information.
31
32    Parameters
33    ----------
34    data : obj
35        Any object that can be converted to an az.InferenceData object.
36        Refer to documentation of az.convert_to_dataset for details.
37        If InferenceData, energy variable needs to be found.
38
39    Returns
40    -------
41    z : array
42        The Bayesian fraction of missing information of the model and trace. One element per
43        chain in the trace.
44
45    Examples
46    --------
47    Compute the BFMI of an InferenceData object
48
49    .. ipython::
50
51        In [1]: import arviz as az
52           ...: data = az.load_arviz_data('radon')
53           ...: az.bfmi(data)
54
55    """
56    if isinstance(data, np.ndarray):
57        return _bfmi(data)
58
59    dataset = convert_to_dataset(data, group="sample_stats")
60    if not hasattr(dataset, "energy"):
61        raise TypeError("Energy variable was not found.")
62    return _bfmi(dataset.energy)
63
64
65def ess(
66    data,
67    *,
68    var_names=None,
69    method="bulk",
70    relative=False,
71    prob=None,
72    dask_kwargs=None,
73):
74    r"""Calculate estimate of the effective sample size (ess).
75
76    Parameters
77    ----------
78    data : obj
79        Any object that can be converted to an ``az.InferenceData`` object.
80        Refer to documentation of ``az.convert_to_dataset`` for details.
81        For ndarray: shape = (chain, draw).
82        For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
83    var_names : str or list of str
84        Names of variables to include in the return value Dataset.
85    method : str, optional, default "bulk"
86        Select ess method. Valid methods are:
87
88        - "bulk"
89        - "tail"     # prob, optional
90        - "quantile" # prob
91        - "mean" (old ess)
92        - "sd"
93        - "median"
94        - "mad" (mean absolute deviance)
95        - "z_scale"
96        - "folded"
97        - "identity"
98        - "local"
99    relative : bool
100        Return relative ess
101        `ress = ess / n`
102    prob : float, or tuple of two floats, optional
103        probability value for "tail", "quantile" or "local" ess functions.
104    dask_kwargs : dict, optional
105        Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
106
107    Returns
108    -------
109    xarray.Dataset
110        Return the effective sample size, :math:`\hat{N}_{eff}`
111
112    Notes
113    -----
114    The basic ess (:math:`N_{\mathit{eff}}`) diagnostic is computed by:
115
116    .. math:: \hat{N}_{\mathit{eff}} = \frac{MN}{\hat{\tau}}
117
118    .. math:: \hat{\tau} = -1 + 2 \sum_{t'=0}^K \hat{P}_{t'}
119
120    where :math:`M` is the number of chains, :math:`N` the number of draws,
121    :math:`\hat{\rho}_t` is the estimated _autocorrelation at lag :math:`t`, and
122    :math:`K` is the last integer for which :math:`\hat{P}_{K} = \hat{\rho}_{2K} +
123    \hat{\rho}_{2K+1}` is still positive.
124
125    The current implementation is similar to Stan, which uses Geyer's initial monotone sequence
126    criterion (Geyer, 1992; Geyer, 2011).
127
128    References
129    ----------
130    * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
131    * https://mc-stan.org/docs/2_18/reference-manual/effective-sample-size-section.html
132      Section 15.4.2
133    * Gelman et al. BDA (2014) Formula 11.8
134
135    Examples
136    --------
137    Calculate the effective_sample_size using the default arguments:
138
139    .. ipython::
140
141        In [1]: import arviz as az
142           ...: data = az.load_arviz_data('non_centered_eight')
143           ...: az.ess(data)
144
145    Calculate the ress of some of the variables
146
147    .. ipython::
148
149        In [1]: az.ess(data, relative=True, var_names=["mu", "theta_t"])
150
151    Calculate the ess using the "tail" method, leaving the `prob` argument at its default
152    value.
153
154    .. ipython::
155
156        In [1]: az.ess(data, method="tail")
157
158    """
159    methods = {
160        "bulk": _ess_bulk,
161        "tail": _ess_tail,
162        "quantile": _ess_quantile,
163        "mean": _ess_mean,
164        "sd": _ess_sd,
165        "median": _ess_median,
166        "mad": _ess_mad,
167        "z_scale": _ess_z_scale,
168        "folded": _ess_folded,
169        "identity": _ess_identity,
170        "local": _ess_local,
171    }
172
173    if method not in methods:
174        raise TypeError(f"ess method {method} not found. Valid methods are:\n{', '.join(methods)}")
175    ess_func = methods[method]
176
177    if (method == "quantile") and prob is None:
178        raise TypeError("Quantile (prob) information needs to be defined.")
179
180    if isinstance(data, np.ndarray):
181        data = np.atleast_2d(data)
182        if len(data.shape) < 3:
183            if prob is not None:
184                return ess_func(  # pylint: disable=unexpected-keyword-arg
185                    data, prob=prob, relative=relative
186                )
187            else:
188                return ess_func(data, relative=relative)
189        else:
190            msg = (
191                "Only uni-dimensional ndarray variables are supported."
192                " Please transform first to dataset with `az.convert_to_dataset`."
193            )
194            raise TypeError(msg)
195
196    dataset = convert_to_dataset(data, group="posterior")
197    var_names = _var_names(var_names, dataset)
198
199    dataset = dataset if var_names is None else dataset[var_names]
200
201    ufunc_kwargs = {"ravel": False}
202    func_kwargs = {"relative": relative} if prob is None else {"prob": prob, "relative": relative}
203    return _wrap_xarray_ufunc(
204        ess_func,
205        dataset,
206        ufunc_kwargs=ufunc_kwargs,
207        func_kwargs=func_kwargs,
208        dask_kwargs=dask_kwargs,
209    )
210
211
212def rhat(data, *, var_names=None, method="rank", dask_kwargs=None):
213    r"""Compute estimate of rank normalized splitR-hat for a set of traces.
214
215    The rank normalized R-hat diagnostic tests for lack of convergence by comparing the variance
216    between multiple chains to the variance within each chain. If convergence has been achieved,
217    the between-chain and within-chain variances should be identical. To be most effective in
218    detecting evidence for nonconvergence, each chain should have been initialized to starting
219    values that are dispersed relative to the target distribution.
220
221    Parameters
222    ----------
223    data : obj
224        Any object that can be converted to an az.InferenceData object.
225        Refer to documentation of az.convert_to_dataset for details.
226        At least 2 posterior chains are needed to compute this diagnostic of one or more
227        stochastic parameters.
228        For ndarray: shape = (chain, draw).
229        For n-dimensional ndarray transform first to dataset with az.convert_to_dataset.
230    var_names : list
231        Names of variables to include in the rhat report
232    method : str
233        Select R-hat method. Valid methods are:
234        - "rank"        # recommended by Vehtari et al. (2019)
235        - "split"
236        - "folded"
237        - "z_scale"
238        - "identity"
239    dask_kwargs : dict, optional
240        Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
241
242    Returns
243    -------
244    xarray.Dataset
245      Returns dataset of the potential scale reduction factors, :math:`\hat{R}`
246
247    Notes
248    -----
249    The diagnostic is computed by:
250
251      .. math:: \hat{R} = \frac{\hat{V}}{W}
252
253    where :math:`W` is the within-chain variance and :math:`\hat{V}` is the posterior variance
254    estimate for the pooled rank-traces. This is the potential scale reduction factor, which
255    converges to unity when each of the traces is a sample from the target posterior. Values
256    greater than one indicate that one or more chains have not yet converged.
257
258    Rank values are calculated over all the chains with `scipy.stats.rankdata`.
259    Each chain is split in two and normalized with the z-transform following Vehtari et al. (2019).
260
261    References
262    ----------
263    * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
264    * Gelman et al. BDA (2014)
265    * Brooks and Gelman (1998)
266    * Gelman and Rubin (1992)
267
268    Examples
269    --------
270    Calculate the R-hat using the default arguments:
271
272    .. ipython::
273
274        In [1]: import arviz as az
275           ...: data = az.load_arviz_data("non_centered_eight")
276           ...: az.rhat(data)
277
278    Calculate the R-hat of some variables using the folded method:
279
280    .. ipython::
281
282        In [1]: az.rhat(data, var_names=["mu", "theta_t"], method="folded")
283
284    """
285    methods = {
286        "rank": _rhat_rank,
287        "split": _rhat_split,
288        "folded": _rhat_folded,
289        "z_scale": _rhat_z_scale,
290        "identity": _rhat_identity,
291    }
292    if method not in methods:
293        raise TypeError(
294            f"R-hat method {method} not found. Valid methods are:\n{', '.join(methods)}"
295        )
296    rhat_func = methods[method]
297
298    if isinstance(data, np.ndarray):
299        data = np.atleast_2d(data)
300        if len(data.shape) < 3:
301            return rhat_func(data)
302        else:
303            msg = (
304                "Only uni-dimensional ndarray variables are supported."
305                " Please transform first to dataset with `az.convert_to_dataset`."
306            )
307            raise TypeError(msg)
308
309    dataset = convert_to_dataset(data, group="posterior")
310    var_names = _var_names(var_names, dataset)
311
312    dataset = dataset if var_names is None else dataset[var_names]
313
314    ufunc_kwargs = {"ravel": False}
315    func_kwargs = {}
316    return _wrap_xarray_ufunc(
317        rhat_func,
318        dataset,
319        ufunc_kwargs=ufunc_kwargs,
320        func_kwargs=func_kwargs,
321        dask_kwargs=dask_kwargs,
322    )
323
324
325def mcse(data, *, var_names=None, method="mean", prob=None, dask_kwargs=None):
326    """Calculate Markov Chain Standard Error statistic.
327
328    Parameters
329    ----------
330    data : obj
331        Any object that can be converted to an az.InferenceData object
332        Refer to documentation of az.convert_to_dataset for details
333        For ndarray: shape = (chain, draw).
334        For n-dimensional ndarray transform first to dataset with az.convert_to_dataset.
335    var_names : list
336        Names of variables to include in the rhat report
337    method : str
338        Select mcse method. Valid methods are:
339        - "mean"
340        - "sd"
341        - "median"
342        - "quantile"
343
344    prob : float
345        Quantile information.
346    dask_kwargs : dict, optional
347        Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
348
349    Returns
350    -------
351    xarray.Dataset
352        Return the msce dataset
353
354    Examples
355    --------
356    Calculate the Markov Chain Standard Error using the default arguments:
357
358    .. ipython::
359
360        In [1]: import arviz as az
361           ...: data = az.load_arviz_data("non_centered_eight")
362           ...: az.mcse(data)
363
364    Calculate the Markov Chain Standard Error using the quantile method:
365
366    .. ipython::
367
368        In [1]: az.mcse(data, method="quantile", prob=0.7)
369
370    """
371    methods = {
372        "mean": _mcse_mean,
373        "sd": _mcse_sd,
374        "median": _mcse_median,
375        "quantile": _mcse_quantile,
376    }
377    if method not in methods:
378        raise TypeError(
379            "mcse method {} not found. Valid methods are:\n{}".format(
380                method, "\n    ".join(methods)
381            )
382        )
383    mcse_func = methods[method]
384
385    if method == "quantile" and prob is None:
386        raise TypeError("Quantile (prob) information needs to be defined.")
387
388    if isinstance(data, np.ndarray):
389        data = np.atleast_2d(data)
390        if len(data.shape) < 3:
391            if prob is not None:
392                return mcse_func(data, prob=prob)  # pylint: disable=unexpected-keyword-arg
393            else:
394                return mcse_func(data)
395        else:
396            msg = (
397                "Only uni-dimensional ndarray variables are supported."
398                " Please transform first to dataset with `az.convert_to_dataset`."
399            )
400            raise TypeError(msg)
401
402    dataset = convert_to_dataset(data, group="posterior")
403    var_names = _var_names(var_names, dataset)
404
405    dataset = dataset if var_names is None else dataset[var_names]
406
407    ufunc_kwargs = {"ravel": False}
408    func_kwargs = {} if prob is None else {"prob": prob}
409    return _wrap_xarray_ufunc(
410        mcse_func,
411        dataset,
412        ufunc_kwargs=ufunc_kwargs,
413        func_kwargs=func_kwargs,
414        dask_kwargs=dask_kwargs,
415    )
416
417
418def ks_summary(pareto_tail_indices):
419    """Display a summary of Pareto tail indices.
420
421    Parameters
422    ----------
423    pareto_tail_indices : array
424      Pareto tail indices.
425
426    Returns
427    -------
428    df_k : dataframe
429      Dataframe containing k diagnostic values.
430    """
431    _numba_flag = Numba.numba_flag
432    if _numba_flag:
433        bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf])
434        kcounts, *_ = _histogram(pareto_tail_indices, bins)
435    else:
436        kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.Inf, 0.5, 0.7, 1, np.Inf])
437    kprop = kcounts / len(pareto_tail_indices) * 100
438    df_k = pd.DataFrame(
439        dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)
440    ).rename(index={0: "(-Inf, 0.5]", 1: " (0.5, 0.7]", 2: "   (0.7, 1]", 3: "   (1, Inf)"})
441
442    if np.sum(kcounts[1:]) == 0:
443        warnings.warn("All Pareto k estimates are good (k < 0.5)")
444    elif np.sum(kcounts[2:]) == 0:
445        warnings.warn("All Pareto k estimates are ok (k < 0.7)")
446
447    return df_k
448
449
450def _bfmi(energy):
451    r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
452
453    BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
454    information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
455    values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
456    change. See http://mc-stan.org/users/documentation/case-studies/pystan_workflow.html for more
457    information.
458
459    Parameters
460    ----------
461    energy : NumPy array
462        Should be extracted from a gradient based sampler, such as in Stan or PyMC3. Typically,
463        after converting a trace or fit to InferenceData, the energy will be in
464        `data.sample_stats.energy`.
465
466    Returns
467    -------
468    z : array
469        The Bayesian fraction of missing information of the model and trace. One element per
470        chain in the trace.
471    """
472    energy_mat = np.atleast_2d(energy)
473    num = np.square(np.diff(energy_mat, axis=1)).mean(axis=1)  # pylint: disable=no-member
474    if energy_mat.ndim == 2:
475        den = _numba_var(svar, np.var, energy_mat, axis=1, ddof=1)
476    else:
477        den = np.var(energy, axis=1, ddof=1)
478    return num / den
479
480
481def _backtransform_ranks(arr, c=3 / 8):  # pylint: disable=invalid-name
482    """Backtransformation of ranks.
483
484    Parameters
485    ----------
486    arr : np.ndarray
487        Ranks array
488    c : float
489        Fractional offset. Defaults to c = 3/8 as recommended by Blom (1958).
490
491    Returns
492    -------
493    np.ndarray
494
495    References
496    ----------
497    Blom, G. (1958). Statistical Estimates and Transformed Beta-Variables. Wiley; New York.
498    """
499    arr = np.asarray(arr)
500    size = arr.size
501    return (arr - c) / (size - 2 * c + 1)
502
503
504def _z_scale(ary):
505    """Calculate z_scale.
506
507    Parameters
508    ----------
509    ary : np.ndarray
510
511    Returns
512    -------
513    np.ndarray
514    """
515    ary = np.asarray(ary)
516    rank = stats.rankdata(ary, method="average")
517    rank = _backtransform_ranks(rank)
518    z = stats.norm.ppf(rank)
519    z = z.reshape(ary.shape)
520    return z
521
522
523def _split_chains(ary):
524    """Split and stack chains."""
525    ary = np.asarray(ary)
526    if len(ary.shape) > 1:
527        _, n_draw = ary.shape
528    else:
529        ary = np.atleast_2d(ary)
530        _, n_draw = ary.shape
531    half = n_draw // 2
532    return _stack(ary[:, :half], ary[:, -half:])
533
534
535def _z_fold(ary):
536    """Fold and z-scale values."""
537    ary = np.asarray(ary)
538    ary = abs(ary - np.median(ary))
539    ary = _z_scale(ary)
540    return ary
541
542
543def _rhat(ary):
544    """Compute the rhat for a 2d array."""
545    _numba_flag = Numba.numba_flag
546    ary = np.asarray(ary, dtype=float)
547    if _not_valid(ary, check_shape=False):
548        return np.nan
549    _, num_samples = ary.shape
550
551    # Calculate chain mean
552    chain_mean = np.mean(ary, axis=1)
553    # Calculate chain variance
554    chain_var = _numba_var(svar, np.var, ary, axis=1, ddof=1)
555    # Calculate between-chain variance
556    between_chain_variance = num_samples * _numba_var(svar, np.var, chain_mean, axis=None, ddof=1)
557    # Calculate within-chain variance
558    within_chain_variance = np.mean(chain_var)
559    # Estimate of marginal posterior variance
560    rhat_value = np.sqrt(
561        (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
562    )
563    return rhat_value
564
565
566def _rhat_rank(ary):
567    """Compute the rank normalized rhat for 2d array.
568
569    Computation follows https://arxiv.org/abs/1903.08008
570    """
571    ary = np.asarray(ary)
572    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
573        return np.nan
574    split_ary = _split_chains(ary)
575    rhat_bulk = _rhat(_z_scale(split_ary))
576
577    split_ary_folded = abs(split_ary - np.median(split_ary))
578    rhat_tail = _rhat(_z_scale(split_ary_folded))
579
580    rhat_rank = max(rhat_bulk, rhat_tail)
581    return rhat_rank
582
583
584def _rhat_folded(ary):
585    """Calculate split-Rhat for folded z-values."""
586    ary = np.asarray(ary)
587    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
588        return np.nan
589    ary = _z_fold(_split_chains(ary))
590    return _rhat(ary)
591
592
593def _rhat_z_scale(ary):
594    ary = np.asarray(ary)
595    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
596        return np.nan
597    return _rhat(_z_scale(_split_chains(ary)))
598
599
600def _rhat_split(ary):
601    ary = np.asarray(ary)
602    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
603        return np.nan
604    return _rhat(_split_chains(ary))
605
606
607def _rhat_identity(ary):
608    ary = np.asarray(ary)
609    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
610        return np.nan
611    return _rhat(ary)
612
613
614def _ess(ary, relative=False):
615    """Compute the effective sample size for a 2D array."""
616    _numba_flag = Numba.numba_flag
617    ary = np.asarray(ary, dtype=float)
618    if _not_valid(ary, check_shape=False):
619        return np.nan
620    if (np.max(ary) - np.min(ary)) < np.finfo(float).resolution:  # pylint: disable=no-member
621        return ary.size
622    if len(ary.shape) < 2:
623        ary = np.atleast_2d(ary)
624    n_chain, n_draw = ary.shape
625    acov = _autocov(ary, axis=1)
626    chain_mean = ary.mean(axis=1)
627    mean_var = np.mean(acov[:, 0]) * n_draw / (n_draw - 1.0)
628    var_plus = mean_var * (n_draw - 1.0) / n_draw
629    if n_chain > 1:
630        var_plus += _numba_var(svar, np.var, chain_mean, axis=None, ddof=1)
631
632    rho_hat_t = np.zeros(n_draw)
633    rho_hat_even = 1.0
634    rho_hat_t[0] = rho_hat_even
635    rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, 1])) / var_plus
636    rho_hat_t[1] = rho_hat_odd
637
638    # Geyer's initial positive sequence
639    t = 1
640    while t < (n_draw - 3) and (rho_hat_even + rho_hat_odd) > 0.0:
641        rho_hat_even = 1.0 - (mean_var - np.mean(acov[:, t + 1])) / var_plus
642        rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, t + 2])) / var_plus
643        if (rho_hat_even + rho_hat_odd) >= 0:
644            rho_hat_t[t + 1] = rho_hat_even
645            rho_hat_t[t + 2] = rho_hat_odd
646        t += 2
647
648    max_t = t - 2
649    # improve estimation
650    if rho_hat_even > 0:
651        rho_hat_t[max_t + 1] = rho_hat_even
652    # Geyer's initial monotone sequence
653    t = 1
654    while t <= max_t - 2:
655        if (rho_hat_t[t + 1] + rho_hat_t[t + 2]) > (rho_hat_t[t - 1] + rho_hat_t[t]):
656            rho_hat_t[t + 1] = (rho_hat_t[t - 1] + rho_hat_t[t]) / 2.0
657            rho_hat_t[t + 2] = rho_hat_t[t + 1]
658        t += 2
659
660    ess = n_chain * n_draw
661    tau_hat = -1.0 + 2.0 * np.sum(rho_hat_t[: max_t + 1]) + np.sum(rho_hat_t[max_t + 1 : max_t + 2])
662    tau_hat = max(tau_hat, 1 / np.log10(ess))
663    ess = (1 if relative else ess) / tau_hat
664    if np.isnan(rho_hat_t).any():
665        ess = np.nan
666    return ess
667
668
669def _ess_bulk(ary, relative=False):
670    """Compute the effective sample size for the bulk."""
671    ary = np.asarray(ary)
672    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
673        return np.nan
674    z_scaled = _z_scale(_split_chains(ary))
675    ess_bulk = _ess(z_scaled, relative=relative)
676    return ess_bulk
677
678
679def _ess_tail(ary, prob=None, relative=False):
680    """Compute the effective sample size for the tail.
681
682    If `prob` defined, ess = min(qess(prob), qess(1-prob))
683    """
684    if prob is None:
685        prob = (0.05, 0.95)
686    elif not isinstance(prob, Sequence):
687        prob = (prob, 1 - prob)
688
689    ary = np.asarray(ary)
690    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
691        return np.nan
692
693    prob_low, prob_high = prob
694    quantile_low_ess = _ess_quantile(ary, prob_low, relative=relative)
695    quantile_high_ess = _ess_quantile(ary, prob_high, relative=relative)
696    return min(quantile_low_ess, quantile_high_ess)
697
698
699def _ess_mean(ary, relative=False):
700    """Compute the effective sample size for the mean."""
701    ary = np.asarray(ary)
702    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
703        return np.nan
704    return _ess(_split_chains(ary), relative=relative)
705
706
707def _ess_sd(ary, relative=False):
708    """Compute the effective sample size for the sd."""
709    ary = np.asarray(ary)
710    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
711        return np.nan
712    ary = _split_chains(ary)
713    return min(_ess(ary, relative=relative), _ess(ary ** 2, relative=relative))
714
715
716def _ess_quantile(ary, prob, relative=False):
717    """Compute the effective sample size for the specific residual."""
718    ary = np.asarray(ary)
719    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
720        return np.nan
721    if prob is None:
722        raise TypeError("Prob not defined.")
723    (quantile,) = _quantile(ary, prob)
724    iquantile = ary <= quantile
725    return _ess(_split_chains(iquantile), relative=relative)
726
727
728def _ess_local(ary, prob, relative=False):
729    """Compute the effective sample size for the specific residual."""
730    ary = np.asarray(ary)
731    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
732        return np.nan
733    if prob is None:
734        raise TypeError("Prob not defined.")
735    if len(prob) != 2:
736        raise ValueError("Prob argument in ess local must be upper and lower bound")
737    quantile = _quantile(ary, prob)
738    iquantile = (quantile[0] <= ary) & (ary <= quantile[1])
739    return _ess(_split_chains(iquantile), relative=relative)
740
741
742def _ess_z_scale(ary, relative=False):
743    """Calculate ess for z-scaLe."""
744    ary = np.asarray(ary)
745    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
746        return np.nan
747    return _ess(_z_scale(_split_chains(ary)), relative=relative)
748
749
750def _ess_folded(ary, relative=False):
751    """Calculate split-ess for folded data."""
752    ary = np.asarray(ary)
753    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
754        return np.nan
755    return _ess(_z_fold(_split_chains(ary)), relative=relative)
756
757
758def _ess_median(ary, relative=False):
759    """Calculate split-ess for median."""
760    ary = np.asarray(ary)
761    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
762        return np.nan
763    return _ess_quantile(ary, 0.5, relative=relative)
764
765
766def _ess_mad(ary, relative=False):
767    """Calculate split-ess for mean absolute deviance."""
768    ary = np.asarray(ary)
769    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
770        return np.nan
771    ary = abs(ary - np.median(ary))
772    ary = ary <= np.median(ary)
773    ary = _z_scale(_split_chains(ary))
774    return _ess(ary, relative=relative)
775
776
777def _ess_identity(ary, relative=False):
778    """Calculate ess."""
779    ary = np.asarray(ary)
780    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
781        return np.nan
782    return _ess(ary, relative=relative)
783
784
785def _mcse_mean(ary):
786    """Compute the Markov Chain mean error."""
787    _numba_flag = Numba.numba_flag
788    ary = np.asarray(ary)
789    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
790        return np.nan
791    ess = _ess_mean(ary)
792    if _numba_flag:
793        sd = _sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1))
794    else:
795        sd = np.std(ary, ddof=1)
796    mcse_mean_value = sd / np.sqrt(ess)
797    return mcse_mean_value
798
799
800def _mcse_sd(ary):
801    """Compute the Markov Chain sd error."""
802    _numba_flag = Numba.numba_flag
803    ary = np.asarray(ary)
804    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
805        return np.nan
806    ess = _ess_sd(ary)
807    if _numba_flag:
808        sd = np.float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)))
809    else:
810        sd = np.std(ary, ddof=1)
811    fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
812    mcse_sd_value = sd * fac_mcse_sd
813    return mcse_sd_value
814
815
816def _mcse_median(ary):
817    """Compute the Markov Chain median error."""
818    return _mcse_quantile(ary, 0.5)
819
820
821def _mcse_quantile(ary, prob):
822    """Compute the Markov Chain quantile error at quantile=prob."""
823    ary = np.asarray(ary)
824    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
825        return np.nan
826    ess = _ess_quantile(ary, prob)
827    probability = [0.1586553, 0.8413447]
828    with np.errstate(invalid="ignore"):
829        ppf = stats.beta.ppf(probability, ess * prob + 1, ess * (1 - prob) + 1)
830    sorted_ary = np.sort(ary.ravel())
831    size = sorted_ary.size
832    ppf_size = ppf * size - 1
833    th1 = sorted_ary[int(np.floor(np.nanmax((ppf_size[0], 0))))]
834    th2 = sorted_ary[int(np.ceil(np.nanmin((ppf_size[1], size - 1))))]
835    return (th2 - th1) / 2
836
837
838def _mc_error(ary, batches=5, circular=False):
839    """Calculate the simulation standard error, accounting for non-independent samples.
840
841    The trace is divided into batches, and the standard deviation of the batch
842    means is calculated.
843
844    Parameters
845    ----------
846    ary : Numpy array
847        An array containing MCMC samples
848    batches : integer
849        Number of batches
850    circular : bool
851        Whether to compute the error taking into account `ary` is a circular variable
852        (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
853
854    Returns
855    -------
856    mc_error : float
857        Simulation standard error
858    """
859    _numba_flag = Numba.numba_flag
860    if ary.ndim > 1:
861
862        dims = np.shape(ary)
863        trace = np.transpose([t.ravel() for t in ary])
864
865        return np.reshape([_mc_error(t, batches) for t in trace], dims[1:])
866
867    else:
868        if _not_valid(ary, check_shape=False):
869            return np.nan
870        if batches == 1:
871            if circular:
872                if _numba_flag:
873                    std = _circular_standard_deviation(ary, high=np.pi, low=-np.pi)
874                else:
875                    std = stats.circstd(ary, high=np.pi, low=-np.pi)
876            else:
877                if _numba_flag:
878                    std = np.float(_sqrt(svar(ary), np.zeros(1)))
879                else:
880                    std = np.std(ary)
881            return std / np.sqrt(len(ary))
882
883        batched_traces = np.resize(ary, (batches, int(len(ary) / batches)))
884
885        if circular:
886            means = stats.circmean(batched_traces, high=np.pi, low=-np.pi, axis=1)
887            if _numba_flag:
888                std = _circular_standard_deviation(means, high=np.pi, low=-np.pi)
889            else:
890                std = stats.circstd(means, high=np.pi, low=-np.pi)
891        else:
892            means = np.mean(batched_traces, 1)
893            if _numba_flag:
894                std = _sqrt(svar(means), np.zeros(1))
895            else:
896                std = np.std(means)
897
898        return std / np.sqrt(batches)
899
900
901def _multichain_statistics(ary):
902    """Calculate efficiently multichain statistics for summary.
903
904    Parameters
905    ----------
906    ary : numpy.ndarray
907
908    Returns
909    -------
910    tuple
911        Order of return parameters is
912            - mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat
913    """
914    ary = np.atleast_2d(ary)
915    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
916        return np.nan, np.nan, np.nan, np.nan, np.nan
917    # ess mean
918    ess_mean_value = _ess_mean(ary)
919
920    # ess sd
921    ess_sd_value = _ess_sd(ary)
922
923    # ess bulk
924    z_split = _z_scale(_split_chains(ary))
925    ess_bulk_value = _ess(z_split)
926
927    # ess tail
928    quantile05, quantile95 = _quantile(ary, [0.05, 0.95])
929    iquantile05 = ary <= quantile05
930    quantile05_ess = _ess(_split_chains(iquantile05))
931    iquantile95 = ary <= quantile95
932    quantile95_ess = _ess(_split_chains(iquantile95))
933    ess_tail_value = min(quantile05_ess, quantile95_ess)
934
935    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
936        rhat_value = np.nan
937    else:
938        # r_hat
939        rhat_bulk = _rhat(z_split)
940        ary_folded = np.abs(ary - np.median(ary))
941        rhat_tail = _rhat(_z_scale(_split_chains(ary_folded)))
942        rhat_value = max(rhat_bulk, rhat_tail)
943
944    # mcse_mean
945    sd = np.std(ary, ddof=1)
946    mcse_mean_value = sd / np.sqrt(ess_mean_value)
947
948    # mcse_sd
949    fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess_sd_value) ** (ess_sd_value - 1) - 1)
950    mcse_sd_value = sd * fac_mcse_sd
951
952    return (
953        mcse_mean_value,
954        mcse_sd_value,
955        ess_bulk_value,
956        ess_tail_value,
957        rhat_value,
958    )
959