1"""Stats-utility functions for ArviZ.""" 2import warnings 3from collections.abc import Sequence 4from copy import copy as _copy 5from copy import deepcopy as _deepcopy 6 7import numpy as np 8import pandas as pd 9from scipy.fftpack import next_fast_len 10from scipy.interpolate import CubicSpline 11from scipy.stats.mstats import mquantiles 12from xarray import apply_ufunc 13 14from .. import _log 15from ..utils import conditional_jit, conditional_vect, conditional_dask 16from .density_utils import histogram as _histogram 17 18 19__all__ = ["autocorr", "autocov", "ELPDData", "make_ufunc", "wrap_xarray_ufunc"] 20 21 22def autocov(ary, axis=-1): 23 """Compute autocovariance estimates for every lag for the input array. 24 25 Parameters 26 ---------- 27 ary : Numpy array 28 An array containing MCMC samples 29 30 Returns 31 ------- 32 acov: Numpy array same size as the input array 33 """ 34 axis = axis if axis > 0 else len(ary.shape) + axis 35 n = ary.shape[axis] 36 m = next_fast_len(2 * n) 37 38 ary = ary - ary.mean(axis, keepdims=True) 39 40 # added to silence tuple warning for a submodule 41 with warnings.catch_warnings(): 42 warnings.simplefilter("ignore") 43 44 ifft_ary = np.fft.rfft(ary, n=m, axis=axis) 45 ifft_ary *= np.conjugate(ifft_ary) 46 47 shape = tuple( 48 slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape) 49 ) 50 cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape] 51 cov /= n 52 53 return cov 54 55 56def autocorr(ary, axis=-1): 57 """Compute autocorrelation using FFT for every lag for the input array. 58 59 See https://en.wikipedia.org/wiki/autocorrelation#Efficient_computation 60 61 Parameters 62 ---------- 63 ary : Numpy array 64 An array containing MCMC samples 65 66 Returns 67 ------- 68 acorr: Numpy array same size as the input array 69 """ 70 corr = autocov(ary, axis=axis) 71 axis = axis = axis if axis > 0 else len(corr.shape) + axis 72 norm = tuple( 73 slice(None, None) if dim != axis else slice(None, 1) for dim, _ in enumerate(corr.shape) 74 ) 75 with np.errstate(invalid="ignore"): 76 corr /= corr[norm] 77 return corr 78 79 80def make_ufunc( 81 func, n_dims=2, n_output=1, n_input=1, index=Ellipsis, ravel=True, check_shape=None 82): # noqa: D202 83 """Make ufunc from a function taking 1D array input. 84 85 Parameters 86 ---------- 87 func : callable 88 n_dims : int, optional 89 Number of core dimensions not broadcasted. Dimensions are skipped from the end. 90 At minimum n_dims > 0. 91 n_output : int, optional 92 Select number of results returned by `func`. 93 If n_output > 1, ufunc returns a tuple of objects else returns an object. 94 n_input : int, optional 95 Number of **array** inputs to func, i.e. ``n_input=2`` means that func is called 96 with ``func(ary1, ary2, *args, **kwargs)`` 97 index : int, optional 98 Slice ndarray with `index`. Defaults to `Ellipsis`. 99 ravel : bool, optional 100 If true, ravel the ndarray before calling `func`. 101 check_shape: bool, optional 102 If false, do not check if the shape of the output is compatible with n_dims and 103 n_output. By default, True only for n_input=1. If n_input is larger than 1, the last 104 input array is used to check the shape, however, shape checking with multiple inputs 105 may not be correct. 106 107 Returns 108 ------- 109 callable 110 ufunc wrapper for `func`. 111 """ 112 if n_dims < 1: 113 raise TypeError("n_dims must be one or higher.") 114 115 if n_input == 1 and check_shape is None: 116 check_shape = True 117 elif check_shape is None: 118 check_shape = False 119 120 def _ufunc(*args, out=None, out_shape=None, **kwargs): 121 """General ufunc for single-output function.""" 122 arys = args[:n_input] 123 n_dims_out = None 124 if out is None: 125 if out_shape is None: 126 out = np.empty(arys[-1].shape[:-n_dims]) 127 else: 128 out = np.empty((*arys[-1].shape[:-n_dims], *out_shape)) 129 n_dims_out = -len(out_shape) 130 elif check_shape: 131 if out.shape != arys[-1].shape[:-n_dims]: 132 msg = f"Shape incorrect for `out`: {out.shape}." 133 msg += f" Correct shape is {arys[-1].shape[:-n_dims]}" 134 raise TypeError(msg) 135 for idx in np.ndindex(out.shape[:n_dims_out]): 136 arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys] 137 out[idx] = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index] 138 return out 139 140 def _multi_ufunc(*args, out=None, out_shape=None, **kwargs): 141 """General ufunc for multi-output function.""" 142 arys = args[:n_input] 143 element_shape = arys[-1].shape[:-n_dims] 144 if out is None: 145 if out_shape is None: 146 out = tuple(np.empty(element_shape) for _ in range(n_output)) 147 else: 148 out = tuple(np.empty((*element_shape, *out_shape[i])) for i in range(n_output)) 149 150 elif check_shape: 151 raise_error = False 152 correct_shape = tuple(element_shape for _ in range(n_output)) 153 if isinstance(out, tuple): 154 out_shape = tuple(item.shape for item in out) 155 if out_shape != correct_shape: 156 raise_error = True 157 else: 158 raise_error = True 159 out_shape = "not tuple, type={type(out)}" 160 if raise_error: 161 msg = f"Shapes incorrect for `out`: {out_shape}." 162 msg += f" Correct shapes are {correct_shape}" 163 raise TypeError(msg) 164 for idx in np.ndindex(element_shape): 165 arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys] 166 results = func(*arys_idx, *args[n_input:], **kwargs) 167 for i, res in enumerate(results): 168 out[i][idx] = np.asarray(res)[index] 169 return out 170 171 if n_output > 1: 172 ufunc = _multi_ufunc 173 else: 174 ufunc = _ufunc 175 176 update_docstring(ufunc, func, n_output) 177 return ufunc 178 179 180@conditional_dask 181def wrap_xarray_ufunc( 182 ufunc, 183 *datasets, 184 ufunc_kwargs=None, 185 func_args=None, 186 func_kwargs=None, 187 dask_kwargs=None, 188 **kwargs, 189): 190 """Wrap make_ufunc with xarray.apply_ufunc. 191 192 Parameters 193 ---------- 194 ufunc : callable 195 datasets : xarray.dataset 196 ufunc_kwargs : dict 197 Keyword arguments passed to `make_ufunc`. 198 - 'n_dims', int, by default 2 199 - 'n_output', int, by default 1 200 - 'n_input', int, by default len(datasets) 201 - 'index', slice, by default Ellipsis 202 - 'ravel', bool, by default True 203 func_args : tuple 204 Arguments passed to 'ufunc'. 205 func_kwargs : dict 206 Keyword arguments passed to 'ufunc'. 207 - 'out_shape', int, by default None 208 dask_kwargs : dict 209 Dask related kwargs passed to :func:`xarray:xarray.apply_ufunc`. 210 Use :meth:`~arviz.Dask.enable_dask` to set default kwargs. 211 **kwargs 212 Passed to xarray.apply_ufunc. 213 214 Returns 215 ------- 216 xarray.dataset 217 """ 218 if ufunc_kwargs is None: 219 ufunc_kwargs = {} 220 ufunc_kwargs.setdefault("n_input", len(datasets)) 221 if func_args is None: 222 func_args = tuple() 223 if func_kwargs is None: 224 func_kwargs = {} 225 if dask_kwargs is None: 226 dask_kwargs = {} 227 228 kwargs.setdefault( 229 "input_core_dims", tuple(("chain", "draw") for _ in range(len(func_args) + len(datasets))) 230 ) 231 ufunc_kwargs.setdefault("n_dims", len(kwargs["input_core_dims"][-1])) 232 kwargs.setdefault("output_core_dims", tuple([] for _ in range(ufunc_kwargs.get("n_output", 1)))) 233 234 callable_ufunc = make_ufunc(ufunc, **ufunc_kwargs) 235 236 return apply_ufunc( 237 callable_ufunc, *datasets, *func_args, kwargs=func_kwargs, **dask_kwargs, **kwargs 238 ) 239 240 241def update_docstring(ufunc, func, n_output=1): 242 """Update ArviZ generated ufunc docstring.""" 243 module = "" 244 name = "" 245 docstring = "" 246 if hasattr(func, "__module__") and isinstance(func.__module__, str): 247 module += func.__module__ 248 if hasattr(func, "__name__"): 249 name += func.__name__ 250 if hasattr(func, "__doc__") and isinstance(func.__doc__, str): 251 docstring += func.__doc__ 252 ufunc.__doc__ += "\n\n" 253 if module or name: 254 ufunc.__doc__ += "This function is a ufunc wrapper for " 255 ufunc.__doc__ += module + "." + name 256 ufunc.__doc__ += "\n" 257 ufunc.__doc__ += 'Call ufunc with n_args from xarray against "chain" and "draw" dimensions:' 258 ufunc.__doc__ += "\n\n" 259 input_core_dims = 'tuple(("chain", "draw") for _ in range(n_args))' 260 if n_output > 1: 261 output_core_dims = f" tuple([] for _ in range({n_output}))" 262 msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims}, " 263 msg += f"output_core_dims={ output_core_dims})" 264 ufunc.__doc__ += msg 265 else: 266 output_core_dims = "" 267 msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims})" 268 ufunc.__doc__ += msg 269 ufunc.__doc__ += "\n\n" 270 ufunc.__doc__ += "For example: np.std(data, ddof=1) --> n_args=2" 271 if docstring: 272 ufunc.__doc__ += "\n\n" 273 ufunc.__doc__ += module 274 ufunc.__doc__ += name 275 ufunc.__doc__ += " docstring:" 276 ufunc.__doc__ += "\n\n" 277 ufunc.__doc__ += docstring 278 279 280def logsumexp(ary, *, b=None, b_inv=None, axis=None, keepdims=False, out=None, copy=True): 281 """Stable logsumexp when b >= 0 and b is scalar. 282 283 b_inv overwrites b unless b_inv is None. 284 """ 285 # check dimensions for result arrays 286 ary = np.asarray(ary) 287 if ary.dtype.kind == "i": 288 ary = ary.astype(np.float64) 289 dtype = ary.dtype.type 290 shape = ary.shape 291 shape_len = len(shape) 292 if isinstance(axis, Sequence): 293 axis = tuple(axis_i if axis_i >= 0 else shape_len + axis_i for axis_i in axis) 294 agroup = axis 295 else: 296 axis = axis if (axis is None) or (axis >= 0) else shape_len + axis 297 agroup = (axis,) 298 shape_max = ( 299 tuple(1 for _ in shape) 300 if axis is None 301 else tuple(1 if i in agroup else d for i, d in enumerate(shape)) 302 ) 303 # create result arrays 304 if out is None: 305 if not keepdims: 306 out_shape = ( 307 tuple() 308 if axis is None 309 else tuple(d for i, d in enumerate(shape) if i not in agroup) 310 ) 311 else: 312 out_shape = shape_max 313 out = np.empty(out_shape, dtype=dtype) 314 if b_inv == 0: 315 return np.full_like(out, np.inf, dtype=dtype) if out.shape else np.inf 316 if b_inv is None and b == 0: 317 return np.full_like(out, -np.inf) if out.shape else -np.inf 318 ary_max = np.empty(shape_max, dtype=dtype) 319 # calculations 320 ary.max(axis=axis, keepdims=True, out=ary_max) 321 if copy: 322 ary = ary.copy() 323 ary -= ary_max 324 np.exp(ary, out=ary) 325 ary.sum(axis=axis, keepdims=keepdims, out=out) 326 np.log(out, out=out) 327 if b_inv is not None: 328 ary_max -= np.log(b_inv) 329 elif b: 330 ary_max += np.log(b) 331 out += ary_max.squeeze() if not keepdims else ary_max 332 # transform to scalar if possible 333 return out if out.shape else dtype(out) 334 335 336def quantile(ary, q, axis=None, limit=None): 337 """Use same quantile function as R (Type 7).""" 338 if limit is None: 339 limit = tuple() 340 return mquantiles(ary, q, alphap=1, betap=1, axis=axis, limit=limit) 341 342 343def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwargs=None): 344 """Validate ndarray. 345 346 Parameters 347 ---------- 348 ary : numpy.ndarray 349 check_nan : bool 350 Check if any value contains NaN. 351 check_shape : bool 352 Check if array has correct shape. Assumes dimensions in order (chain, draw, *shape). 353 For 1D arrays (shape = (n,)) assumes chain equals 1. 354 nan_kwargs : dict 355 Valid kwargs are: 356 axis : int, 357 Defaults to None. 358 how : str, {"all", "any"} 359 Default to "any". 360 shape_kwargs : dict 361 Valid kwargs are: 362 min_chains : int 363 Defaults to 1. 364 min_draws : int 365 Defaults to 4. 366 367 Returns 368 ------- 369 bool 370 """ 371 ary = np.asarray(ary) 372 373 nan_error = False 374 draw_error = False 375 chain_error = False 376 377 if check_nan: 378 if nan_kwargs is None: 379 nan_kwargs = {} 380 381 isnan = np.isnan(ary) 382 axis = nan_kwargs.get("axis", None) 383 if nan_kwargs.get("how", "any").lower() == "all": 384 nan_error = isnan.all(axis) 385 else: 386 nan_error = isnan.any(axis) 387 388 if (isinstance(nan_error, bool) and nan_error) or nan_error.any(): 389 _log.warning("Array contains NaN-value.") 390 391 if check_shape: 392 shape = ary.shape 393 394 if shape_kwargs is None: 395 shape_kwargs = {} 396 397 min_chains = shape_kwargs.get("min_chains", 2) 398 min_draws = shape_kwargs.get("min_draws", 4) 399 error_msg = f"Shape validation failed: input_shape: {shape}, " 400 error_msg += f"minimum_shape: (chains={min_chains}, draws={min_draws})" 401 402 chain_error = ((min_chains > 1) and (len(shape) < 2)) or (shape[0] < min_chains) 403 draw_error = ((len(shape) < 2) and (shape[0] < min_draws)) or ( 404 (len(shape) > 1) and (shape[1] < min_draws) 405 ) 406 407 if chain_error or draw_error: 408 _log.warning(error_msg) 409 410 return nan_error | chain_error | draw_error 411 412 413def get_log_likelihood(idata, var_name=None): 414 """Retrieve the log likelihood dataarray of a given variable.""" 415 if ( 416 not hasattr(idata, "log_likelihood") 417 and hasattr(idata, "sample_stats") 418 and hasattr(idata.sample_stats, "log_likelihood") 419 ): 420 warnings.warn( 421 "Storing the log_likelihood in sample_stats groups has been deprecated", 422 DeprecationWarning, 423 ) 424 return idata.sample_stats.log_likelihood 425 if not hasattr(idata, "log_likelihood"): 426 raise TypeError("log likelihood not found in inference data object") 427 if var_name is None: 428 var_names = list(idata.log_likelihood.data_vars) 429 if len(var_names) > 1: 430 raise TypeError( 431 f"Found several log likelihood arrays {var_names}, var_name cannot be None" 432 ) 433 return idata.log_likelihood[var_names[0]] 434 else: 435 try: 436 log_likelihood = idata.log_likelihood[var_name] 437 except KeyError as err: 438 raise TypeError(f"No log likelihood data named {var_name} found") from err 439 return log_likelihood 440 441 442BASE_FMT = """Computed from {{n_samples}} by {{n_points}} log-likelihood matrix 443 444{{0:{0}}} Estimate SE 445{{scale}}_{{kind}} {{1:8.2f}} {{2:7.2f}} 446p_{{kind:{1}}} {{3:8.2f}} -""" 447POINTWISE_LOO_FMT = """------ 448 449Pareto k diagnostic values: 450 {{0:>{0}}} {{1:>6}} 451(-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}% 452 (0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}% 453 (0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}% 454 (1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}% 455""" 456SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"} 457 458 459class ELPDData(pd.Series): # pylint: disable=too-many-ancestors 460 """Class to contain the data from elpd information criterion like waic or loo.""" 461 462 def __str__(self): 463 """Print elpd data in a user friendly way.""" 464 kind = self.index[0] 465 466 if kind not in ("loo", "waic"): 467 raise ValueError("Invalid ELPDData object") 468 469 scale_str = SCALE_DICT[self[f"{kind}_scale"]] 470 padding = len(scale_str) + len(kind) + 1 471 base = BASE_FMT.format(padding, padding - 2) 472 base = base.format( 473 "", 474 kind=kind, 475 scale=scale_str, 476 n_samples=self.n_samples, 477 n_points=self.n_data_points, 478 *self.values, 479 ) 480 481 if self.warning: 482 base += "\n\nThere has been a warning during the calculation. Please check the results." 483 484 if kind == "loo" and "pareto_k" in self: 485 bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf]) 486 counts, *_ = _histogram(self.pareto_k.values, bins) 487 extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts))))) 488 extended = extended.format( 489 "Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)] 490 ) 491 base = "\n".join([base, extended]) 492 return base 493 494 def __repr__(self): 495 """Alias to ``__str__``.""" 496 return self.__str__() 497 498 def copy(self, deep=True): 499 """Perform a pandas deep copy of the ELPDData plus a copy of the stored data.""" 500 copied_obj = pd.Series.copy(self) 501 for key in copied_obj.keys(): 502 if deep: 503 copied_obj[key] = _deepcopy(copied_obj[key]) 504 else: 505 copied_obj[key] = _copy(copied_obj[key]) 506 return ELPDData(copied_obj) 507 508 509@conditional_jit 510def stats_variance_1d(data, ddof=0): 511 a_a, b_b = 0, 0 512 for i in data: 513 a_a = a_a + i 514 b_b = b_b + i * i 515 var = b_b / (len(data)) - ((a_a / (len(data))) ** 2) 516 var = var * (len(data) / (len(data) - ddof)) 517 return var 518 519 520def stats_variance_2d(data, ddof=0, axis=1): 521 if data.ndim == 1: 522 return stats_variance_1d(data, ddof=ddof) 523 a_a, b_b = data.shape 524 if axis == 1: 525 var = np.zeros(a_a) 526 for i in range(a_a): 527 var[i] = stats_variance_1d(data[i], ddof=ddof) 528 return var 529 else: 530 var = np.zeros(b_b) 531 for i in range(b_b): 532 var[i] = stats_variance_1d(data[:, i], ddof=ddof) 533 return var 534 535 536@conditional_vect 537def _sqrt(a_a, b_b): 538 return (a_a + b_b) ** 0.5 539 540 541def _circfunc(samples, high, low, skipna): 542 samples = np.asarray(samples) 543 if skipna: 544 samples = samples[~np.isnan(samples)] 545 if samples.size == 0: 546 return np.nan 547 return _angle(samples, low, high, np.pi) 548 549 550@conditional_vect 551def _angle(samples, low, high, p_i=np.pi): 552 ang = (samples - low) * 2.0 * p_i / (high - low) 553 return ang 554 555 556def _circular_standard_deviation(samples, high=2 * np.pi, low=0, skipna=False, axis=None): 557 ang = _circfunc(samples, high, low, skipna) 558 s_s = np.sin(ang).mean(axis=axis) 559 c_c = np.cos(ang).mean(axis=axis) 560 r_r = np.hypot(s_s, c_c) 561 return ((high - low) / 2.0 / np.pi) * np.sqrt(-2 * np.log(r_r)) 562 563 564def smooth_data(obs_vals, pp_vals): 565 """Smooth data, helper function for discrete data in plot_pbv, loo_pit and plot_loo_pit.""" 566 x = np.linspace(0, 1, len(obs_vals)) 567 csi = CubicSpline(x, obs_vals) 568 obs_vals = csi(np.linspace(0.01, 0.99, len(obs_vals))) 569 570 x = np.linspace(0, 1, pp_vals.shape[1]) 571 csi = CubicSpline(x, pp_vals, axis=1) 572 pp_vals = csi(np.linspace(0.01, 0.99, pp_vals.shape[1])) 573 574 return obs_vals, pp_vals 575