1"""Low level converters usually used by other functions.""" 2import datetime 3import functools 4import re 5import warnings 6from copy import deepcopy 7from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union 8 9import numpy as np 10import pkg_resources 11import xarray as xr 12 13try: 14 import ujson as json 15except ImportError: 16 # mypy struggles with conditional imports expressed as catching ImportError: 17 # https://github.com/python/mypy/issues/1153 18 import json # type: ignore 19 20from .. import __version__, utils 21from ..rcparams import rcParams 22 23CoordSpec = Dict[str, List[Any]] 24DimSpec = Dict[str, List[str]] 25RequiresArgTypeT = TypeVar("RequiresArgTypeT") 26RequiresReturnTypeT = TypeVar("RequiresReturnTypeT") 27 28 29class requires: # pylint: disable=invalid-name 30 """Decorator to return None if an object does not have the required attribute. 31 32 If the decorator is called various times on the same function with different 33 attributes, it will return None if one of them is missing. If instead a list 34 of attributes is passed, it will return None if all attributes in the list are 35 missing. Both functionalities can be combined as desired. 36 """ 37 38 def __init__(self, *props: Union[str, List[str]]) -> None: 39 self.props: Tuple[Union[str, List[str]], ...] = props 40 41 # Until typing.ParamSpec (https://www.python.org/dev/peps/pep-0612/) is available 42 # in all our supported Python versions, there is no way to simultaneously express 43 # the following two properties: 44 # - the input function may take arbitrary args/kwargs, and 45 # - the output function takes those same arbitrary args/kwargs, but has a different return type. 46 # We either have to limit the input function to e.g. only allowing a "self" argument, 47 # or we have to adopt the current approach of annotating the returned function as if 48 # it was defined as "def f(*args: Any, **kwargs: Any) -> Optional[RequiresReturnTypeT]". 49 # 50 # Since all functions decorated with @requires currently only accept a single argument, 51 # we choose to limit application of @requires to only functions of one argument. 52 # When typing.ParamSpec is available, this definition can be updated to use it. 53 # See https://github.com/arviz-devs/arviz/pull/1504 for more discussion. 54 def __call__( 55 self, func: Callable[[RequiresArgTypeT], RequiresReturnTypeT] 56 ) -> Callable[[RequiresArgTypeT], Optional[RequiresReturnTypeT]]: # noqa: D202 57 """Wrap the decorated function.""" 58 59 def wrapped(cls: RequiresArgTypeT) -> Optional[RequiresReturnTypeT]: 60 """Return None if not all props are available.""" 61 for prop in self.props: 62 prop = [prop] if isinstance(prop, str) else prop 63 if all((getattr(cls, prop_i) is None for prop_i in prop)): 64 return None 65 return func(cls) 66 67 return wrapped 68 69 70def generate_dims_coords( 71 shape, 72 var_name, 73 dims=None, 74 coords=None, 75 default_dims=None, 76 index_origin=None, 77 skip_event_dims=None, 78): 79 """Generate default dimensions and coordinates for a variable. 80 81 Parameters 82 ---------- 83 shape : tuple[int] 84 Shape of the variable 85 var_name : str 86 Name of the variable. If no dimension name(s) is provided, ArviZ 87 will generate a default dimension name using ``var_name``, e.g., 88 ``"foo_dim_0"`` for the first dimension if ``var_name`` is ``"foo"``. 89 dims : list 90 List of dimensions for the variable 91 coords : dict[str] -> list[str] 92 Map of dimensions to coordinates 93 default_dims : list[str] 94 Dimension names that are not part of the variable's shape. For example, 95 when manipulating Monte Carlo traces, the ``default_dims`` would be 96 ``["chain" , "draw"]`` which ArviZ uses as its own names for dimensions 97 of MCMC traces. 98 index_origin : int, optional 99 Starting value of integer coordinate values. Defaults to the value in rcParam 100 ``data.index_origin``. 101 skip_event_dims : bool, default False 102 103 Returns 104 ------- 105 list[str] 106 Default dims 107 dict[str] -> list[str] 108 Default coords 109 """ 110 if index_origin is None: 111 index_origin = rcParams["data.index_origin"] 112 if default_dims is None: 113 default_dims = [] 114 if dims is None: 115 dims = [] 116 if skip_event_dims is None: 117 skip_event_dims = False 118 119 if coords is None: 120 coords = {} 121 122 coords = deepcopy(coords) 123 dims = deepcopy(dims) 124 125 ndims = len([dim for dim in dims if dim not in default_dims]) 126 if ndims > len(shape): 127 if skip_event_dims: 128 dims = dims[: len(shape)] 129 else: 130 warnings.warn( 131 ( 132 "In variable {var_name}, there are " 133 + "more dims ({dims_len}) given than exist ({shape_len}). " 134 + "Passed array should have shape ({defaults}*shape)" 135 ).format( 136 var_name=var_name, 137 dims_len=len(dims), 138 shape_len=len(shape), 139 defaults=",".join(default_dims) + ", " if default_dims is not None else "", 140 ), 141 UserWarning, 142 ) 143 if skip_event_dims: 144 # this is needed in case the reduction keeps the dimension with size 1 145 for i, (dim, dim_size) in enumerate(zip(dims, shape)): 146 if (dim in coords) and (dim_size != len(coords[dim])): 147 dims = dims[:i] 148 break 149 150 for idx, dim_len in enumerate(shape): 151 if (len(dims) < idx + 1) or (dims[idx] is None): 152 dim_name = f"{var_name}_dim_{idx}" 153 if len(dims) < idx + 1: 154 dims.append(dim_name) 155 else: 156 dims[idx] = dim_name 157 dim_name = dims[idx] 158 if dim_name not in coords: 159 coords[dim_name] = np.arange(index_origin, dim_len + index_origin) 160 coords = {key: coord for key, coord in coords.items() if any(key == dim for dim in dims)} 161 return dims, coords 162 163 164def numpy_to_data_array( 165 ary, 166 *, 167 var_name="data", 168 coords=None, 169 dims=None, 170 default_dims=None, 171 index_origin=None, 172 skip_event_dims=None, 173): 174 """Convert a numpy array to an xarray.DataArray. 175 176 By default, the first two dimensions will be (chain, draw), and any remaining 177 dimensions will be "shape". 178 * If the numpy array is 1d, this dimension is interpreted as draw 179 * If the numpy array is 2d, it is interpreted as (chain, draw) 180 * If the numpy array is 3 or more dimensions, the last dimensions are kept as shapes. 181 182 To modify this behaviour, use ``default_dims``. 183 184 Parameters 185 ---------- 186 ary : np.ndarray 187 A numpy array. If it has 2 or more dimensions, the first dimension should be 188 independent chains from a simulation. Use `np.expand_dims(ary, 0)` to add a 189 single dimension to the front if there is only 1 chain. 190 var_name : str 191 If there are no dims passed, this string is used to name dimensions 192 coords : dict[str, iterable] 193 A dictionary containing the values that are used as index. The key 194 is the name of the dimension, the values are the index values. 195 dims : List(str) 196 A list of coordinate names for the variable 197 default_dims : list of str, optional 198 Passed to :py:func:`generate_dims_coords`. Defaults to ``["chain", "draw"]``, and 199 an empty list is accepted 200 index_origin : int, optional 201 Passed to :py:func:`generate_dims_coords` 202 skip_event_dims : bool 203 204 Returns 205 ------- 206 xr.DataArray 207 Will have the same data as passed, but with coordinates and dimensions 208 """ 209 # manage and transform copies 210 if default_dims is None: 211 default_dims = ["chain", "draw"] 212 if "chain" in default_dims and "draw" in default_dims: 213 ary = utils.two_de(ary) 214 n_chains, n_samples, *_ = ary.shape 215 if n_chains > n_samples: 216 warnings.warn( 217 "More chains ({n_chains}) than draws ({n_samples}). " 218 "Passed array should have shape (chains, draws, *shape)".format( 219 n_chains=n_chains, n_samples=n_samples 220 ), 221 UserWarning, 222 ) 223 else: 224 ary = utils.one_de(ary) 225 226 dims, coords = generate_dims_coords( 227 ary.shape[len(default_dims) :], 228 var_name, 229 dims=dims, 230 coords=coords, 231 default_dims=default_dims, 232 index_origin=index_origin, 233 skip_event_dims=skip_event_dims, 234 ) 235 236 # reversed order for default dims: 'chain', 'draw' 237 if "draw" not in dims and "draw" in default_dims: 238 dims = ["draw"] + dims 239 if "chain" not in dims and "chain" in default_dims: 240 dims = ["chain"] + dims 241 242 index_origin = rcParams["data.index_origin"] 243 if "chain" not in coords and "chain" in default_dims: 244 coords["chain"] = np.arange(index_origin, n_chains + index_origin) 245 if "draw" not in coords and "draw" in default_dims: 246 coords["draw"] = np.arange(index_origin, n_samples + index_origin) 247 248 # filter coords based on the dims 249 coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in dims} 250 return xr.DataArray(ary, coords=coords, dims=dims) 251 252 253def dict_to_dataset( 254 data, 255 *, 256 attrs=None, 257 library=None, 258 coords=None, 259 dims=None, 260 default_dims=None, 261 index_origin=None, 262 skip_event_dims=None, 263): 264 """Convert a dictionary of numpy arrays to an xarray.Dataset. 265 266 Parameters 267 ---------- 268 data : dict[str] -> ndarray 269 Data to convert. Keys are variable names. 270 attrs : dict 271 Json serializable metadata to attach to the dataset, in addition to defaults. 272 library : module 273 Library used for performing inference. Will be attached to the attrs metadata. 274 coords : dict[str] -> ndarray 275 Coordinates for the dataset 276 dims : dict[str] -> list[str] 277 Dimensions of each variable. The keys are variable names, values are lists of 278 coordinates. 279 default_dims : list of str, optional 280 Passed to :py:func:`numpy_to_data_array` 281 index_origin : int, optional 282 Passed to :py:func:`numpy_to_data_array` 283 skip_event_dims : bool 284 If True, cut extra dims whenever present to match the shape of the data. 285 Necessary for PPLs which have the same name in both observed data and log 286 likelihood groups, to account for their different shapes when observations are 287 multivariate. 288 289 Returns 290 ------- 291 xr.Dataset 292 293 Examples 294 -------- 295 dict_to_dataset({'x': np.random.randn(4, 100), 'y': np.random.rand(4, 100)}) 296 297 """ 298 if dims is None: 299 dims = {} 300 301 data_vars = {} 302 for key, values in data.items(): 303 data_vars[key] = numpy_to_data_array( 304 values, 305 var_name=key, 306 coords=coords, 307 dims=dims.get(key), 308 default_dims=default_dims, 309 index_origin=index_origin, 310 skip_event_dims=skip_event_dims, 311 ) 312 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library)) 313 314 315def make_attrs(attrs=None, library=None): 316 """Make standard attributes to attach to xarray datasets. 317 318 Parameters 319 ---------- 320 attrs : dict (optional) 321 Additional attributes to add or overwrite 322 323 Returns 324 ------- 325 dict 326 attrs 327 """ 328 default_attrs = { 329 "created_at": datetime.datetime.utcnow().isoformat(), 330 "arviz_version": __version__, 331 } 332 if library is not None: 333 library_name = library.__name__ 334 default_attrs["inference_library"] = library_name 335 try: 336 version = pkg_resources.get_distribution(library_name).version 337 default_attrs["inference_library_version"] = version 338 except pkg_resources.DistributionNotFound: 339 if hasattr(library, "__version__"): 340 version = library.__version__ 341 default_attrs["inference_library_version"] = version 342 343 if attrs is not None: 344 default_attrs.update(attrs) 345 return default_attrs 346 347 348def _extend_xr_method(func, doc="", description="", examples="", see_also=""): 349 """Make wrapper to extend methods from xr.Dataset to InferenceData Class. 350 351 Parameters 352 ---------- 353 func : callable 354 An xr.Dataset function 355 doc : str 356 docstring for the func 357 description : str 358 the description of the func to be added in docstring 359 examples : str 360 the examples of the func to be added in docstring 361 see_also : str, list 362 the similar methods of func to be included in See Also section of docstring 363 364 """ 365 # pydocstyle requires a non empty line 366 367 @functools.wraps(func) 368 def wrapped(self, *args, **kwargs): 369 _filter = kwargs.pop("filter_groups", None) 370 _groups = kwargs.pop("groups", None) 371 _inplace = kwargs.pop("inplace", False) 372 373 out = self if _inplace else deepcopy(self) 374 375 groups = self._group_names(_groups, _filter) # pylint: disable=protected-access 376 for group in groups: 377 xr_data = getattr(out, group) 378 xr_data = func(xr_data, *args, **kwargs) # pylint: disable=not-callable 379 setattr(out, group, xr_data) 380 381 return None if _inplace else out 382 383 description_default = """{method_name} method is extended from xarray.Dataset methods. 384 385 {description}For more info see :meth:`xarray:xarray.Dataset.{method_name}` 386 """.format( 387 description=description, method_name=func.__name__ # pylint: disable=no-member 388 ) 389 params = """ 390 Parameters 391 ---------- 392 groups: str or list of str, optional 393 Groups where the selection is to be applied. Can either be group names 394 or metagroup names. 395 filter_groups: {None, "like", "regex"}, optional, default=None 396 If `None` (default), interpret groups as the real group or metagroup names. 397 If "like", interpret groups as substrings of the real group or metagroup names. 398 If "regex", interpret groups as regular expressions on the real group or 399 metagroup names. A la `pandas.filter`. 400 inplace: bool, optional 401 If ``True``, modify the InferenceData object inplace, 402 otherwise, return the modified copy. 403 """ 404 405 if not isinstance(see_also, str): 406 see_also = "\n".join(see_also) 407 see_also_basic = """ 408 See Also 409 -------- 410 xarray.Dataset.{method_name} 411 {custom_see_also} 412 """.format( 413 method_name=func.__name__, custom_see_also=see_also # pylint: disable=no-member 414 ) 415 wrapped.__doc__ = ( 416 description_default + params + examples + see_also_basic if doc is None else doc 417 ) 418 419 return wrapped 420 421 422def _make_json_serializable(data: dict) -> dict: 423 """Convert `data` with numpy.ndarray-like values to JSON-serializable form.""" 424 ret = {} 425 for key, value in data.items(): 426 try: 427 json.dumps(value) 428 except (TypeError, OverflowError): 429 pass 430 else: 431 ret[key] = value 432 continue 433 if isinstance(value, dict): 434 ret[key] = _make_json_serializable(value) 435 elif isinstance(value, np.ndarray): 436 ret[key] = np.asarray(value).tolist() 437 else: 438 raise TypeError( 439 f"Value associated with variable `{type(value)}` is not JSON serializable." 440 ) 441 return ret 442 443 444def infer_stan_dtypes(stan_code): 445 """Infer Stan integer variables from generated quantities block.""" 446 # Remove old deprecated comments 447 stan_code = "\n".join( 448 line if "#" not in line else line[: line.find("#")] for line in stan_code.splitlines() 449 ) 450 pattern_remove_comments = re.compile( 451 r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', re.DOTALL | re.MULTILINE 452 ) 453 stan_code = re.sub(pattern_remove_comments, "", stan_code) 454 455 # Check generated quantities 456 if "generated quantities" not in stan_code: 457 return {} 458 459 # Extract generated quantities block 460 gen_quantities_location = stan_code.index("generated quantities") 461 block_start = gen_quantities_location + stan_code[gen_quantities_location:].index("{") 462 463 curly_bracket_count = 0 464 block_end = None 465 for block_end, char in enumerate(stan_code[block_start:], block_start + 1): 466 if char == "{": 467 curly_bracket_count += 1 468 elif char == "}": 469 curly_bracket_count -= 1 470 471 if curly_bracket_count == 0: 472 break 473 474 stan_code = stan_code[block_start:block_end] 475 476 stan_integer = r"int" 477 stan_limits = r"(?:\<[^\>]+\>)*" # ignore group: 0 or more <....> 478 stan_param = r"([^;=\s\[]+)" # capture group: ends= ";", "=", "[" or whitespace 479 stan_ws = r"\s*" # 0 or more whitespace 480 stan_ws_one = r"\s+" # 1 or more whitespace 481 pattern_int = re.compile( 482 "".join((stan_integer, stan_ws_one, stan_limits, stan_ws, stan_param)), re.IGNORECASE 483 ) 484 dtypes = {key.strip(): "int" for key in re.findall(pattern_int, stan_code)} 485 return dtypes 486