1# pylint: disable=too-many-nested-blocks 2"""General utilities.""" 3import functools 4import importlib 5import re 6import warnings 7from functools import lru_cache 8 9import matplotlib.pyplot as plt 10import numpy as np 11import pkg_resources 12from numpy import newaxis 13 14from .rcparams import rcParams 15 16STATIC_FILES = ("static/html/icons-svg-inline.html", "static/css/style.css") 17 18 19def _var_names(var_names, data, filter_vars=None): 20 """Handle var_names input across arviz. 21 22 Parameters 23 ---------- 24 var_names: str, list, or None 25 data : xarray.Dataset 26 Posterior data in an xarray 27 filter_vars: {None, "like", "regex"}, optional, default=None 28 If `None` (default), interpret var_names as the real variables names. If "like", 29 interpret var_names as substrings of the real variables names. If "regex", 30 interpret var_names as regular expressions on the real variables names. A la 31 `pandas.filter`. 32 33 Returns 34 ------- 35 var_name: list or None 36 """ 37 if filter_vars not in {None, "like", "regex"}: 38 raise ValueError( 39 f"'filter_vars' can only be None, 'like', or 'regex', got: '{filter_vars}'" 40 ) 41 42 if var_names is not None: 43 if isinstance(data, (list, tuple)): 44 all_vars = [] 45 for dataset in data: 46 dataset_vars = list(dataset.data_vars) 47 for var in dataset_vars: 48 if var not in all_vars: 49 all_vars.append(var) 50 else: 51 all_vars = list(data.data_vars) 52 53 all_vars_tilde = [var for var in all_vars if var.startswith("~")] 54 if all_vars_tilde: 55 warnings.warn( 56 """ArviZ treats '~' as a negation character for variable selection. 57 Your model has variables names starting with '~', {0}. Please double check 58 your results to ensure all variables are included""".format( 59 ", ".join(all_vars_tilde) 60 ) 61 ) 62 63 try: 64 var_names = _subset_list(var_names, all_vars, filter_items=filter_vars, warn=False) 65 except KeyError as err: 66 msg = " ".join(("var names:", f"{err}", "in dataset")) 67 raise KeyError(msg) from err 68 return var_names 69 70 71def _subset_list(subset, whole_list, filter_items=None, warn=True): 72 """Handle list subsetting (var_names, groups...) across arviz. 73 74 Parameters 75 ---------- 76 subset : str, list, or None 77 whole_list : list 78 List from which to select a subset according to subset elements and 79 filter_items value. 80 filter_items : {None, "like", "regex"}, optional 81 If `None` (default), interpret `subset` as the exact elements in `whole_list` 82 names. If "like", interpret `subset` as substrings of the elements in 83 `whole_list`. If "regex", interpret `subset` as regular expressions to match 84 elements in `whole_list`. A la `pandas.filter`. 85 86 Returns 87 ------- 88 list or None 89 A subset of ``whole_list`` fulfilling the requests imposed by ``subset`` 90 and ``filter_items``. 91 """ 92 if subset is not None: 93 94 if isinstance(subset, str): 95 subset = [subset] 96 97 whole_list_tilde = [item for item in whole_list if item.startswith("~")] 98 if whole_list_tilde and warn: 99 warnings.warn( 100 "ArviZ treats '~' as a negation character for selection. There are " 101 "elements in `whole_list` starting with '~', {0}. Please double check" 102 "your results to ensure all elements are included".format( 103 ", ".join(whole_list_tilde) 104 ) 105 ) 106 107 excluded_items = [ 108 item[1:] for item in subset if item.startswith("~") and item not in whole_list 109 ] 110 filter_items = str(filter_items).lower() 111 not_found = [] 112 113 if excluded_items: 114 if filter_items in ("like", "regex"): 115 for pattern in excluded_items[:]: 116 excluded_items.remove(pattern) 117 if filter_items == "like": 118 real_items = [real_item for real_item in whole_list if pattern in real_item] 119 else: 120 # i.e filter_items == "regex" 121 real_items = [ 122 real_item for real_item in whole_list if re.search(pattern, real_item) 123 ] 124 if not real_items: 125 not_found.append(pattern) 126 excluded_items.extend(real_items) 127 not_found.extend([item for item in excluded_items if item not in whole_list]) 128 if not_found: 129 warnings.warn( 130 f"Items starting with ~: {not_found} have not been found and will be ignored" 131 ) 132 subset = [item for item in whole_list if item not in excluded_items] 133 134 else: 135 if filter_items == "like": 136 subset = [item for item in whole_list for name in subset if name in item] 137 elif filter_items == "regex": 138 subset = [item for item in whole_list for name in subset if re.search(name, item)] 139 140 existing_items = np.isin(subset, whole_list) 141 if not np.all(existing_items): 142 raise KeyError(f"{np.array(subset)[~existing_items]} are not present") 143 144 return subset 145 146 147class lazy_property: # pylint: disable=invalid-name 148 """Used to load numba first time it is needed.""" 149 150 def __init__(self, fget): 151 """Lazy load a property with `fget`.""" 152 self.fget = fget 153 154 # copy the getter function's docstring and other attributes 155 functools.update_wrapper(self, fget) 156 157 def __get__(self, obj, cls): 158 """Call the function, set the attribute.""" 159 if obj is None: 160 return self 161 162 value = self.fget(obj) 163 setattr(obj, self.fget.__name__, value) 164 return value 165 166 167class maybe_numba_fn: # pylint: disable=invalid-name 168 """Wrap a function to (maybe) use a (lazy) jit-compiled version.""" 169 170 def __init__(self, function, **kwargs): 171 """Wrap a function and save compilation keywords.""" 172 self.function = function 173 self.kwargs = kwargs 174 175 @lazy_property 176 def numba_fn(self): 177 """Memoized compiled function.""" 178 try: 179 numba = importlib.import_module("numba") 180 numba_fn = numba.jit(**self.kwargs)(self.function) 181 except ImportError: 182 numba_fn = self.function 183 return numba_fn 184 185 def __call__(self, *args, **kwargs): 186 """Call the jitted function or normal, depending on flag.""" 187 if Numba.numba_flag: 188 return self.numba_fn(*args, **kwargs) 189 else: 190 return self.function(*args, **kwargs) 191 192 193class interactive_backend: # pylint: disable=invalid-name 194 """Context manager to change backend temporarily in ipython sesson. 195 196 It uses ipython magic to change temporarily from the ipython inline backend to 197 an interactive backend of choice. It cannot be used outside ipython sessions nor 198 to change backends different than inline -> interactive. 199 200 Notes 201 ----- 202 The first time ``interactive_backend`` context manager is called, any of the available 203 interactive backends can be chosen. The following times, this same backend must be used 204 unless the kernel is restarted. 205 206 Parameters 207 ---------- 208 backend : str, optional 209 Interactive backend to use. It will be passed to ``%matplotlib`` magic, refer to 210 its docs to see available options. 211 212 Examples 213 -------- 214 Inside an ipython session (i.e. a jupyter notebook) with the inline backend set: 215 216 .. code:: 217 218 >>> import arviz as az 219 >>> idata = az.load_arviz_data("centered_eight") 220 >>> az.plot_posterior(idata) # inline 221 >>> with az.interactive_backend(): 222 ... az.plot_density(idata) # interactive 223 >>> az.plot_trace(idata) # inline 224 225 """ 226 227 # based on matplotlib.rc_context 228 def __init__(self, backend=""): 229 """Initialize context manager.""" 230 try: 231 from IPython import get_ipython 232 except ImportError as err: 233 raise ImportError( 234 "The exception below was risen while importing Ipython, this " 235 "context manager can only be used inside ipython sessions:\n{}".format(err) 236 ) from err 237 self.ipython = get_ipython() 238 if self.ipython is None: 239 raise EnvironmentError("This context manager can only be used inside ipython sessions") 240 self.ipython.magic(f"matplotlib {backend}") 241 242 def __enter__(self): 243 """Enter context manager.""" 244 return self 245 246 def __exit__(self, exc_type, exc_value, exc_tb): 247 """Exit context manager.""" 248 plt.show(block=True) 249 self.ipython.magic("matplotlib inline") 250 251 252def conditional_jit(_func=None, **kwargs): 253 """Use numba's jit decorator if numba is installed. 254 255 Notes 256 ----- 257 If called without arguments then return wrapped function. 258 259 @conditional_jit 260 def my_func(): 261 return 262 263 else called with arguments 264 265 @conditional_jit(nopython=True) 266 def my_func(): 267 return 268 269 """ 270 if _func is None: 271 return lambda fn: functools.wraps(fn)(maybe_numba_fn(fn, **kwargs)) 272 else: 273 lazy_numba = maybe_numba_fn(_func, **kwargs) 274 return functools.wraps(_func)(lazy_numba) 275 276 277def conditional_vect(function=None, **kwargs): # noqa: D202 278 """Use numba's vectorize decorator if numba is installed. 279 280 Notes 281 ----- 282 If called without arguments then return wrapped function. 283 @conditional_vect 284 def my_func(): 285 return 286 else called with arguments 287 @conditional_vect(nopython=True) 288 def my_func(): 289 return 290 291 """ 292 293 def wrapper(function): 294 try: 295 numba = importlib.import_module("numba") 296 return numba.vectorize(**kwargs)(function) 297 298 except ImportError: 299 return function 300 301 if function: 302 return wrapper(function) 303 else: 304 return wrapper 305 306 307def numba_check(): 308 """Check if numba is installed.""" 309 numba = importlib.util.find_spec("numba") 310 return numba is not None 311 312 313class Numba: 314 """A class to toggle numba states.""" 315 316 numba_flag = numba_check() 317 318 @classmethod 319 def disable_numba(cls): 320 """To disable numba.""" 321 cls.numba_flag = False 322 323 @classmethod 324 def enable_numba(cls): 325 """To enable numba.""" 326 if numba_check(): 327 cls.numba_flag = True 328 else: 329 raise ValueError("Numba is not installed") 330 331 332def _numba_var(numba_function, standard_numpy_func, data, axis=None, ddof=0): 333 """Replace the numpy methods used to calculate variance. 334 335 Parameters 336 ---------- 337 numba_function : function() 338 Custom numba function included in stats/stats_utils.py. 339 340 standard_numpy_func: function() 341 Standard function included in the numpy library. 342 343 data : array. 344 axis : axis along which the variance is calculated. 345 ddof : degrees of freedom allowed while calculating variance. 346 347 Returns 348 ------- 349 array: 350 variance values calculate by appropriate function for numba speedup 351 if Numba is installed or enabled. 352 353 """ 354 if Numba.numba_flag: 355 return numba_function(data, axis=axis, ddof=ddof) 356 else: 357 return standard_numpy_func(data, axis=axis, ddof=ddof) 358 359 360def _stack(x, y): 361 assert x.shape[1:] == y.shape[1:] 362 return np.vstack((x, y)) 363 364 365def arange(x): 366 """Jitting numpy arange.""" 367 return np.arange(x) 368 369 370def one_de(x): 371 """Jitting numpy atleast_1d.""" 372 if not isinstance(x, np.ndarray): 373 return np.atleast_1d(x) 374 if x.ndim == 0: 375 result = x.reshape(1) 376 else: 377 result = x 378 return result 379 380 381def two_de(x): 382 """Jitting numpy at_least_2d.""" 383 if not isinstance(x, np.ndarray): 384 return np.atleast_2d(x) 385 if x.ndim == 0: 386 result = x.reshape(1, 1) 387 elif x.ndim == 1: 388 result = x[newaxis, :] 389 else: 390 result = x 391 return result 392 393 394def expand_dims(x): 395 """Jitting numpy expand_dims.""" 396 if not isinstance(x, np.ndarray): 397 return np.expand_dims(x, 0) 398 shape = x.shape 399 return x.reshape(shape[:0] + (1,) + shape[0:]) 400 401 402@conditional_jit(cache=True, nopython=True) 403def _dot(x, y): 404 return np.dot(x, y) 405 406 407@conditional_jit(cache=True, nopython=True) 408def _cov_1d(x): 409 x = x - x.mean(axis=0) 410 ddof = x.shape[0] - 1 411 return np.dot(x.T, x.conj()) / ddof 412 413 414# @conditional_jit(cache=True) 415def _cov(data): 416 if data.ndim == 1: 417 return _cov_1d(data) 418 elif data.ndim == 2: 419 x = data.astype(float) 420 avg, _ = np.average(x, axis=1, weights=None, returned=True) 421 ddof = x.shape[1] - 1 422 if ddof <= 0: 423 warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) 424 ddof = 0.0 425 x -= avg[:, None] 426 prod = _dot(x, x.T.conj()) 427 prod *= np.true_divide(1, ddof) 428 prod = prod.squeeze() 429 prod += 1e-6 * np.eye(prod.shape[0]) 430 return prod 431 else: 432 raise ValueError(f"{data.ndim} dimension arrays are not supported") 433 434 435def flatten_inference_data_to_dict( 436 data, 437 var_names=None, 438 groups=None, 439 dimensions=None, 440 group_info=False, 441 var_name_format=None, 442 index_origin=None, 443): 444 """Transform data to dictionary. 445 446 Parameters 447 ---------- 448 data : obj 449 Any object that can be converted to an az.InferenceData object 450 Refer to documentation of az.convert_to_inference_data for details 451 var_names : str or list of str, optional 452 Variables to be processed, if None all variables are processed. 453 groups : str or list of str, optional 454 Select groups for CDS. Default groups are 455 {"posterior_groups", "prior_groups", "posterior_groups_warmup"} 456 - posterior_groups: posterior, posterior_predictive, sample_stats 457 - prior_groups: prior, prior_predictive, sample_stats_prior 458 - posterior_groups_warmup: warmup_posterior, warmup_posterior_predictive, 459 warmup_sample_stats 460 ignore_groups : str or list of str, optional 461 Ignore specific groups from CDS. 462 dimension : str, or list of str, optional 463 Select dimensions along to slice the data. By default uses ("chain", "draw"). 464 group_info : bool 465 Add group info for `var_name_format` 466 var_name_format : str or tuple of tuple of string, optional 467 Select column name format for non-scalar input. 468 Predefined options are {"brackets", "underscore", "cds"} 469 "brackets": 470 - add_group_info == False: theta[0,0] 471 - add_group_info == True: theta_posterior[0,0] 472 "underscore": 473 - add_group_info == False: theta_0_0 474 - add_group_info == True: theta_posterior_0_0_ 475 "cds": 476 - add_group_info == False: theta_ARVIZ_CDS_SELECTION_0_0 477 - add_group_info == True: theta_ARVIZ_GROUP_posterior__ARVIZ_CDS_SELECTION_0_0 478 tuple: 479 Structure: 480 tuple: (dim_info, group_info) 481 dim_info: (str: `.join` separator, 482 str: dim_separator_start, 483 str: dim_separator_end) 484 group_info: (str: group separator start, str: group separator end) 485 Example: ((",", "[", "]"), ("_", "")) 486 - add_group_info == False: theta[0,0] 487 - add_group_info == True: theta_posterior[0,0] 488 index_origin : int, optional 489 Start parameter indices from `index_origin`. Either 0 or 1. 490 491 Returns 492 ------- 493 dict 494 """ 495 from .data import convert_to_inference_data 496 497 data = convert_to_inference_data(data) 498 499 if groups is None: 500 groups = ["posterior", "posterior_predictive", "sample_stats"] 501 elif isinstance(groups, str): 502 if groups.lower() == "posterior_groups": 503 groups = ["posterior", "posterior_predictive", "sample_stats"] 504 elif groups.lower() == "prior_groups": 505 groups = ["prior", "prior_predictive", "sample_stats_prior"] 506 elif groups.lower() == "posterior_groups_warmup": 507 groups = ["warmup_posterior", "warmup_posterior_predictive", "warmup_sample_stats"] 508 else: 509 raise TypeError( 510 ( 511 "Valid predefined groups are " 512 "{posterior_groups, prior_groups, posterior_groups_warmup}" 513 ) 514 ) 515 516 if dimensions is None: 517 dimensions = "chain", "draw" 518 elif isinstance(dimensions, str): 519 dimensions = (dimensions,) 520 521 if var_name_format is None: 522 var_name_format = "brackets" 523 524 if isinstance(var_name_format, str): 525 var_name_format = var_name_format.lower() 526 527 if var_name_format == "brackets": 528 dim_join_separator, dim_separator_start, dim_separator_end = ",", "[", "]" 529 group_separator_start, group_separator_end = "_", "" 530 elif var_name_format == "underscore": 531 dim_join_separator, dim_separator_start, dim_separator_end = "_", "_", "" 532 group_separator_start, group_separator_end = "_", "" 533 elif var_name_format == "cds": 534 dim_join_separator, dim_separator_start, dim_separator_end = ( 535 "_", 536 "_ARVIZ_CDS_SELECTION_", 537 "", 538 ) 539 group_separator_start, group_separator_end = "_ARVIZ_GROUP_", "" 540 elif isinstance(var_name_format, str): 541 msg = 'Invalid predefined format. Select one {"brackets", "underscore", "cds"}' 542 raise TypeError(msg) 543 else: 544 ( 545 (dim_join_separator, dim_separator_start, dim_separator_end), 546 (group_separator_start, group_separator_end), 547 ) = var_name_format 548 549 if index_origin is None: 550 index_origin = rcParams["data.index_origin"] 551 552 data_dict = {} 553 for group in groups: 554 if hasattr(data, group): 555 group_data = getattr(data, group).stack(stack_dimension=dimensions) 556 for var_name, var in group_data.data_vars.items(): 557 var_values = var.values 558 if var_names is not None and var_name not in var_names: 559 continue 560 for dim_name in dimensions: 561 if dim_name not in data_dict: 562 data_dict[dim_name] = var.coords.get(dim_name).values 563 if len(var.shape) == 1: 564 if group_info: 565 var_name_dim = ( 566 "{var_name}" "{group_separator_start}{group}{group_separator_end}" 567 ).format( 568 var_name=var_name, 569 group_separator_start=group_separator_start, 570 group=group, 571 group_separator_end=group_separator_end, 572 ) 573 else: 574 var_name_dim = f"{var_name}" 575 data_dict[var_name_dim] = var.values 576 else: 577 for loc in np.ndindex(var.shape[:-1]): 578 if group_info: 579 var_name_dim = ( 580 "{var_name}" 581 "{group_separator_start}{group}{group_separator_end}" 582 "{dim_separator_start}{dim_join}{dim_separator_end}" 583 ).format( 584 var_name=var_name, 585 group_separator_start=group_separator_start, 586 group=group, 587 group_separator_end=group_separator_end, 588 dim_separator_start=dim_separator_start, 589 dim_join=dim_join_separator.join( 590 (str(item + index_origin) for item in loc) 591 ), 592 dim_separator_end=dim_separator_end, 593 ) 594 else: 595 var_name_dim = ( 596 "{var_name}" "{dim_separator_start}{dim_join}{dim_separator_end}" 597 ).format( 598 var_name=var_name, 599 dim_separator_start=dim_separator_start, 600 dim_join=dim_join_separator.join( 601 (str(item + index_origin) for item in loc) 602 ), 603 dim_separator_end=dim_separator_end, 604 ) 605 606 data_dict[var_name_dim] = var_values[loc] 607 return data_dict 608 609 610def get_coords(data, coords): 611 """Subselects xarray DataSet or DataArray object to provided coords. Raises exception if fails. 612 613 Raises 614 ------ 615 ValueError 616 If coords name are not available in data 617 618 KeyError 619 If coords dims are not available in data 620 621 Returns 622 ------- 623 data: xarray 624 xarray.DataSet or xarray.DataArray object, same type as input 625 """ 626 if not isinstance(data, (list, tuple)): 627 try: 628 return data.sel(**coords) 629 630 except ValueError as err: 631 invalid_coords = set(coords.keys()) - set(data.coords.keys()) 632 raise ValueError(f"Coords {invalid_coords} are invalid coordinate keys") from err 633 634 except KeyError as err: 635 raise KeyError( 636 ( 637 "Coords should follow mapping format {{coord_name:[dim1, dim2]}}. " 638 "Check that coords structure is correct and" 639 " dimensions are valid. {}" 640 ).format(err) 641 ) from err 642 if not isinstance(coords, (list, tuple)): 643 coords = [coords] * len(data) 644 data_subset = [] 645 for idx, (datum, coords_dict) in enumerate(zip(data, coords)): 646 try: 647 data_subset.append(get_coords(datum, coords_dict)) 648 except ValueError as err: 649 raise ValueError(f"Error in data[{idx}]: {err}") from err 650 except KeyError as err: 651 raise KeyError(f"Error in data[{idx}]: {err}") from err 652 return data_subset 653 654 655@lru_cache(None) 656def _load_static_files(): 657 """Lazily load the resource files into memory the first time they are needed. 658 659 Clone from xarray.core.formatted_html_template. 660 """ 661 return [pkg_resources.resource_string("arviz", fname).decode("utf8") for fname in STATIC_FILES] 662 663 664class HtmlTemplate: 665 """Contain html templates for InferenceData repr.""" 666 667 html_template = """ 668 <div> 669 <div class='xr-header'> 670 <div class="xr-obj-type">arviz.InferenceData</div> 671 </div> 672 <ul class="xr-sections group-sections"> 673 {} 674 </ul> 675 </div> 676 """ 677 element_template = """ 678 <li class = "xr-section-item"> 679 <input id="idata_{group_id}" class="xr-section-summary-in" type="checkbox"> 680 <label for="idata_{group_id}" class = "xr-section-summary">{group}</label> 681 <div class="xr-section-inline-details"></div> 682 <div class="xr-section-details"> 683 <ul id="xr-dataset-coord-list" class="xr-var-list"> 684 <div style="padding-left:2rem;">{xr_data}<br></div> 685 </ul> 686 </div> 687 </li> 688 """ 689 _, css_style = _load_static_files() # pylint: disable=protected-access 690 specific_style = ".xr-wrap{width:700px!important;}" 691 css_template = f"<style> {css_style}{specific_style} </style>" 692 693 694def either_dict_or_kwargs( 695 pos_kwargs, 696 kw_kwargs, 697 func_name, 698): 699 """Clone from xarray.core.utils.""" 700 if pos_kwargs is not None: 701 if not hasattr(pos_kwargs, "keys") and hasattr(pos_kwargs, "__getitem__"): 702 raise ValueError(f"the first argument to .{func_name} must be a dictionary") 703 if kw_kwargs: 704 raise ValueError( 705 f"cannot specify both keyword and positional arguments to .{func_name}" 706 ) 707 return pos_kwargs 708 else: 709 return kw_kwargs 710 711 712class Dask: 713 """Class to toggle Dask states. 714 715 Warnings 716 -------- 717 Dask integration is an experimental feature still in progress. It can already be used 718 but it doesn't work with all stats nor diagnostics yet. 719 """ 720 721 dask_flag = False 722 dask_kwargs = None 723 724 @classmethod 725 def enable_dask(cls, dask_kwargs=None): 726 """To enable Dask. 727 728 Parameters 729 ---------- 730 dask_kwargs : dict 731 Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`. 732 """ 733 cls.dask_flag = True 734 cls.dask_kwargs = dask_kwargs 735 736 @classmethod 737 def disable_dask(cls): 738 """To disable Dask.""" 739 cls.dask_flag = False 740 cls.dask_kwargs = None 741 742 743def conditional_dask(func): 744 """Conditionally pass dask kwargs to `wrap_xarray_ufunc`.""" 745 746 @functools.wraps(func) 747 def wrapper(*args, **kwargs): 748 749 if Dask.dask_flag: 750 user_kwargs = kwargs.pop("dask_kwargs", None) 751 if user_kwargs is None: 752 user_kwargs = {} 753 default_kwargs = Dask.dask_kwargs 754 return func(dask_kwargs={**default_kwargs, **user_kwargs}, *args, **kwargs) 755 else: 756 return func(*args, **kwargs) 757 758 return wrapper 759