1"""Stats-utility functions for ArviZ."""
2import warnings
3from collections.abc import Sequence
4from copy import copy as _copy
5from copy import deepcopy as _deepcopy
6
7import numpy as np
8import pandas as pd
9from scipy.fftpack import next_fast_len
10from scipy.interpolate import CubicSpline
11from scipy.stats.mstats import mquantiles
12from xarray import apply_ufunc
13
14from .. import _log
15from ..utils import conditional_jit, conditional_vect, conditional_dask
16from .density_utils import histogram as _histogram
17
18
19__all__ = ["autocorr", "autocov", "ELPDData", "make_ufunc", "wrap_xarray_ufunc"]
20
21
22def autocov(ary, axis=-1):
23    """Compute autocovariance estimates for every lag for the input array.
24
25    Parameters
26    ----------
27    ary : Numpy array
28        An array containing MCMC samples
29
30    Returns
31    -------
32    acov: Numpy array same size as the input array
33    """
34    axis = axis if axis > 0 else len(ary.shape) + axis
35    n = ary.shape[axis]
36    m = next_fast_len(2 * n)
37
38    ary = ary - ary.mean(axis, keepdims=True)
39
40    # added to silence tuple warning for a submodule
41    with warnings.catch_warnings():
42        warnings.simplefilter("ignore")
43
44        ifft_ary = np.fft.rfft(ary, n=m, axis=axis)
45        ifft_ary *= np.conjugate(ifft_ary)
46
47        shape = tuple(
48            slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape)
49        )
50        cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape]
51        cov /= n
52
53    return cov
54
55
56def autocorr(ary, axis=-1):
57    """Compute autocorrelation using FFT for every lag for the input array.
58
59    See https://en.wikipedia.org/wiki/autocorrelation#Efficient_computation
60
61    Parameters
62    ----------
63    ary : Numpy array
64        An array containing MCMC samples
65
66    Returns
67    -------
68    acorr: Numpy array same size as the input array
69    """
70    corr = autocov(ary, axis=axis)
71    axis = axis = axis if axis > 0 else len(corr.shape) + axis
72    norm = tuple(
73        slice(None, None) if dim != axis else slice(None, 1) for dim, _ in enumerate(corr.shape)
74    )
75    with np.errstate(invalid="ignore"):
76        corr /= corr[norm]
77    return corr
78
79
80def make_ufunc(
81    func, n_dims=2, n_output=1, n_input=1, index=Ellipsis, ravel=True, check_shape=None
82):  # noqa: D202
83    """Make ufunc from a function taking 1D array input.
84
85    Parameters
86    ----------
87    func : callable
88    n_dims : int, optional
89        Number of core dimensions not broadcasted. Dimensions are skipped from the end.
90        At minimum n_dims > 0.
91    n_output : int, optional
92        Select number of results returned by `func`.
93        If n_output > 1, ufunc returns a tuple of objects else returns an object.
94    n_input : int, optional
95        Number of **array** inputs to func, i.e. ``n_input=2`` means that func is called
96        with ``func(ary1, ary2, *args, **kwargs)``
97    index : int, optional
98        Slice ndarray with `index`. Defaults to `Ellipsis`.
99    ravel : bool, optional
100        If true, ravel the ndarray before calling `func`.
101    check_shape: bool, optional
102        If false, do not check if the shape of the output is compatible with n_dims and
103        n_output. By default, True only for n_input=1. If n_input is larger than 1, the last
104        input array is used to check the shape, however, shape checking with multiple inputs
105        may not be correct.
106
107    Returns
108    -------
109    callable
110        ufunc wrapper for `func`.
111    """
112    if n_dims < 1:
113        raise TypeError("n_dims must be one or higher.")
114
115    if n_input == 1 and check_shape is None:
116        check_shape = True
117    elif check_shape is None:
118        check_shape = False
119
120    def _ufunc(*args, out=None, out_shape=None, **kwargs):
121        """General ufunc for single-output function."""
122        arys = args[:n_input]
123        n_dims_out = None
124        if out is None:
125            if out_shape is None:
126                out = np.empty(arys[-1].shape[:-n_dims])
127            else:
128                out = np.empty((*arys[-1].shape[:-n_dims], *out_shape))
129                n_dims_out = -len(out_shape)
130        elif check_shape:
131            if out.shape != arys[-1].shape[:-n_dims]:
132                msg = f"Shape incorrect for `out`: {out.shape}."
133                msg += f" Correct shape is {arys[-1].shape[:-n_dims]}"
134                raise TypeError(msg)
135        for idx in np.ndindex(out.shape[:n_dims_out]):
136            arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
137            out[idx] = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index]
138        return out
139
140    def _multi_ufunc(*args, out=None, out_shape=None, **kwargs):
141        """General ufunc for multi-output function."""
142        arys = args[:n_input]
143        element_shape = arys[-1].shape[:-n_dims]
144        if out is None:
145            if out_shape is None:
146                out = tuple(np.empty(element_shape) for _ in range(n_output))
147            else:
148                out = tuple(np.empty((*element_shape, *out_shape[i])) for i in range(n_output))
149
150        elif check_shape:
151            raise_error = False
152            correct_shape = tuple(element_shape for _ in range(n_output))
153            if isinstance(out, tuple):
154                out_shape = tuple(item.shape for item in out)
155                if out_shape != correct_shape:
156                    raise_error = True
157            else:
158                raise_error = True
159                out_shape = "not tuple, type={type(out)}"
160            if raise_error:
161                msg = f"Shapes incorrect for `out`: {out_shape}."
162                msg += f" Correct shapes are {correct_shape}"
163                raise TypeError(msg)
164        for idx in np.ndindex(element_shape):
165            arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys]
166            results = func(*arys_idx, *args[n_input:], **kwargs)
167            for i, res in enumerate(results):
168                out[i][idx] = np.asarray(res)[index]
169        return out
170
171    if n_output > 1:
172        ufunc = _multi_ufunc
173    else:
174        ufunc = _ufunc
175
176    update_docstring(ufunc, func, n_output)
177    return ufunc
178
179
180@conditional_dask
181def wrap_xarray_ufunc(
182    ufunc,
183    *datasets,
184    ufunc_kwargs=None,
185    func_args=None,
186    func_kwargs=None,
187    dask_kwargs=None,
188    **kwargs,
189):
190    """Wrap make_ufunc with xarray.apply_ufunc.
191
192    Parameters
193    ----------
194    ufunc : callable
195    datasets : xarray.dataset
196    ufunc_kwargs : dict
197        Keyword arguments passed to `make_ufunc`.
198            - 'n_dims', int, by default 2
199            - 'n_output', int, by default 1
200            - 'n_input', int, by default len(datasets)
201            - 'index', slice, by default Ellipsis
202            - 'ravel', bool, by default True
203    func_args : tuple
204        Arguments passed to 'ufunc'.
205    func_kwargs : dict
206        Keyword arguments passed to 'ufunc'.
207            - 'out_shape', int, by default None
208    dask_kwargs : dict
209        Dask related kwargs passed to :func:`xarray:xarray.apply_ufunc`.
210        Use :meth:`~arviz.Dask.enable_dask` to set default kwargs.
211    **kwargs
212        Passed to xarray.apply_ufunc.
213
214    Returns
215    -------
216    xarray.dataset
217    """
218    if ufunc_kwargs is None:
219        ufunc_kwargs = {}
220    ufunc_kwargs.setdefault("n_input", len(datasets))
221    if func_args is None:
222        func_args = tuple()
223    if func_kwargs is None:
224        func_kwargs = {}
225    if dask_kwargs is None:
226        dask_kwargs = {}
227
228    kwargs.setdefault(
229        "input_core_dims", tuple(("chain", "draw") for _ in range(len(func_args) + len(datasets)))
230    )
231    ufunc_kwargs.setdefault("n_dims", len(kwargs["input_core_dims"][-1]))
232    kwargs.setdefault("output_core_dims", tuple([] for _ in range(ufunc_kwargs.get("n_output", 1))))
233
234    callable_ufunc = make_ufunc(ufunc, **ufunc_kwargs)
235
236    return apply_ufunc(
237        callable_ufunc, *datasets, *func_args, kwargs=func_kwargs, **dask_kwargs, **kwargs
238    )
239
240
241def update_docstring(ufunc, func, n_output=1):
242    """Update ArviZ generated ufunc docstring."""
243    module = ""
244    name = ""
245    docstring = ""
246    if hasattr(func, "__module__") and isinstance(func.__module__, str):
247        module += func.__module__
248    if hasattr(func, "__name__"):
249        name += func.__name__
250    if hasattr(func, "__doc__") and isinstance(func.__doc__, str):
251        docstring += func.__doc__
252    ufunc.__doc__ += "\n\n"
253    if module or name:
254        ufunc.__doc__ += "This function is a ufunc wrapper for "
255        ufunc.__doc__ += module + "." + name
256        ufunc.__doc__ += "\n"
257    ufunc.__doc__ += 'Call ufunc with n_args from xarray against "chain" and "draw" dimensions:'
258    ufunc.__doc__ += "\n\n"
259    input_core_dims = 'tuple(("chain", "draw") for _ in range(n_args))'
260    if n_output > 1:
261        output_core_dims = f" tuple([] for _ in range({n_output}))"
262        msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims}, "
263        msg += f"output_core_dims={ output_core_dims})"
264        ufunc.__doc__ += msg
265    else:
266        output_core_dims = ""
267        msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims})"
268        ufunc.__doc__ += msg
269    ufunc.__doc__ += "\n\n"
270    ufunc.__doc__ += "For example: np.std(data, ddof=1) --> n_args=2"
271    if docstring:
272        ufunc.__doc__ += "\n\n"
273        ufunc.__doc__ += module
274        ufunc.__doc__ += name
275        ufunc.__doc__ += " docstring:"
276        ufunc.__doc__ += "\n\n"
277        ufunc.__doc__ += docstring
278
279
280def logsumexp(ary, *, b=None, b_inv=None, axis=None, keepdims=False, out=None, copy=True):
281    """Stable logsumexp when b >= 0 and b is scalar.
282
283    b_inv overwrites b unless b_inv is None.
284    """
285    # check dimensions for result arrays
286    ary = np.asarray(ary)
287    if ary.dtype.kind == "i":
288        ary = ary.astype(np.float64)
289    dtype = ary.dtype.type
290    shape = ary.shape
291    shape_len = len(shape)
292    if isinstance(axis, Sequence):
293        axis = tuple(axis_i if axis_i >= 0 else shape_len + axis_i for axis_i in axis)
294        agroup = axis
295    else:
296        axis = axis if (axis is None) or (axis >= 0) else shape_len + axis
297        agroup = (axis,)
298    shape_max = (
299        tuple(1 for _ in shape)
300        if axis is None
301        else tuple(1 if i in agroup else d for i, d in enumerate(shape))
302    )
303    # create result arrays
304    if out is None:
305        if not keepdims:
306            out_shape = (
307                tuple()
308                if axis is None
309                else tuple(d for i, d in enumerate(shape) if i not in agroup)
310            )
311        else:
312            out_shape = shape_max
313        out = np.empty(out_shape, dtype=dtype)
314    if b_inv == 0:
315        return np.full_like(out, np.inf, dtype=dtype) if out.shape else np.inf
316    if b_inv is None and b == 0:
317        return np.full_like(out, -np.inf) if out.shape else -np.inf
318    ary_max = np.empty(shape_max, dtype=dtype)
319    # calculations
320    ary.max(axis=axis, keepdims=True, out=ary_max)
321    if copy:
322        ary = ary.copy()
323    ary -= ary_max
324    np.exp(ary, out=ary)
325    ary.sum(axis=axis, keepdims=keepdims, out=out)
326    np.log(out, out=out)
327    if b_inv is not None:
328        ary_max -= np.log(b_inv)
329    elif b:
330        ary_max += np.log(b)
331    out += ary_max.squeeze() if not keepdims else ary_max
332    # transform to scalar if possible
333    return out if out.shape else dtype(out)
334
335
336def quantile(ary, q, axis=None, limit=None):
337    """Use same quantile function as R (Type 7)."""
338    if limit is None:
339        limit = tuple()
340    return mquantiles(ary, q, alphap=1, betap=1, axis=axis, limit=limit)
341
342
343def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwargs=None):
344    """Validate ndarray.
345
346    Parameters
347    ----------
348    ary : numpy.ndarray
349    check_nan : bool
350        Check if any value contains NaN.
351    check_shape : bool
352        Check if array has correct shape. Assumes dimensions in order (chain, draw, *shape).
353        For 1D arrays (shape = (n,)) assumes chain equals 1.
354    nan_kwargs : dict
355        Valid kwargs are:
356            axis : int,
357                Defaults to None.
358            how : str, {"all", "any"}
359                Default to "any".
360    shape_kwargs : dict
361        Valid kwargs are:
362            min_chains : int
363                Defaults to 1.
364            min_draws : int
365                Defaults to 4.
366
367    Returns
368    -------
369    bool
370    """
371    ary = np.asarray(ary)
372
373    nan_error = False
374    draw_error = False
375    chain_error = False
376
377    if check_nan:
378        if nan_kwargs is None:
379            nan_kwargs = {}
380
381        isnan = np.isnan(ary)
382        axis = nan_kwargs.get("axis", None)
383        if nan_kwargs.get("how", "any").lower() == "all":
384            nan_error = isnan.all(axis)
385        else:
386            nan_error = isnan.any(axis)
387
388        if (isinstance(nan_error, bool) and nan_error) or nan_error.any():
389            _log.warning("Array contains NaN-value.")
390
391    if check_shape:
392        shape = ary.shape
393
394        if shape_kwargs is None:
395            shape_kwargs = {}
396
397        min_chains = shape_kwargs.get("min_chains", 2)
398        min_draws = shape_kwargs.get("min_draws", 4)
399        error_msg = f"Shape validation failed: input_shape: {shape}, "
400        error_msg += f"minimum_shape: (chains={min_chains}, draws={min_draws})"
401
402        chain_error = ((min_chains > 1) and (len(shape) < 2)) or (shape[0] < min_chains)
403        draw_error = ((len(shape) < 2) and (shape[0] < min_draws)) or (
404            (len(shape) > 1) and (shape[1] < min_draws)
405        )
406
407        if chain_error or draw_error:
408            _log.warning(error_msg)
409
410    return nan_error | chain_error | draw_error
411
412
413def get_log_likelihood(idata, var_name=None):
414    """Retrieve the log likelihood dataarray of a given variable."""
415    if (
416        not hasattr(idata, "log_likelihood")
417        and hasattr(idata, "sample_stats")
418        and hasattr(idata.sample_stats, "log_likelihood")
419    ):
420        warnings.warn(
421            "Storing the log_likelihood in sample_stats groups has been deprecated",
422            DeprecationWarning,
423        )
424        return idata.sample_stats.log_likelihood
425    if not hasattr(idata, "log_likelihood"):
426        raise TypeError("log likelihood not found in inference data object")
427    if var_name is None:
428        var_names = list(idata.log_likelihood.data_vars)
429        if len(var_names) > 1:
430            raise TypeError(
431                f"Found several log likelihood arrays {var_names}, var_name cannot be None"
432            )
433        return idata.log_likelihood[var_names[0]]
434    else:
435        try:
436            log_likelihood = idata.log_likelihood[var_name]
437        except KeyError as err:
438            raise TypeError(f"No log likelihood data named {var_name} found") from err
439        return log_likelihood
440
441
442BASE_FMT = """Computed from {{n_samples}} by {{n_points}} log-likelihood matrix
443
444{{0:{0}}} Estimate       SE
445{{scale}}_{{kind}} {{1:8.2f}}  {{2:7.2f}}
446p_{{kind:{1}}} {{3:8.2f}}        -"""
447POINTWISE_LOO_FMT = """------
448
449Pareto k diagnostic values:
450                         {{0:>{0}}} {{1:>6}}
451(-Inf, 0.5]   (good)     {{2:{0}d}} {{6:6.1f}}%
452 (0.5, 0.7]   (ok)       {{3:{0}d}} {{7:6.1f}}%
453   (0.7, 1]   (bad)      {{4:{0}d}} {{8:6.1f}}%
454   (1, Inf)   (very bad) {{5:{0}d}} {{9:6.1f}}%
455"""
456SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"}
457
458
459class ELPDData(pd.Series):  # pylint: disable=too-many-ancestors
460    """Class to contain the data from elpd information criterion like waic or loo."""
461
462    def __str__(self):
463        """Print elpd data in a user friendly way."""
464        kind = self.index[0]
465
466        if kind not in ("loo", "waic"):
467            raise ValueError("Invalid ELPDData object")
468
469        scale_str = SCALE_DICT[self[f"{kind}_scale"]]
470        padding = len(scale_str) + len(kind) + 1
471        base = BASE_FMT.format(padding, padding - 2)
472        base = base.format(
473            "",
474            kind=kind,
475            scale=scale_str,
476            n_samples=self.n_samples,
477            n_points=self.n_data_points,
478            *self.values,
479        )
480
481        if self.warning:
482            base += "\n\nThere has been a warning during the calculation. Please check the results."
483
484        if kind == "loo" and "pareto_k" in self:
485            bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf])
486            counts, *_ = _histogram(self.pareto_k.values, bins)
487            extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts)))))
488            extended = extended.format(
489                "Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)]
490            )
491            base = "\n".join([base, extended])
492        return base
493
494    def __repr__(self):
495        """Alias to ``__str__``."""
496        return self.__str__()
497
498    def copy(self, deep=True):
499        """Perform a pandas deep copy of the ELPDData plus a copy of the stored data."""
500        copied_obj = pd.Series.copy(self)
501        for key in copied_obj.keys():
502            if deep:
503                copied_obj[key] = _deepcopy(copied_obj[key])
504            else:
505                copied_obj[key] = _copy(copied_obj[key])
506        return ELPDData(copied_obj)
507
508
509@conditional_jit
510def stats_variance_1d(data, ddof=0):
511    a_a, b_b = 0, 0
512    for i in data:
513        a_a = a_a + i
514        b_b = b_b + i * i
515    var = b_b / (len(data)) - ((a_a / (len(data))) ** 2)
516    var = var * (len(data) / (len(data) - ddof))
517    return var
518
519
520def stats_variance_2d(data, ddof=0, axis=1):
521    if data.ndim == 1:
522        return stats_variance_1d(data, ddof=ddof)
523    a_a, b_b = data.shape
524    if axis == 1:
525        var = np.zeros(a_a)
526        for i in range(a_a):
527            var[i] = stats_variance_1d(data[i], ddof=ddof)
528        return var
529    else:
530        var = np.zeros(b_b)
531        for i in range(b_b):
532            var[i] = stats_variance_1d(data[:, i], ddof=ddof)
533        return var
534
535
536@conditional_vect
537def _sqrt(a_a, b_b):
538    return (a_a + b_b) ** 0.5
539
540
541def _circfunc(samples, high, low, skipna):
542    samples = np.asarray(samples)
543    if skipna:
544        samples = samples[~np.isnan(samples)]
545    if samples.size == 0:
546        return np.nan
547    return _angle(samples, low, high, np.pi)
548
549
550@conditional_vect
551def _angle(samples, low, high, p_i=np.pi):
552    ang = (samples - low) * 2.0 * p_i / (high - low)
553    return ang
554
555
556def _circular_standard_deviation(samples, high=2 * np.pi, low=0, skipna=False, axis=None):
557    ang = _circfunc(samples, high, low, skipna)
558    s_s = np.sin(ang).mean(axis=axis)
559    c_c = np.cos(ang).mean(axis=axis)
560    r_r = np.hypot(s_s, c_c)
561    return ((high - low) / 2.0 / np.pi) * np.sqrt(-2 * np.log(r_r))
562
563
564def smooth_data(obs_vals, pp_vals):
565    """Smooth data, helper function for discrete data in plot_pbv, loo_pit and plot_loo_pit."""
566    x = np.linspace(0, 1, len(obs_vals))
567    csi = CubicSpline(x, obs_vals)
568    obs_vals = csi(np.linspace(0.01, 0.99, len(obs_vals)))
569
570    x = np.linspace(0, 1, pp_vals.shape[1])
571    csi = CubicSpline(x, pp_vals, axis=1)
572    pp_vals = csi(np.linspace(0.01, 0.99, pp_vals.shape[1]))
573
574    return obs_vals, pp_vals
575