1"""Stats-utility functions for ArviZ."""
2import warnings
3from collections.abc import Sequence
4from copy import copy as _copy
5from copy import deepcopy as _deepcopy
7import numpy as np
8import pandas as pd
9from scipy.fftpack import next_fast_len
10from scipy.interpolate import CubicSpline
11from scipy.stats.mstats import mquantiles
12from xarray import apply_ufunc
14from .. import _log
15from ..utils import conditional_jit, conditional_vect, conditional_dask
16from .density_utils import histogram as _histogram
19__all__ = ["autocorr", "autocov", "ELPDData", "make_ufunc", "wrap_xarray_ufunc"]
22def autocov(ary, axis=-1):
23    """Compute autocovariance estimates for every lag for the input array.
25    Parameters
26    ----------
27    ary : Numpy array
28        An array containing MCMC samples
30    Returns
31    -------
32    acov: Numpy array same size as the input array
33    """
34    axis = axis if axis > 0 else len(ary.shape) + axis
35    n = ary.shape[axis]
36    m = next_fast_len(2 * n)
38    ary = ary - ary.mean(axis, keepdims=True)
40    # added to silence tuple warning for a submodule
41    with warnings.catch_warnings():
42        warnings.simplefilter("ignore")
44        ifft_ary = np.fft.rfft(ary, n=m, axis=axis)
45        ifft_ary *= np.conjugate(ifft_ary)
47        shape = tuple(
48            slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape)
49        )
50        cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape]
51        cov /= n
53    return cov
56def autocorr(ary, axis=-1):
57    """Compute autocorrelation using FFT for every lag for the input array.
59    See https://en.wikipedia.org/wiki/autocorrelation#Efficient_computation
61    Parameters
62    ----------
63    ary : Numpy array
64        An array containing MCMC samples
66    Returns
67    -------
68    acorr: Numpy array same size as the input array
69    """
70    corr = autocov(ary, axis=axis)
71    axis = axis = axis if axis > 0 else len(corr.shape) + axis
72    norm = tuple(
73        slice(None, None) if dim != axis else slice(None, 1) for dim, _ in enumerate(corr.shape)
74    )
75    with np.errstate(invalid="ignore"):
76        corr /= corr[norm]
77    return corr
80def make_ufunc(
81    func, n_dims=2, n_output=1, n_input=1, index=Ellipsis, ravel=True, check_shape=None
82):  # noqa: D202
83    """Make ufunc from a function taking 1D array input.
85    Parameters
86    ----------
87    func : callable
88    n_dims : int, optional
89        Number of core dimensions not broadcasted. Dimensions are skipped from the end.
90        At minimum n_dims > 0.
91    n_output : int, optional
92        Select number of results returned by `func`.
93        If n_output > 1, ufunc returns a tuple of objects else returns an object.
94    n_input : int, optional
95        Number of **array** inputs to func, i.e. ``n_input=2`` means that func is called
96        with ``func(ary1, ary2, *args, **kwargs)``
97    index : int, optional
98        Slice ndarray with `index`. Defaults to `Ellipsis`.
99    ravel : bool, optional
100        If true, ravel the ndarray before calling `func`.
101    check_shape: bool, optional
102        If false, do not check if the shape of the output is compatible with n_dims and
103        n_output. By default, True only for n_input=1. If n_input is larger than 1, the last
104        input array is used to check the shape, however, shape checking with multiple inputs
105        may not be correct.
107    Returns
108    -------
109    callable
110        ufunc wrapper for `func`.
111    """
112    if n_dims < 1:
113        raise TypeError("n_dims must be one or higher.")
115    if n_input == 1 and check_shape is None:
116        check_shape = True
117    elif check_shape is None:
118        check_shape = False
120    def _ufunc(*args, out=None, out_shape=None, **kwargs):
121        """General ufunc for single-output function."""
122        arys = args[:n_input]
123        n_dims_out = None
124        if out is None:
125            if out_shape is None:
126                out = np.empty(arys[-1].shape[:-n_dims])
127            else:
128                out = np.empty((*arys[-1].shape[:-n_dims], *out_shape))
129                n_dims_out = -len(out_shape)
130        elif check_shape:
131            if out.shape != arys[-1].shape[:-n_dims]:
132                msg = f"Shape incorrect for `out`: {out.shape}."
133                msg += f" Correct shape is {arys[-1].shape[:-n_dims]}"
134                raise TypeError(msg)
135        for idx in np.ndindex(out.shape[:n_dims_out]):
136            arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
137            out[idx] = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index]
138        return out
140    def _multi_ufunc(*args, out=None, out_shape=None, **kwargs):
141        """General ufunc for multi-output function."""
142        arys = args[:n_input]
143        element_shape = arys[-1].shape[:-n_dims]
144        if out is None:
145            if out_shape is None:
146                out = tuple(np.empty(element_shape) for _ in range(n_output))
147            else:
148                out = tuple(np.empty((*element_shape, *out_shape[i])) for i in range(n_output))
150        elif check_shape:
151            raise_error = False
152            correct_shape = tuple(element_shape for _ in range(n_output))
153            if isinstance(out, tuple):
154                out_shape = tuple(item.shape for item in out)
155                if out_shape != correct_shape:
156                    raise_error = True
157            else:
158                raise_error = True
159                out_shape = "not tuple, type={type(out)}"
160            if raise_error:
161                msg = f"Shapes incorrect for `out`: {out_shape}."
162                msg += f" Correct shapes are {correct_shape}"
163                raise TypeError(msg)
164        for idx in np.ndindex(element_shape):
165            arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
166            results = func(*arys_idx, *args[n_input:], **kwargs)
167            for i, res in enumerate(results):
168                out[i][idx] = np.asarray(res)[index]
169        return out
171    if n_output > 1:
172        ufunc = _multi_ufunc
173    else:
174        ufunc = _ufunc
176    update_docstring(ufunc, func, n_output)
177    return ufunc
181def wrap_xarray_ufunc(
182    ufunc,
183    *datasets,
184    ufunc_kwargs=None,
185    func_args=None,
186    func_kwargs=None,
187    dask_kwargs=None,
188    **kwargs,
190    """Wrap make_ufunc with xarray.apply_ufunc.
192    Parameters
193    ----------
194    ufunc : callable
195    datasets : xarray.dataset
196    ufunc_kwargs : dict
197        Keyword arguments passed to `make_ufunc`.
198            - 'n_dims', int, by default 2
199            - 'n_output', int, by default 1
200            - 'n_input', int, by default len(datasets)
201            - 'index', slice, by default Ellipsis
202            - 'ravel', bool, by default True
203    func_args : tuple
204        Arguments passed to 'ufunc'.
205    func_kwargs : dict
206        Keyword arguments passed to 'ufunc'.
207            - 'out_shape', int, by default None
208    dask_kwargs : dict
209        Dask related kwargs passed to :func:`xarray:xarray.apply_ufunc`.
210        Use :meth:`~arviz.Dask.enable_dask` to set default kwargs.
211    **kwargs
212        Passed to xarray.apply_ufunc.
214    Returns
215    -------
216    xarray.dataset
217    """
218    if ufunc_kwargs is None:
219        ufunc_kwargs = {}
220    ufunc_kwargs.setdefault("n_input", len(datasets))
221    if func_args is None:
222        func_args = tuple()
223    if func_kwargs is None:
224        func_kwargs = {}
225    if dask_kwargs is None:
226        dask_kwargs = {}
228    kwargs.setdefault(
229        "input_core_dims", tuple(("chain", "draw") for _ in range(len(func_args) + len(datasets)))
230    )
231    ufunc_kwargs.setdefault("n_dims", len(kwargs["input_core_dims"][-1]))
232    kwargs.setdefault("output_core_dims", tuple([] for _ in range(ufunc_kwargs.get("n_output", 1))))
234    callable_ufunc = make_ufunc(ufunc, **ufunc_kwargs)
236    return apply_ufunc(
237        callable_ufunc, *datasets, *func_args, kwargs=func_kwargs, **dask_kwargs, **kwargs
238    )
241def update_docstring(ufunc, func, n_output=1):
242    """Update ArviZ generated ufunc docstring."""
243    module = ""
244    name = ""
245    docstring = ""
246    if hasattr(func, "__module__") and isinstance(func.__module__, str):
247        module += func.__module__
248    if hasattr(func, "__name__"):
249        name += func.__name__
250    if hasattr(func, "__doc__") and isinstance(func.__doc__, str):
251        docstring += func.__doc__
252    ufunc.__doc__ += "\n\n"
253    if module or name:
254        ufunc.__doc__ += "This function is a ufunc wrapper for "
255        ufunc.__doc__ += module + "." + name
256        ufunc.__doc__ += "\n"
257    ufunc.__doc__ += 'Call ufunc with n_args from xarray against "chain" and "draw" dimensions:'
258    ufunc.__doc__ += "\n\n"
259    input_core_dims = 'tuple(("chain", "draw") for _ in range(n_args))'
260    if n_output > 1:
261        output_core_dims = f" tuple([] for _ in range({n_output}))"
262        msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims}, "
263        msg += f"output_core_dims={ output_core_dims})"
264        ufunc.__doc__ += msg
265    else:
266        output_core_dims = ""
267        msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims})"
268        ufunc.__doc__ += msg
269    ufunc.__doc__ += "\n\n"
270    ufunc.__doc__ += "For example: np.std(data, ddof=1) --> n_args=2"
271    if docstring:
272        ufunc.__doc__ += "\n\n"
273        ufunc.__doc__ += module
274        ufunc.__doc__ += name
275        ufunc.__doc__ += " docstring:"
276        ufunc.__doc__ += "\n\n"
277        ufunc.__doc__ += docstring
280def logsumexp(ary, *, b=None, b_inv=None, axis=None, keepdims=False, out=None, copy=True):
281    """Stable logsumexp when b >= 0 and b is scalar.
283    b_inv overwrites b unless b_inv is None.
284    """
285    # check dimensions for result arrays
286    ary = np.asarray(ary)
287    if ary.dtype.kind == "i":
288        ary = ary.astype(np.float64)
289    dtype = ary.dtype.type
290    shape = ary.shape
291    shape_len = len(shape)
292    if isinstance(axis, Sequence):
293        axis = tuple(axis_i if axis_i >= 0 else shape_len + axis_i for axis_i in axis)
294        agroup = axis
295    else:
296        axis = axis if (axis is None) or (axis >= 0) else shape_len + axis
297        agroup = (axis,)
298    shape_max = (
299        tuple(1 for _ in shape)
300        if axis is None
301        else tuple(1 if i in agroup else d for i, d in enumerate(shape))
302    )
303    # create result arrays
304    if out is None:
305        if not keepdims:
306            out_shape = (
307                tuple()
308                if axis is None
309                else tuple(d for i, d in enumerate(shape) if i not in agroup)
310            )
311        else:
312            out_shape = shape_max
313        out = np.empty(out_shape, dtype=dtype)
314    if b_inv == 0:
315        return np.full_like(out, np.inf, dtype=dtype) if out.shape else np.inf
316    if b_inv is None and b == 0:
317        return np.full_like(out, -np.inf) if out.shape else -np.inf
318    ary_max = np.empty(shape_max, dtype=dtype)
319    # calculations
320    ary.max(axis=axis, keepdims=True, out=ary_max)
321    if copy:
322        ary = ary.copy()
323    ary -= ary_max
324    np.exp(ary, out=ary)
325    ary.sum(axis=axis, keepdims=keepdims, out=out)
326    np.log(out, out=out)
327    if b_inv is not None:
328        ary_max -= np.log(b_inv)
329    elif b:
330        ary_max += np.log(b)
331    out += ary_max.squeeze() if not keepdims else ary_max
332    # transform to scalar if possible
333    return out if out.shape else dtype(out)
336def quantile(ary, q, axis=None, limit=None):
337    """Use same quantile function as R (Type 7)."""
338    if limit is None:
339        limit = tuple()
340    return mquantiles(ary, q, alphap=1, betap=1, axis=axis, limit=limit)
343def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwargs=None):
344    """Validate ndarray.
346    Parameters
347    ----------
348    ary : numpy.ndarray
349    check_nan : bool
350        Check if any value contains NaN.
351    check_shape : bool
352        Check if array has correct shape. Assumes dimensions in order (chain, draw, *shape).
353        For 1D arrays (shape = (n,)) assumes chain equals 1.
354    nan_kwargs : dict
355        Valid kwargs are:
356            axis : int,
357                Defaults to None.
358            how : str, {"all", "any"}
359                Default to "any".
360    shape_kwargs : dict
361        Valid kwargs are:
362            min_chains : int
363                Defaults to 1.
364            min_draws : int
365                Defaults to 4.
367    Returns
368    -------
369    bool
370    """
371    ary = np.asarray(ary)
373    nan_error = False
374    draw_error = False
375    chain_error = False
377    if check_nan:
378        if nan_kwargs is None:
379            nan_kwargs = {}
381        isnan = np.isnan(ary)
382        axis = nan_kwargs.get("axis", None)
383        if nan_kwargs.get("how", "any").lower() == "all":
384            nan_error = isnan.all(axis)
385        else:
386            nan_error = isnan.any(axis)
388        if (isinstance(nan_error, bool) and nan_error) or nan_error.any():
389            _log.warning("Array contains NaN-value.")
391    if check_shape:
392        shape = ary.shape
394        if shape_kwargs is None:
395            shape_kwargs = {}
397        min_chains = shape_kwargs.get("min_chains", 2)
398        min_draws = shape_kwargs.get("min_draws", 4)
399        error_msg = f"Shape validation failed: input_shape: {shape}, "
400        error_msg += f"minimum_shape: (chains={min_chains}, draws={min_draws})"
402        chain_error = ((min_chains > 1) and (len(shape) < 2)) or (shape[0] < min_chains)
403        draw_error = ((len(shape) < 2) and (shape[0] < min_draws)) or (
404            (len(shape) > 1) and (shape[1] < min_draws)
405        )
407        if chain_error or draw_error:
408            _log.warning(error_msg)
410    return nan_error | chain_error | draw_error
413def get_log_likelihood(idata, var_name=None):
414    """Retrieve the log likelihood dataarray of a given variable."""
415    if (
416        not hasattr(idata, "log_likelihood")
417        and hasattr(idata, "sample_stats")
418        and hasattr(idata.sample_stats, "log_likelihood")
419    ):
420        warnings.warn(
421            "Storing the log_likelihood in sample_stats groups has been deprecated",
422            DeprecationWarning,
423        )
424        return idata.sample_stats.log_likelihood
425    if not hasattr(idata, "log_likelihood"):
426        raise TypeError("log likelihood not found in inference data object")
427    if var_name is None:
428        var_names = list(idata.log_likelihood.data_vars)
429        if len(var_names) > 1:
430            raise TypeError(
431                f"Found several log likelihood arrays {var_names}, var_name cannot be None"
432            )
433        return idata.log_likelihood[var_names[0]]
434    else:
435        try:
436            log_likelihood = idata.log_likelihood[var_name]
437        except KeyError as err:
438            raise TypeError(f"No log likelihood data named {var_name} found") from err
439        return log_likelihood
442BASE_FMT = """Computed from {{n_samples}} by {{n_points}} log-likelihood matrix
444{{0:{0}}} Estimate       SE
445{{scale}}_{{kind}} {{1:8.2f}}  {{2:7.2f}}
446p_{{kind:{1}}} {{3:8.2f}}        -"""
447POINTWISE_LOO_FMT = """------
449Pareto k diagnostic values:
450                         {{0:>{0}}} {{1:>6}}
451(-Inf, 0.5]   (good)     {{2:{0}d}} {{6:6.1f}}%
452 (0.5, 0.7]   (ok)       {{3:{0}d}} {{7:6.1f}}%
453   (0.7, 1]   (bad)      {{4:{0}d}} {{8:6.1f}}%
454   (1, Inf)   (very bad) {{5:{0}d}} {{9:6.1f}}%
456SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}
459class ELPDData(pd.Series):  # pylint: disable=too-many-ancestors
460    """Class to contain the data from elpd information criterion like waic or loo."""
462    def __str__(self):
463        """Print elpd data in a user friendly way."""
464        kind = self.index[0]
466        if kind not in ("loo", "waic"):
467            raise ValueError("Invalid ELPDData object")
469        scale_str = SCALE_DICT[self[f"{kind}_scale"]]
470        padding = len(scale_str) + len(kind) + 1
471        base = BASE_FMT.format(padding, padding - 2)
472        base = base.format(
473            "",
474            kind=kind,
475            scale=scale_str,
476            n_samples=self.n_samples,
477            n_points=self.n_data_points,
478            *self.values,
479        )
481        if self.warning:
482            base += "\n\nThere has been a warning during the calculation. Please check the results."
484        if kind == "loo" and "pareto_k" in self:
485            bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf])
486            counts, *_ = _histogram(self.pareto_k.values, bins)
487            extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
488            extended = extended.format(
489                "Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)]
490            )
491            base = "\n".join([base, extended])
492        return base
494    def __repr__(self):
495        """Alias to ``__str__``."""
496        return self.__str__()
498    def copy(self, deep=True):
499        """Perform a pandas deep copy of the ELPDData plus a copy of the stored data."""
500        copied_obj = pd.Series.copy(self)
501        for key in copied_obj.keys():
502            if deep:
503                copied_obj[key] = _deepcopy(copied_obj[key])
504            else:
505                copied_obj[key] = _copy(copied_obj[key])
506        return ELPDData(copied_obj)
510def stats_variance_1d(data, ddof=0):
511    a_a, b_b = 0, 0
512    for i in data:
513        a_a = a_a + i
514        b_b = b_b + i * i
515    var = b_b / (len(data)) - ((a_a / (len(data))) ** 2)
516    var = var * (len(data) / (len(data) - ddof))
517    return var
520def stats_variance_2d(data, ddof=0, axis=1):
521    if data.ndim == 1:
522        return stats_variance_1d(data, ddof=ddof)
523    a_a, b_b = data.shape
524    if axis == 1:
525        var = np.zeros(a_a)
526        for i in range(a_a):
527            var[i] = stats_variance_1d(data[i], ddof=ddof)
528        return var
529    else:
530        var = np.zeros(b_b)
531        for i in range(b_b):
532            var[i] = stats_variance_1d(data[:, i], ddof=ddof)
533        return var
537def _sqrt(a_a, b_b):
538    return (a_a + b_b) ** 0.5
541def _circfunc(samples, high, low, skipna):
542    samples = np.asarray(samples)
543    if skipna:
544        samples = samples[~np.isnan(samples)]
545    if samples.size == 0:
546        return np.nan
547    return _angle(samples, low, high, np.pi)
551def _angle(samples, low, high, p_i=np.pi):
552    ang = (samples - low) * 2.0 * p_i / (high - low)
553    return ang
556def _circular_standard_deviation(samples, high=2 * np.pi, low=0, skipna=False, axis=None):
557    ang = _circfunc(samples, high, low, skipna)
558    s_s = np.sin(ang).mean(axis=axis)
559    c_c = np.cos(ang).mean(axis=axis)
560    r_r = np.hypot(s_s, c_c)
561    return ((high - low) / 2.0 / np.pi) * np.sqrt(-2 * np.log(r_r))
564def smooth_data(obs_vals, pp_vals):
565    """Smooth data, helper function for discrete data in plot_pbv, loo_pit and plot_loo_pit."""
566    x = np.linspace(0, 1, len(obs_vals))
567    csi = CubicSpline(x, obs_vals)
568    obs_vals = csi(np.linspace(0.01, 0.99, len(obs_vals)))
570    x = np.linspace(0, 1, pp_vals.shape[1])
571    csi = CubicSpline(x, pp_vals, axis=1)
572    pp_vals = csi(np.linspace(0.01, 0.99, pp_vals.shape[1]))
574    return obs_vals, pp_vals