1# pylint: disable=too-many-nested-blocks
2"""General utilities."""
3import functools
4import importlib
5import re
6import warnings
7from functools import lru_cache
8
9import matplotlib.pyplot as plt
10import numpy as np
11import pkg_resources
12from numpy import newaxis
13
14from .rcparams import rcParams
15
16STATIC_FILES = ("static/html/icons-svg-inline.html", "static/css/style.css")
17
18
19def _var_names(var_names, data, filter_vars=None):
20    """Handle var_names input across arviz.
21
22    Parameters
23    ----------
24    var_names: str, list, or None
25    data : xarray.Dataset
26        Posterior data in an xarray
27    filter_vars: {None, "like", "regex"}, optional, default=None
28        If `None` (default), interpret var_names as the real variables names. If "like",
29         interpret var_names as substrings of the real variables names. If "regex",
30         interpret var_names as regular expressions on the real variables names. A la
31        `pandas.filter`.
32
33    Returns
34    -------
35    var_name: list or None
36    """
37    if filter_vars not in {None, "like", "regex"}:
38        raise ValueError(
39            f"'filter_vars' can only be None, 'like', or 'regex', got: '{filter_vars}'"
40        )
41
42    if var_names is not None:
43        if isinstance(data, (list, tuple)):
44            all_vars = []
45            for dataset in data:
46                dataset_vars = list(dataset.data_vars)
47                for var in dataset_vars:
48                    if var not in all_vars:
49                        all_vars.append(var)
50        else:
51            all_vars = list(data.data_vars)
52
53        all_vars_tilde = [var for var in all_vars if var.startswith("~")]
54        if all_vars_tilde:
55            warnings.warn(
56                """ArviZ treats '~' as a negation character for variable selection.
57                   Your model has variables names starting with '~', {0}. Please double check
58                   your results to ensure all variables are included""".format(
59                    ", ".join(all_vars_tilde)
60                )
61            )
62
63        try:
64            var_names = _subset_list(var_names, all_vars, filter_items=filter_vars, warn=False)
65        except KeyError as err:
66            msg = " ".join(("var names:", f"{err}", "in dataset"))
67            raise KeyError(msg) from err
68    return var_names
69
70
71def _subset_list(subset, whole_list, filter_items=None, warn=True):
72    """Handle list subsetting (var_names, groups...) across arviz.
73
74    Parameters
75    ----------
76    subset : str, list, or None
77    whole_list : list
78        List from which to select a subset according to subset elements and
79        filter_items value.
80    filter_items : {None, "like", "regex"}, optional
81        If `None` (default), interpret `subset` as the exact elements in `whole_list`
82        names. If "like", interpret `subset` as substrings of the elements in
83        `whole_list`. If "regex", interpret `subset` as regular expressions to match
84        elements in `whole_list`. A la `pandas.filter`.
85
86    Returns
87    -------
88    list or None
89        A subset of ``whole_list`` fulfilling the requests imposed by ``subset``
90        and ``filter_items``.
91    """
92    if subset is not None:
93
94        if isinstance(subset, str):
95            subset = [subset]
96
97        whole_list_tilde = [item for item in whole_list if item.startswith("~")]
98        if whole_list_tilde and warn:
99            warnings.warn(
100                "ArviZ treats '~' as a negation character for selection. There are "
101                "elements in `whole_list` starting with '~', {0}. Please double check"
102                "your results to ensure all elements are included".format(
103                    ", ".join(whole_list_tilde)
104                )
105            )
106
107        excluded_items = [
108            item[1:] for item in subset if item.startswith("~") and item not in whole_list
109        ]
110        filter_items = str(filter_items).lower()
111        not_found = []
112
113        if excluded_items:
114            if filter_items in ("like", "regex"):
115                for pattern in excluded_items[:]:
116                    excluded_items.remove(pattern)
117                    if filter_items == "like":
118                        real_items = [real_item for real_item in whole_list if pattern in real_item]
119                    else:
120                        # i.e filter_items == "regex"
121                        real_items = [
122                            real_item for real_item in whole_list if re.search(pattern, real_item)
123                        ]
124                    if not real_items:
125                        not_found.append(pattern)
126                    excluded_items.extend(real_items)
127            not_found.extend([item for item in excluded_items if item not in whole_list])
128            if not_found:
129                warnings.warn(
130                    f"Items starting with ~: {not_found} have not been found and will be ignored"
131                )
132            subset = [item for item in whole_list if item not in excluded_items]
133
134        else:
135            if filter_items == "like":
136                subset = [item for item in whole_list for name in subset if name in item]
137            elif filter_items == "regex":
138                subset = [item for item in whole_list for name in subset if re.search(name, item)]
139
140        existing_items = np.isin(subset, whole_list)
141        if not np.all(existing_items):
142            raise KeyError(f"{np.array(subset)[~existing_items]} are not present")
143
144    return subset
145
146
147class lazy_property:  # pylint: disable=invalid-name
148    """Used to load numba first time it is needed."""
149
150    def __init__(self, fget):
151        """Lazy load a property with `fget`."""
152        self.fget = fget
153
154        # copy the getter function's docstring and other attributes
155        functools.update_wrapper(self, fget)
156
157    def __get__(self, obj, cls):
158        """Call the function, set the attribute."""
159        if obj is None:
160            return self
161
162        value = self.fget(obj)
163        setattr(obj, self.fget.__name__, value)
164        return value
165
166
167class maybe_numba_fn:  # pylint: disable=invalid-name
168    """Wrap a function to (maybe) use a (lazy) jit-compiled version."""
169
170    def __init__(self, function, **kwargs):
171        """Wrap a function and save compilation keywords."""
172        self.function = function
173        self.kwargs = kwargs
174
175    @lazy_property
176    def numba_fn(self):
177        """Memoized compiled function."""
178        try:
179            numba = importlib.import_module("numba")
180            numba_fn = numba.jit(**self.kwargs)(self.function)
181        except ImportError:
182            numba_fn = self.function
183        return numba_fn
184
185    def __call__(self, *args, **kwargs):
186        """Call the jitted function or normal, depending on flag."""
187        if Numba.numba_flag:
188            return self.numba_fn(*args, **kwargs)
189        else:
190            return self.function(*args, **kwargs)
191
192
193class interactive_backend:  # pylint: disable=invalid-name
194    """Context manager to change backend temporarily in ipython sesson.
195
196    It uses ipython magic to change temporarily from the ipython inline backend to
197    an interactive backend of choice. It cannot be used outside ipython sessions nor
198    to change backends different than inline -> interactive.
199
200    Notes
201    -----
202    The first time ``interactive_backend`` context manager is called, any of the available
203    interactive backends can be chosen. The following times, this same backend must be used
204    unless the kernel is restarted.
205
206    Parameters
207    ----------
208    backend : str, optional
209        Interactive backend to use. It will be passed to ``%matplotlib`` magic, refer to
210        its docs to see available options.
211
212    Examples
213    --------
214    Inside an ipython session (i.e. a jupyter notebook) with the inline backend set:
215
216    .. code::
217
218        >>> import arviz as az
219        >>> idata = az.load_arviz_data("centered_eight")
220        >>> az.plot_posterior(idata) # inline
221        >>> with az.interactive_backend():
222        ...     az.plot_density(idata) # interactive
223        >>> az.plot_trace(idata) # inline
224
225    """
226
227    # based on matplotlib.rc_context
228    def __init__(self, backend=""):
229        """Initialize context manager."""
230        try:
231            from IPython import get_ipython
232        except ImportError as err:
233            raise ImportError(
234                "The exception below was risen while importing Ipython, this "
235                "context manager can only be used inside ipython sessions:\n{}".format(err)
236            ) from err
237        self.ipython = get_ipython()
238        if self.ipython is None:
239            raise EnvironmentError("This context manager can only be used inside ipython sessions")
240        self.ipython.magic(f"matplotlib {backend}")
241
242    def __enter__(self):
243        """Enter context manager."""
244        return self
245
246    def __exit__(self, exc_type, exc_value, exc_tb):
247        """Exit context manager."""
248        plt.show(block=True)
249        self.ipython.magic("matplotlib inline")
250
251
252def conditional_jit(_func=None, **kwargs):
253    """Use numba's jit decorator if numba is installed.
254
255    Notes
256    -----
257        If called without arguments  then return wrapped function.
258
259        @conditional_jit
260        def my_func():
261            return
262
263        else called with arguments
264
265        @conditional_jit(nopython=True)
266        def my_func():
267            return
268
269    """
270    if _func is None:
271        return lambda fn: functools.wraps(fn)(maybe_numba_fn(fn, **kwargs))
272    else:
273        lazy_numba = maybe_numba_fn(_func, **kwargs)
274        return functools.wraps(_func)(lazy_numba)
275
276
277def conditional_vect(function=None, **kwargs):  # noqa: D202
278    """Use numba's vectorize decorator if numba is installed.
279
280    Notes
281    -----
282        If called without arguments  then return wrapped function.
283        @conditional_vect
284        def my_func():
285            return
286        else called with arguments
287        @conditional_vect(nopython=True)
288        def my_func():
289            return
290
291    """
292
293    def wrapper(function):
294        try:
295            numba = importlib.import_module("numba")
296            return numba.vectorize(**kwargs)(function)
297
298        except ImportError:
299            return function
300
301    if function:
302        return wrapper(function)
303    else:
304        return wrapper
305
306
307def numba_check():
308    """Check if numba is installed."""
309    numba = importlib.util.find_spec("numba")
310    return numba is not None
311
312
313class Numba:
314    """A class to toggle numba states."""
315
316    numba_flag = numba_check()
317
318    @classmethod
319    def disable_numba(cls):
320        """To disable numba."""
321        cls.numba_flag = False
322
323    @classmethod
324    def enable_numba(cls):
325        """To enable numba."""
326        if numba_check():
327            cls.numba_flag = True
328        else:
329            raise ValueError("Numba is not installed")
330
331
332def _numba_var(numba_function, standard_numpy_func, data, axis=None, ddof=0):
333    """Replace the numpy methods used to calculate variance.
334
335    Parameters
336    ----------
337    numba_function : function()
338        Custom numba function included in stats/stats_utils.py.
339
340    standard_numpy_func: function()
341        Standard function included in the numpy library.
342
343    data : array.
344    axis : axis along which the variance is calculated.
345    ddof : degrees of freedom allowed while calculating variance.
346
347    Returns
348    -------
349    array:
350        variance values calculate by appropriate function for numba speedup
351        if Numba is installed or enabled.
352
353    """
354    if Numba.numba_flag:
355        return numba_function(data, axis=axis, ddof=ddof)
356    else:
357        return standard_numpy_func(data, axis=axis, ddof=ddof)
358
359
360def _stack(x, y):
361    assert x.shape[1:] == y.shape[1:]
362    return np.vstack((x, y))
363
364
365def arange(x):
366    """Jitting numpy arange."""
367    return np.arange(x)
368
369
370def one_de(x):
371    """Jitting numpy atleast_1d."""
372    if not isinstance(x, np.ndarray):
373        return np.atleast_1d(x)
374    if x.ndim == 0:
375        result = x.reshape(1)
376    else:
377        result = x
378    return result
379
380
381def two_de(x):
382    """Jitting numpy at_least_2d."""
383    if not isinstance(x, np.ndarray):
384        return np.atleast_2d(x)
385    if x.ndim == 0:
386        result = x.reshape(1, 1)
387    elif x.ndim == 1:
388        result = x[newaxis, :]
389    else:
390        result = x
391    return result
392
393
394def expand_dims(x):
395    """Jitting numpy expand_dims."""
396    if not isinstance(x, np.ndarray):
397        return np.expand_dims(x, 0)
398    shape = x.shape
399    return x.reshape(shape[:0] + (1,) + shape[0:])
400
401
402@conditional_jit(cache=True, nopython=True)
403def _dot(x, y):
404    return np.dot(x, y)
405
406
407@conditional_jit(cache=True, nopython=True)
408def _cov_1d(x):
409    x = x - x.mean(axis=0)
410    ddof = x.shape[0] - 1
411    return np.dot(x.T, x.conj()) / ddof
412
413
414# @conditional_jit(cache=True)
415def _cov(data):
416    if data.ndim == 1:
417        return _cov_1d(data)
418    elif data.ndim == 2:
419        x = data.astype(float)
420        avg, _ = np.average(x, axis=1, weights=None, returned=True)
421        ddof = x.shape[1] - 1
422        if ddof <= 0:
423            warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
424            ddof = 0.0
425        x -= avg[:, None]
426        prod = _dot(x, x.T.conj())
427        prod *= np.true_divide(1, ddof)
428        prod = prod.squeeze()
429        prod += 1e-6 * np.eye(prod.shape[0])
430        return prod
431    else:
432        raise ValueError(f"{data.ndim} dimension arrays are not supported")
433
434
435def flatten_inference_data_to_dict(
436    data,
437    var_names=None,
438    groups=None,
439    dimensions=None,
440    group_info=False,
441    var_name_format=None,
442    index_origin=None,
443):
444    """Transform data to dictionary.
445
446    Parameters
447    ----------
448    data : obj
449        Any object that can be converted to an az.InferenceData object
450        Refer to documentation of az.convert_to_inference_data for details
451    var_names : str or list of str, optional
452        Variables to be processed, if None all variables are processed.
453    groups : str or list of str, optional
454        Select groups for CDS. Default groups are
455        {"posterior_groups", "prior_groups", "posterior_groups_warmup"}
456            - posterior_groups: posterior, posterior_predictive, sample_stats
457            - prior_groups: prior, prior_predictive, sample_stats_prior
458            - posterior_groups_warmup: warmup_posterior, warmup_posterior_predictive,
459                                       warmup_sample_stats
460    ignore_groups : str or list of str, optional
461        Ignore specific groups from CDS.
462    dimension : str, or list of str, optional
463        Select dimensions along to slice the data. By default uses ("chain", "draw").
464    group_info : bool
465        Add group info for `var_name_format`
466    var_name_format : str or tuple of tuple of string, optional
467        Select column name format for non-scalar input.
468        Predefined options are {"brackets", "underscore", "cds"}
469            "brackets":
470                - add_group_info == False: theta[0,0]
471                - add_group_info == True: theta_posterior[0,0]
472            "underscore":
473                - add_group_info == False: theta_0_0
474                - add_group_info == True: theta_posterior_0_0_
475            "cds":
476                - add_group_info == False: theta_ARVIZ_CDS_SELECTION_0_0
477                - add_group_info == True: theta_ARVIZ_GROUP_posterior__ARVIZ_CDS_SELECTION_0_0
478            tuple:
479                Structure:
480                    tuple: (dim_info, group_info)
481                        dim_info: (str: `.join` separator,
482                                   str: dim_separator_start,
483                                   str: dim_separator_end)
484                        group_info: (str: group separator start, str: group separator end)
485                Example: ((",", "[", "]"), ("_", ""))
486                    - add_group_info == False: theta[0,0]
487                    - add_group_info == True: theta_posterior[0,0]
488    index_origin : int, optional
489        Start parameter indices from `index_origin`. Either 0 or 1.
490
491    Returns
492    -------
493    dict
494    """
495    from .data import convert_to_inference_data
496
497    data = convert_to_inference_data(data)
498
499    if groups is None:
500        groups = ["posterior", "posterior_predictive", "sample_stats"]
501    elif isinstance(groups, str):
502        if groups.lower() == "posterior_groups":
503            groups = ["posterior", "posterior_predictive", "sample_stats"]
504        elif groups.lower() == "prior_groups":
505            groups = ["prior", "prior_predictive", "sample_stats_prior"]
506        elif groups.lower() == "posterior_groups_warmup":
507            groups = ["warmup_posterior", "warmup_posterior_predictive", "warmup_sample_stats"]
508        else:
509            raise TypeError(
510                (
511                    "Valid predefined groups are "
512                    "{posterior_groups, prior_groups, posterior_groups_warmup}"
513                )
514            )
515
516    if dimensions is None:
517        dimensions = "chain", "draw"
518    elif isinstance(dimensions, str):
519        dimensions = (dimensions,)
520
521    if var_name_format is None:
522        var_name_format = "brackets"
523
524    if isinstance(var_name_format, str):
525        var_name_format = var_name_format.lower()
526
527    if var_name_format == "brackets":
528        dim_join_separator, dim_separator_start, dim_separator_end = ",", "[", "]"
529        group_separator_start, group_separator_end = "_", ""
530    elif var_name_format == "underscore":
531        dim_join_separator, dim_separator_start, dim_separator_end = "_", "_", ""
532        group_separator_start, group_separator_end = "_", ""
533    elif var_name_format == "cds":
534        dim_join_separator, dim_separator_start, dim_separator_end = (
535            "_",
536            "_ARVIZ_CDS_SELECTION_",
537            "",
538        )
539        group_separator_start, group_separator_end = "_ARVIZ_GROUP_", ""
540    elif isinstance(var_name_format, str):
541        msg = 'Invalid predefined format. Select one {"brackets", "underscore", "cds"}'
542        raise TypeError(msg)
543    else:
544        (
545            (dim_join_separator, dim_separator_start, dim_separator_end),
546            (group_separator_start, group_separator_end),
547        ) = var_name_format
548
549    if index_origin is None:
550        index_origin = rcParams["data.index_origin"]
551
552    data_dict = {}
553    for group in groups:
554        if hasattr(data, group):
555            group_data = getattr(data, group).stack(stack_dimension=dimensions)
556            for var_name, var in group_data.data_vars.items():
557                var_values = var.values
558                if var_names is not None and var_name not in var_names:
559                    continue
560                for dim_name in dimensions:
561                    if dim_name not in data_dict:
562                        data_dict[dim_name] = var.coords.get(dim_name).values
563                if len(var.shape) == 1:
564                    if group_info:
565                        var_name_dim = (
566                            "{var_name}" "{group_separator_start}{group}{group_separator_end}"
567                        ).format(
568                            var_name=var_name,
569                            group_separator_start=group_separator_start,
570                            group=group,
571                            group_separator_end=group_separator_end,
572                        )
573                    else:
574                        var_name_dim = f"{var_name}"
575                    data_dict[var_name_dim] = var.values
576                else:
577                    for loc in np.ndindex(var.shape[:-1]):
578                        if group_info:
579                            var_name_dim = (
580                                "{var_name}"
581                                "{group_separator_start}{group}{group_separator_end}"
582                                "{dim_separator_start}{dim_join}{dim_separator_end}"
583                            ).format(
584                                var_name=var_name,
585                                group_separator_start=group_separator_start,
586                                group=group,
587                                group_separator_end=group_separator_end,
588                                dim_separator_start=dim_separator_start,
589                                dim_join=dim_join_separator.join(
590                                    (str(item + index_origin) for item in loc)
591                                ),
592                                dim_separator_end=dim_separator_end,
593                            )
594                        else:
595                            var_name_dim = (
596                                "{var_name}" "{dim_separator_start}{dim_join}{dim_separator_end}"
597                            ).format(
598                                var_name=var_name,
599                                dim_separator_start=dim_separator_start,
600                                dim_join=dim_join_separator.join(
601                                    (str(item + index_origin) for item in loc)
602                                ),
603                                dim_separator_end=dim_separator_end,
604                            )
605
606                        data_dict[var_name_dim] = var_values[loc]
607    return data_dict
608
609
610def get_coords(data, coords):
611    """Subselects xarray DataSet or DataArray object to provided coords. Raises exception if fails.
612
613    Raises
614    ------
615    ValueError
616        If coords name are not available in data
617
618    KeyError
619        If coords dims are not available in data
620
621    Returns
622    -------
623    data: xarray
624        xarray.DataSet or xarray.DataArray object, same type as input
625    """
626    if not isinstance(data, (list, tuple)):
627        try:
628            return data.sel(**coords)
629
630        except ValueError as err:
631            invalid_coords = set(coords.keys()) - set(data.coords.keys())
632            raise ValueError(f"Coords {invalid_coords} are invalid coordinate keys") from err
633
634        except KeyError as err:
635            raise KeyError(
636                (
637                    "Coords should follow mapping format {{coord_name:[dim1, dim2]}}. "
638                    "Check that coords structure is correct and"
639                    " dimensions are valid. {}"
640                ).format(err)
641            ) from err
642    if not isinstance(coords, (list, tuple)):
643        coords = [coords] * len(data)
644    data_subset = []
645    for idx, (datum, coords_dict) in enumerate(zip(data, coords)):
646        try:
647            data_subset.append(get_coords(datum, coords_dict))
648        except ValueError as err:
649            raise ValueError(f"Error in data[{idx}]: {err}") from err
650        except KeyError as err:
651            raise KeyError(f"Error in data[{idx}]: {err}") from err
652    return data_subset
653
654
655@lru_cache(None)
656def _load_static_files():
657    """Lazily load the resource files into memory the first time they are needed.
658
659    Clone from xarray.core.formatted_html_template.
660    """
661    return [pkg_resources.resource_string("arviz", fname).decode("utf8") for fname in STATIC_FILES]
662
663
664class HtmlTemplate:
665    """Contain html templates for InferenceData repr."""
666
667    html_template = """
668            <div>
669              <div class='xr-header'>
670                <div class="xr-obj-type">arviz.InferenceData</div>
671              </div>
672              <ul class="xr-sections group-sections">
673              {}
674              </ul>
675            </div>
676            """
677    element_template = """
678            <li class = "xr-section-item">
679                  <input id="idata_{group_id}" class="xr-section-summary-in" type="checkbox">
680                  <label for="idata_{group_id}" class = "xr-section-summary">{group}</label>
681                  <div class="xr-section-inline-details"></div>
682                  <div class="xr-section-details">
683                      <ul id="xr-dataset-coord-list" class="xr-var-list">
684                          <div style="padding-left:2rem;">{xr_data}<br></div>
685                      </ul>
686                  </div>
687            </li>
688            """
689    _, css_style = _load_static_files()  # pylint: disable=protected-access
690    specific_style = ".xr-wrap{width:700px!important;}"
691    css_template = f"<style> {css_style}{specific_style} </style>"
692
693
694def either_dict_or_kwargs(
695    pos_kwargs,
696    kw_kwargs,
697    func_name,
698):
699    """Clone from xarray.core.utils."""
700    if pos_kwargs is not None:
701        if not hasattr(pos_kwargs, "keys") and hasattr(pos_kwargs, "__getitem__"):
702            raise ValueError(f"the first argument to .{func_name} must be a dictionary")
703        if kw_kwargs:
704            raise ValueError(
705                f"cannot specify both keyword and positional arguments to .{func_name}"
706            )
707        return pos_kwargs
708    else:
709        return kw_kwargs
710
711
712class Dask:
713    """Class to toggle Dask states.
714
715    Warnings
716    --------
717    Dask integration is an experimental feature still in progress. It can already be used
718    but it doesn't work with all stats nor diagnostics yet.
719    """
720
721    dask_flag = False
722    dask_kwargs = None
723
724    @classmethod
725    def enable_dask(cls, dask_kwargs=None):
726        """To enable Dask.
727
728        Parameters
729        ----------
730        dask_kwargs : dict
731            Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
732        """
733        cls.dask_flag = True
734        cls.dask_kwargs = dask_kwargs
735
736    @classmethod
737    def disable_dask(cls):
738        """To disable Dask."""
739        cls.dask_flag = False
740        cls.dask_kwargs = None
741
742
743def conditional_dask(func):
744    """Conditionally pass dask kwargs to `wrap_xarray_ufunc`."""
745
746    @functools.wraps(func)
747    def wrapper(*args, **kwargs):
748
749        if Dask.dask_flag:
750            user_kwargs = kwargs.pop("dask_kwargs", None)
751            if user_kwargs is None:
752                user_kwargs = {}
753            default_kwargs = Dask.dask_kwargs
754            return func(dask_kwargs={**default_kwargs, **user_kwargs}, *args, **kwargs)
755        else:
756            return func(*args, **kwargs)
757
758    return wrapper
759