1"""Low level converters usually used by other functions."""
2import datetime
3import functools
4import re
5import warnings
6from copy import deepcopy
7from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
8
9import numpy as np
10import pkg_resources
11import xarray as xr
12
13try:
14    import ujson as json
15except ImportError:
16    # mypy struggles with conditional imports expressed as catching ImportError:
17    # https://github.com/python/mypy/issues/1153
18    import json  # type: ignore
19
20from .. import __version__, utils
21from ..rcparams import rcParams
22
23CoordSpec = Dict[str, List[Any]]
24DimSpec = Dict[str, List[str]]
25RequiresArgTypeT = TypeVar("RequiresArgTypeT")
26RequiresReturnTypeT = TypeVar("RequiresReturnTypeT")
27
28
29class requires:  # pylint: disable=invalid-name
30    """Decorator to return None if an object does not have the required attribute.
31
32    If the decorator is called various times on the same function with different
33    attributes, it will return None if one of them is missing. If instead a list
34    of attributes is passed, it will return None if all attributes in the list are
35    missing. Both functionalities can be combined as desired.
36    """
37
38    def __init__(self, *props: Union[str, List[str]]) -> None:
39        self.props: Tuple[Union[str, List[str]], ...] = props
40
41    # Until typing.ParamSpec (https://www.python.org/dev/peps/pep-0612/) is available
42    # in all our supported Python versions, there is no way to simultaneously express
43    # the following two properties:
44    # - the input function may take arbitrary args/kwargs, and
45    # - the output function takes those same arbitrary args/kwargs, but has a different return type.
46    # We either have to limit the input function to e.g. only allowing a "self" argument,
47    # or we have to adopt the current approach of annotating the returned function as if
48    # it was defined as "def f(*args: Any, **kwargs: Any) -> Optional[RequiresReturnTypeT]".
49    #
50    # Since all functions decorated with @requires currently only accept a single argument,
51    # we choose to limit application of @requires to only functions of one argument.
52    # When typing.ParamSpec is available, this definition can be updated to use it.
53    # See https://github.com/arviz-devs/arviz/pull/1504 for more discussion.
54    def __call__(
55        self, func: Callable[[RequiresArgTypeT], RequiresReturnTypeT]
56    ) -> Callable[[RequiresArgTypeT], Optional[RequiresReturnTypeT]]:  # noqa: D202
57        """Wrap the decorated function."""
58
59        def wrapped(cls: RequiresArgTypeT) -> Optional[RequiresReturnTypeT]:
60            """Return None if not all props are available."""
61            for prop in self.props:
62                prop = [prop] if isinstance(prop, str) else prop
63                if all((getattr(cls, prop_i) is None for prop_i in prop)):
64                    return None
65            return func(cls)
66
67        return wrapped
68
69
70def generate_dims_coords(
71    shape,
72    var_name,
73    dims=None,
74    coords=None,
75    default_dims=None,
76    index_origin=None,
77    skip_event_dims=None,
78):
79    """Generate default dimensions and coordinates for a variable.
80
81    Parameters
82    ----------
83    shape : tuple[int]
84        Shape of the variable
85    var_name : str
86        Name of the variable. If no dimension name(s) is provided, ArviZ
87        will generate a default dimension name using ``var_name``, e.g.,
88        ``"foo_dim_0"`` for the first dimension if ``var_name`` is ``"foo"``.
89    dims : list
90        List of dimensions for the variable
91    coords : dict[str] -> list[str]
92        Map of dimensions to coordinates
93    default_dims : list[str]
94        Dimension names that are not part of the variable's shape. For example,
95        when manipulating Monte Carlo traces, the ``default_dims`` would be
96        ``["chain" , "draw"]`` which ArviZ uses as its own names for dimensions
97        of MCMC traces.
98    index_origin : int, optional
99        Starting value of integer coordinate values. Defaults to the value in rcParam
100        ``data.index_origin``.
101    skip_event_dims : bool, default False
102
103    Returns
104    -------
105    list[str]
106        Default dims
107    dict[str] -> list[str]
108        Default coords
109    """
110    if index_origin is None:
111        index_origin = rcParams["data.index_origin"]
112    if default_dims is None:
113        default_dims = []
114    if dims is None:
115        dims = []
116    if skip_event_dims is None:
117        skip_event_dims = False
118
119    if coords is None:
120        coords = {}
121
122    coords = deepcopy(coords)
123    dims = deepcopy(dims)
124
125    ndims = len([dim for dim in dims if dim not in default_dims])
126    if ndims > len(shape):
127        if skip_event_dims:
128            dims = dims[: len(shape)]
129        else:
130            warnings.warn(
131                (
132                    "In variable {var_name}, there are "
133                    + "more dims ({dims_len}) given than exist ({shape_len}). "
134                    + "Passed array should have shape ({defaults}*shape)"
135                ).format(
136                    var_name=var_name,
137                    dims_len=len(dims),
138                    shape_len=len(shape),
139                    defaults=",".join(default_dims) + ", " if default_dims is not None else "",
140                ),
141                UserWarning,
142            )
143    if skip_event_dims:
144        # this is needed in case the reduction keeps the dimension with size 1
145        for i, (dim, dim_size) in enumerate(zip(dims, shape)):
146            if (dim in coords) and (dim_size != len(coords[dim])):
147                dims = dims[:i]
148                break
149
150    for idx, dim_len in enumerate(shape):
151        if (len(dims) < idx + 1) or (dims[idx] is None):
152            dim_name = f"{var_name}_dim_{idx}"
153            if len(dims) < idx + 1:
154                dims.append(dim_name)
155            else:
156                dims[idx] = dim_name
157        dim_name = dims[idx]
158        if dim_name not in coords:
159            coords[dim_name] = np.arange(index_origin, dim_len + index_origin)
160    coords = {key: coord for key, coord in coords.items() if any(key == dim for dim in dims)}
161    return dims, coords
162
163
164def numpy_to_data_array(
165    ary,
166    *,
167    var_name="data",
168    coords=None,
169    dims=None,
170    default_dims=None,
171    index_origin=None,
172    skip_event_dims=None,
173):
174    """Convert a numpy array to an xarray.DataArray.
175
176    By default, the first two dimensions will be (chain, draw), and any remaining
177    dimensions will be "shape".
178    * If the numpy array is 1d, this dimension is interpreted as draw
179    * If the numpy array is 2d, it is interpreted as (chain, draw)
180    * If the numpy array is 3 or more dimensions, the last dimensions are kept as shapes.
181
182    To modify this behaviour, use ``default_dims``.
183
184    Parameters
185    ----------
186    ary : np.ndarray
187        A numpy array. If it has 2 or more dimensions, the first dimension should be
188        independent chains from a simulation. Use `np.expand_dims(ary, 0)` to add a
189        single dimension to the front if there is only 1 chain.
190    var_name : str
191        If there are no dims passed, this string is used to name dimensions
192    coords : dict[str, iterable]
193        A dictionary containing the values that are used as index. The key
194        is the name of the dimension, the values are the index values.
195    dims : List(str)
196        A list of coordinate names for the variable
197    default_dims : list of str, optional
198        Passed to :py:func:`generate_dims_coords`. Defaults to ``["chain", "draw"]``, and
199        an empty list is accepted
200    index_origin : int, optional
201        Passed to :py:func:`generate_dims_coords`
202    skip_event_dims : bool
203
204    Returns
205    -------
206    xr.DataArray
207        Will have the same data as passed, but with coordinates and dimensions
208    """
209    # manage and transform copies
210    if default_dims is None:
211        default_dims = ["chain", "draw"]
212    if "chain" in default_dims and "draw" in default_dims:
213        ary = utils.two_de(ary)
214        n_chains, n_samples, *_ = ary.shape
215        if n_chains > n_samples:
216            warnings.warn(
217                "More chains ({n_chains}) than draws ({n_samples}). "
218                "Passed array should have shape (chains, draws, *shape)".format(
219                    n_chains=n_chains, n_samples=n_samples
220                ),
221                UserWarning,
222            )
223    else:
224        ary = utils.one_de(ary)
225
226    dims, coords = generate_dims_coords(
227        ary.shape[len(default_dims) :],
228        var_name,
229        dims=dims,
230        coords=coords,
231        default_dims=default_dims,
232        index_origin=index_origin,
233        skip_event_dims=skip_event_dims,
234    )
235
236    # reversed order for default dims: 'chain', 'draw'
237    if "draw" not in dims and "draw" in default_dims:
238        dims = ["draw"] + dims
239    if "chain" not in dims and "chain" in default_dims:
240        dims = ["chain"] + dims
241
242    index_origin = rcParams["data.index_origin"]
243    if "chain" not in coords and "chain" in default_dims:
244        coords["chain"] = np.arange(index_origin, n_chains + index_origin)
245    if "draw" not in coords and "draw" in default_dims:
246        coords["draw"] = np.arange(index_origin, n_samples + index_origin)
247
248    # filter coords based on the dims
249    coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in dims}
250    return xr.DataArray(ary, coords=coords, dims=dims)
251
252
253def dict_to_dataset(
254    data,
255    *,
256    attrs=None,
257    library=None,
258    coords=None,
259    dims=None,
260    default_dims=None,
261    index_origin=None,
262    skip_event_dims=None,
263):
264    """Convert a dictionary of numpy arrays to an xarray.Dataset.
265
266    Parameters
267    ----------
268    data : dict[str] -> ndarray
269        Data to convert. Keys are variable names.
270    attrs : dict
271        Json serializable metadata to attach to the dataset, in addition to defaults.
272    library : module
273        Library used for performing inference. Will be attached to the attrs metadata.
274    coords : dict[str] -> ndarray
275        Coordinates for the dataset
276    dims : dict[str] -> list[str]
277        Dimensions of each variable. The keys are variable names, values are lists of
278        coordinates.
279    default_dims : list of str, optional
280        Passed to :py:func:`numpy_to_data_array`
281    index_origin : int, optional
282        Passed to :py:func:`numpy_to_data_array`
283    skip_event_dims : bool
284        If True, cut extra dims whenever present to match the shape of the data.
285        Necessary for PPLs which have the same name in both observed data and log
286        likelihood groups, to account for their different shapes when observations are
287        multivariate.
288
289    Returns
290    -------
291    xr.Dataset
292
293    Examples
294    --------
295    dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)})
296
297    """
298    if dims is None:
299        dims = {}
300
301    data_vars = {}
302    for key, values in data.items():
303        data_vars[key] = numpy_to_data_array(
304            values,
305            var_name=key,
306            coords=coords,
307            dims=dims.get(key),
308            default_dims=default_dims,
309            index_origin=index_origin,
310            skip_event_dims=skip_event_dims,
311        )
312    return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
313
314
315def make_attrs(attrs=None, library=None):
316    """Make standard attributes to attach to xarray datasets.
317
318    Parameters
319    ----------
320    attrs : dict (optional)
321        Additional attributes to add or overwrite
322
323    Returns
324    -------
325    dict
326        attrs
327    """
328    default_attrs = {
329        "created_at": datetime.datetime.utcnow().isoformat(),
330        "arviz_version": __version__,
331    }
332    if library is not None:
333        library_name = library.__name__
334        default_attrs["inference_library"] = library_name
335        try:
336            version = pkg_resources.get_distribution(library_name).version
337            default_attrs["inference_library_version"] = version
338        except pkg_resources.DistributionNotFound:
339            if hasattr(library, "__version__"):
340                version = library.__version__
341                default_attrs["inference_library_version"] = version
342
343    if attrs is not None:
344        default_attrs.update(attrs)
345    return default_attrs
346
347
348def _extend_xr_method(func, doc="", description="", examples="", see_also=""):
349    """Make wrapper to extend methods from xr.Dataset to InferenceData Class.
350
351    Parameters
352    ----------
353    func : callable
354        An xr.Dataset function
355    doc : str
356        docstring for the func
357    description : str
358        the description of the func to be added in docstring
359    examples : str
360        the examples of the func to be added in docstring
361    see_also : str, list
362        the similar methods of func to be included in See Also section of docstring
363
364    """
365    # pydocstyle requires a non empty line
366
367    @functools.wraps(func)
368    def wrapped(self, *args, **kwargs):
369        _filter = kwargs.pop("filter_groups", None)
370        _groups = kwargs.pop("groups", None)
371        _inplace = kwargs.pop("inplace", False)
372
373        out = self if _inplace else deepcopy(self)
374
375        groups = self._group_names(_groups, _filter)  # pylint: disable=protected-access
376        for group in groups:
377            xr_data = getattr(out, group)
378            xr_data = func(xr_data, *args, **kwargs)  # pylint: disable=not-callable
379            setattr(out, group, xr_data)
380
381        return None if _inplace else out
382
383    description_default = """{method_name} method is extended from xarray.Dataset methods.
384
385    {description}For more info see :meth:`xarray:xarray.Dataset.{method_name}`
386    """.format(
387        description=description, method_name=func.__name__  # pylint: disable=no-member
388    )
389    params = """
390    Parameters
391    ----------
392    groups: str or list of str, optional
393        Groups where the selection is to be applied. Can either be group names
394        or metagroup names.
395    filter_groups: {None, "like", "regex"}, optional, default=None
396        If `None` (default), interpret groups as the real group or metagroup names.
397        If "like", interpret groups as substrings of the real group or metagroup names.
398        If "regex", interpret groups as regular expressions on the real group or
399        metagroup names. A la `pandas.filter`.
400    inplace: bool, optional
401        If ``True``, modify the InferenceData object inplace,
402        otherwise, return the modified copy.
403    """
404
405    if not isinstance(see_also, str):
406        see_also = "\n".join(see_also)
407    see_also_basic = """
408    See Also
409    --------
410    xarray.Dataset.{method_name}
411    {custom_see_also}
412    """.format(
413        method_name=func.__name__, custom_see_also=see_also  # pylint: disable=no-member
414    )
415    wrapped.__doc__ = (
416        description_default + params + examples + see_also_basic if doc is None else doc
417    )
418
419    return wrapped
420
421
422def _make_json_serializable(data: dict) -> dict:
423    """Convert `data` with numpy.ndarray-like values to JSON-serializable form."""
424    ret = {}
425    for key, value in data.items():
426        try:
427            json.dumps(value)
428        except (TypeError, OverflowError):
429            pass
430        else:
431            ret[key] = value
432            continue
433        if isinstance(value, dict):
434            ret[key] = _make_json_serializable(value)
435        elif isinstance(value, np.ndarray):
436            ret[key] = np.asarray(value).tolist()
437        else:
438            raise TypeError(
439                f"Value associated with variable `{type(value)}` is not JSON serializable."
440            )
441    return ret
442
443
444def infer_stan_dtypes(stan_code):
445    """Infer Stan integer variables from generated quantities block."""
446    # Remove old deprecated comments
447    stan_code = "\n".join(
448        line if "#" not in line else line[: line.find("#")] for line in stan_code.splitlines()
449    )
450    pattern_remove_comments = re.compile(
451        r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', re.DOTALL | re.MULTILINE
452    )
453    stan_code = re.sub(pattern_remove_comments, "", stan_code)
454
455    # Check generated quantities
456    if "generated quantities" not in stan_code:
457        return {}
458
459    # Extract generated quantities block
460    gen_quantities_location = stan_code.index("generated quantities")
461    block_start = gen_quantities_location + stan_code[gen_quantities_location:].index("{")
462
463    curly_bracket_count = 0
464    block_end = None
465    for block_end, char in enumerate(stan_code[block_start:], block_start + 1):
466        if char == "{":
467            curly_bracket_count += 1
468        elif char == "}":
469            curly_bracket_count -= 1
470
471            if curly_bracket_count == 0:
472                break
473
474    stan_code = stan_code[block_start:block_end]
475
476    stan_integer = r"int"
477    stan_limits = r"(?:\<[^\>]+\>)*"  # ignore group: 0 or more <....>
478    stan_param = r"([^;=\s\[]+)"  # capture group: ends= ";", "=", "[" or whitespace
479    stan_ws = r"\s*"  # 0 or more whitespace
480    stan_ws_one = r"\s+"  # 1 or more whitespace
481    pattern_int = re.compile(
482        "".join((stan_integer, stan_ws_one, stan_limits, stan_ws, stan_param)), re.IGNORECASE
483    )
484    dtypes = {key.strip(): "int" for key in re.findall(pattern_int, stan_code)}
485    return dtypes
486