1"""Utilities for plotting.""" 2import importlib 3import warnings 4from typing import Any, Dict 5 6import matplotlib as mpl 7import numpy as np 8import packaging 9from matplotlib.colors import to_hex 10from scipy.stats import mode, rankdata 11from scipy.interpolate import CubicSpline 12 13 14from ..rcparams import rcParams 15from ..stats.density_utils import kde 16from ..stats import hdi 17 18KwargSpec = Dict[str, Any] 19 20 21def make_2d(ary): 22 """Convert any array into a 2d numpy array. 23 24 In case the array is already more than 2 dimensional, will ravel the 25 dimensions after the first. 26 """ 27 dim_0, *_ = np.atleast_1d(ary).shape 28 return ary.reshape(dim_0, -1, order="F") 29 30 31def _scale_fig_size(figsize, textsize, rows=1, cols=1): 32 """Scale figure properties according to rows and cols. 33 34 Parameters 35 ---------- 36 figsize : float or None 37 Size of figure in inches 38 textsize : float or None 39 fontsize 40 rows : int 41 Number of rows 42 cols : int 43 Number of columns 44 45 Returns 46 ------- 47 figsize : float or None 48 Size of figure in inches 49 ax_labelsize : int 50 fontsize for axes label 51 titlesize : int 52 fontsize for title 53 xt_labelsize : int 54 fontsize for axes ticks 55 linewidth : int 56 linewidth 57 markersize : int 58 markersize 59 """ 60 params = mpl.rcParams 61 rc_width, rc_height = tuple(params["figure.figsize"]) 62 rc_ax_labelsize = params["axes.labelsize"] 63 rc_titlesize = params["axes.titlesize"] 64 rc_xt_labelsize = params["xtick.labelsize"] 65 rc_linewidth = params["lines.linewidth"] 66 rc_markersize = params["lines.markersize"] 67 if isinstance(rc_ax_labelsize, str): 68 rc_ax_labelsize = 15 69 if isinstance(rc_titlesize, str): 70 rc_titlesize = 16 71 if isinstance(rc_xt_labelsize, str): 72 rc_xt_labelsize = 14 73 74 if figsize is None: 75 width, height = rc_width, rc_height 76 sff = 1 if (rows == cols == 1) else 1.15 77 width = width * cols * sff 78 height = height * rows * sff 79 else: 80 width, height = figsize 81 82 if textsize is not None: 83 scale_factor = textsize / rc_xt_labelsize 84 elif rows == cols == 1: 85 scale_factor = ((width * height) / (rc_width * rc_height)) ** 0.5 86 else: 87 scale_factor = 1 88 89 ax_labelsize = rc_ax_labelsize * scale_factor 90 titlesize = rc_titlesize * scale_factor 91 xt_labelsize = rc_xt_labelsize * scale_factor 92 linewidth = rc_linewidth * scale_factor 93 markersize = rc_markersize * scale_factor 94 95 return (width, height), ax_labelsize, titlesize, xt_labelsize, linewidth, markersize 96 97 98def default_grid(n_items, grid=None, max_cols=4, min_cols=3): # noqa: D202 99 """Make a grid for subplots. 100 101 Tries to get as close to sqrt(n_items) x sqrt(n_items) as it can, 102 but allows for custom logic 103 104 Parameters 105 ---------- 106 n_items : int 107 Number of panels required 108 grid : tuple 109 Number of rows and columns 110 max_cols : int 111 Maximum number of columns, inclusive 112 min_cols : int 113 Minimum number of columns, inclusive 114 115 Returns 116 ------- 117 (int, int) 118 Rows and columns, so that rows * columns >= n_items 119 """ 120 121 if grid is None: 122 123 def in_bounds(val): 124 return np.clip(val, min_cols, max_cols) 125 126 if n_items <= max_cols: 127 return 1, n_items 128 ideal = in_bounds(round(n_items ** 0.5)) 129 130 for offset in (0, 1, -1, 2, -2): 131 cols = in_bounds(ideal + offset) 132 rows, extra = divmod(n_items, cols) 133 if extra == 0: 134 return rows, cols 135 return n_items // ideal + 1, ideal 136 else: 137 rows, cols = grid 138 if rows * cols < n_items: 139 raise ValueError("The number of rows times columns is less than the number of subplots") 140 if (rows * cols) - n_items >= cols: 141 warnings.warn("The number of rows times columns is larger than necessary") 142 return rows, cols 143 144 145def format_sig_figs(value, default=None): 146 """Get a default number of significant figures. 147 148 Gives the integer part or `default`, whichever is bigger. 149 150 Examples 151 -------- 152 0.1234 --> 0.12 153 1.234 --> 1.2 154 12.34 --> 12 155 123.4 --> 123 156 """ 157 if default is None: 158 default = 2 159 if value == 0: 160 return 1 161 return max(int(np.log10(np.abs(value))) + 1, default) 162 163 164def round_num(n, round_to): 165 """ 166 Return a string representing a number with `round_to` significant figures. 167 168 Parameters 169 ---------- 170 n : float 171 number to round 172 round_to : int 173 number of significant figures 174 """ 175 sig_figs = format_sig_figs(n, round_to) 176 return "{n:.{sig_figs}g}".format(n=n, sig_figs=sig_figs) 177 178 179def color_from_dim(dataarray, dim_name): 180 """Return colors and color mapping of a DataArray using coord values as color code. 181 182 Parameters 183 ---------- 184 dataarray : xarray.DataArray 185 dim_name : str 186 dimension whose coordinates will be used as color code. 187 188 Returns 189 ------- 190 colors : array of floats 191 Array of colors (as floats for use with a cmap) for each element in the dataarray. 192 color_mapping : mapping coord_value -> float 193 Mapping from coord values to corresponding color 194 """ 195 present_dims = dataarray.dims 196 coord_values = dataarray[dim_name].values 197 unique_coords = set(coord_values) 198 color_mapping = {coord: num / len(unique_coords) for num, coord in enumerate(unique_coords)} 199 if len(present_dims) > 1: 200 multi_coords = dataarray.coords.to_index() 201 coord_idx = present_dims.index(dim_name) 202 colors = [color_mapping[coord[coord_idx]] for coord in multi_coords] 203 else: 204 colors = [color_mapping[coord] for coord in coord_values] 205 return colors, color_mapping 206 207 208def vectorized_to_hex(c_values, keep_alpha=False): 209 """Convert a color (including vector of colors) to hex. 210 211 Parameters 212 ---------- 213 c: Matplotlib color 214 215 keep_alpha: boolean 216 to select if alpha values should be kept in the final hex values. 217 218 Returns 219 ------- 220 rgba_hex : vector of hex values 221 """ 222 try: 223 hex_color = to_hex(c_values, keep_alpha) 224 225 except ValueError: 226 hex_color = [to_hex(color, keep_alpha) for color in c_values] 227 return hex_color 228 229 230def format_coords_as_labels(dataarray, skip_dims=None): 231 """Format 1d or multi-d dataarray coords as strings. 232 233 Parameters 234 ---------- 235 dataarray : xarray.DataArray 236 DataArray whose coordinates will be converted to labels. 237 skip_dims : str of list_like, optional 238 Dimensions whose values should not be included in the labels 239 """ 240 if skip_dims is None: 241 coord_labels = dataarray.coords.to_index() 242 else: 243 coord_labels = dataarray.coords.to_index().droplevel(skip_dims).drop_duplicates() 244 coord_labels = coord_labels.values 245 if isinstance(coord_labels[0], tuple): 246 fmt = ", ".join(["{}" for _ in coord_labels[0]]) 247 coord_labels[:] = [fmt.format(*x) for x in coord_labels] 248 else: 249 coord_labels[:] = [f"{s}" for s in coord_labels] 250 return coord_labels 251 252 253def set_xticklabels(ax, coord_labels): 254 """Set xticklabels to label list using Matplotlib default formatter.""" 255 ax.xaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10]) 256 xticks = ax.get_xticks().astype(np.int64) 257 xticks = xticks[(xticks >= 0) & (xticks < len(coord_labels))] 258 if len(xticks) > len(coord_labels): 259 ax.set_xticks(np.arange(len(coord_labels))) 260 ax.set_xticklabels(coord_labels) 261 else: 262 ax.set_xticks(xticks) 263 ax.set_xticklabels(coord_labels[xticks]) 264 265 266def filter_plotters_list(plotters, plot_kind): 267 """Cut list of plotters so that it is at most of length "plot.max_subplots".""" 268 max_plots = rcParams["plot.max_subplots"] 269 max_plots = len(plotters) if max_plots is None else max_plots 270 if len(plotters) > max_plots: 271 warnings.warn( 272 "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " 273 "of variables to plot ({len_plotters}) in {plot_kind}, generating only " 274 "{max_plots} plots".format( 275 max_plots=max_plots, len_plotters=len(plotters), plot_kind=plot_kind 276 ), 277 UserWarning, 278 ) 279 return plotters[:max_plots] 280 return plotters 281 282 283def get_plotting_function(plot_name, plot_module, backend): 284 """Return plotting function for correct backend.""" 285 _backend = { 286 "mpl": "matplotlib", 287 "bokeh": "bokeh", 288 "matplotlib": "matplotlib", 289 } 290 291 if backend is None: 292 backend = rcParams["plot.backend"] 293 backend = backend.lower() 294 295 try: 296 backend = _backend[backend] 297 except KeyError as err: 298 raise KeyError( 299 "Backend {} is not implemented. Try backend in {}".format( 300 backend, set(_backend.values()) 301 ) 302 ) from err 303 304 if backend == "bokeh": 305 try: 306 import bokeh 307 308 assert packaging.version.parse(bokeh.__version__) >= packaging.version.parse("1.4.0") 309 310 except (ImportError, AssertionError) as err: 311 raise ImportError( 312 "'bokeh' backend needs Bokeh (1.4.0+) installed." " Please upgrade or install" 313 ) from err 314 315 # Perform import of plotting method 316 # TODO: Convert module import to top level for all plots 317 module = importlib.import_module(f"arviz.plots.backends.{backend}.{plot_module}") 318 319 plotting_method = getattr(module, plot_name) 320 321 return plotting_method 322 323 324def calculate_point_estimate(point_estimate, values, bw="default", circular=False, skipna=False): 325 """Validate and calculate the point estimate. 326 327 Parameters 328 ---------- 329 point_estimate : Optional[str] 330 Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None. 331 Defaults to 'auto' i.e. it falls back to default set in rcParams. 332 values : 1-d array 333 bw: Optional[float or str] 334 If numeric, indicates the bandwidth and must be positive. 335 If str, indicates the method to estimate the bandwidth and must be 336 one of "scott", "silverman", "isj" or "experimental" when `circular` is False 337 and "taylor" (for now) when `circular` is True. 338 Defaults to "default" which means "experimental" when variable is not circular 339 and "taylor" when it is. 340 circular: Optional[bool] 341 If True, it interprets the values passed are from a circular variable measured in radians 342 and a circular KDE is used. Only valid for 1D KDE. Defaults to False. 343 skipna=True, 344 If true ignores nan values when computing the hdi. Defaults to false. 345 346 Returns 347 ------- 348 point_value : float 349 best estimate of data distribution 350 """ 351 point_value = None 352 if point_estimate == "auto": 353 point_estimate = rcParams["plot.point_estimate"] 354 elif point_estimate not in ("mean", "median", "mode", None): 355 raise ValueError( 356 "Point estimate should be 'mean', 'median', 'mode' or None, not {}".format( 357 point_estimate 358 ) 359 ) 360 if point_estimate == "mean": 361 if skipna: 362 point_value = np.nanmean(values) 363 else: 364 point_value = np.mean(values) 365 elif point_estimate == "mode": 366 if values.dtype.kind == "f": 367 if bw == "default": 368 if circular: 369 bw = "taylor" 370 else: 371 bw = "experimental" 372 x, density = kde(values, circular=circular, bw=bw) 373 point_value = x[np.argmax(density)] 374 else: 375 point_value = mode(values)[0][0] 376 elif point_estimate == "median": 377 if skipna: 378 point_value = np.nanmedian(values) 379 else: 380 point_value = np.median(values) 381 382 return point_value 383 384 385def plot_point_interval( 386 ax, 387 values, 388 point_estimate, 389 hdi_prob, 390 quartiles, 391 linewidth, 392 markersize, 393 markercolor, 394 marker, 395 rotated, 396 intervalcolor, 397 backend="matplotlib", 398): 399 """Plot point intervals. 400 401 Translates the data and represents them as point and interval summaries. 402 403 Parameters 404 ---------- 405 ax : axes 406 Matplotlib axes 407 values : array-like 408 Values to plot 409 point_estimate : str 410 Plot point estimate per variable. 411 linewidth : int 412 Line width throughout. 413 quartiles : bool 414 If True then the quartile interval will be plotted with the HDI. 415 markersize : int 416 Markersize throughout. 417 markercolor: string 418 Color of the marker. 419 marker: string 420 Shape of the marker. 421 hdi_prob : float 422 Valid only when point_interval is True. Plots HDI for chosen percentage of density. 423 rotated : bool 424 Whether to rotate the dot plot by 90 degrees. 425 intervalcolor : string 426 Color of the interval. 427 backend : string, optional 428 Matplotlib or Bokeh. 429 """ 430 endpoint = (1 - hdi_prob) / 2 431 if quartiles: 432 qlist_interval = [endpoint, 0.25, 0.75, 1 - endpoint] 433 else: 434 qlist_interval = [endpoint, 1 - endpoint] 435 quantiles_interval = np.quantile(values, qlist_interval) 436 437 quantiles_interval[0], quantiles_interval[-1] = hdi( 438 values.flatten(), hdi_prob, multimodal=False 439 ) 440 mid = len(quantiles_interval) // 2 441 param_iter = zip(np.linspace(2 * linewidth, linewidth, mid, endpoint=True)[-1::-1], range(mid)) 442 443 if backend == "matplotlib": 444 for width, j in param_iter: 445 if rotated: 446 ax.vlines( 447 0, 448 quantiles_interval[j], 449 quantiles_interval[-(j + 1)], 450 linewidth=width, 451 color=intervalcolor, 452 ) 453 else: 454 ax.hlines( 455 0, 456 quantiles_interval[j], 457 quantiles_interval[-(j + 1)], 458 linewidth=width, 459 color=intervalcolor, 460 ) 461 462 if point_estimate: 463 point_value = calculate_point_estimate(point_estimate, values) 464 if rotated: 465 ax.plot( 466 0, 467 point_value, 468 marker, 469 markersize=markersize, 470 color=markercolor, 471 ) 472 else: 473 ax.plot( 474 point_value, 475 0, 476 marker, 477 markersize=markersize, 478 color=markercolor, 479 ) 480 else: 481 for width, j in param_iter: 482 if rotated: 483 ax.line( 484 [0, 0], 485 [quantiles_interval[j], quantiles_interval[-(j + 1)]], 486 line_width=width, 487 color=intervalcolor, 488 ) 489 else: 490 ax.line( 491 [quantiles_interval[j], quantiles_interval[-(j + 1)]], 492 [0, 0], 493 line_width=width, 494 color=intervalcolor, 495 ) 496 497 if point_estimate: 498 point_value = calculate_point_estimate(point_estimate, values) 499 if rotated: 500 ax.circle( 501 x=0, 502 y=point_value, 503 size=markersize, 504 fill_color=markercolor, 505 ) 506 else: 507 ax.circle( 508 x=point_value, 509 y=0, 510 size=markersize, 511 fill_color=markercolor, 512 ) 513 514 return ax 515 516 517def is_valid_quantile(value): 518 """Check if value is a number between 0 and 1.""" 519 try: 520 value = float(value) 521 return 0 < value < 1 522 except ValueError: 523 return False 524 525 526def sample_reference_distribution(dist, shape): 527 """Generate samples from a scipy distribution with a given shape.""" 528 x_ss = [] 529 densities = [] 530 dist_rvs = dist.rvs(size=shape) 531 for idx in range(shape[1]): 532 x_s, density = kde(dist_rvs[:, idx]) 533 x_ss.append(x_s) 534 densities.append(density) 535 return np.array(x_ss).T, np.array(densities).T 536 537 538def set_bokeh_circular_ticks_labels(ax, hist, labels): 539 """Place ticks and ticklabels on Bokeh's circular histogram.""" 540 ticks = np.linspace(-np.pi, np.pi, len(labels), endpoint=False) 541 ax.annular_wedge( 542 x=0, 543 y=0, 544 inner_radius=0, 545 outer_radius=np.max(hist) * 1.1, 546 start_angle=ticks, 547 end_angle=ticks, 548 line_color="grey", 549 ) 550 551 radii_circles = np.linspace(0, np.max(hist) * 1.1, 4) 552 ax.circle(0, 0, radius=radii_circles, fill_color=None, line_color="grey") 553 554 offset = np.max(hist * 1.05) * 0.15 555 ticks_labels_pos_1 = np.max(hist * 1.05) 556 ticks_labels_pos_2 = ticks_labels_pos_1 * np.sqrt(2) / 2 557 558 ax.text( 559 [ 560 ticks_labels_pos_1 + offset, 561 ticks_labels_pos_2 + offset, 562 0, 563 -ticks_labels_pos_2 - offset, 564 -ticks_labels_pos_1 - offset, 565 -ticks_labels_pos_2 - offset, 566 0, 567 ticks_labels_pos_2 + offset, 568 ], 569 [ 570 0, 571 ticks_labels_pos_2 + offset / 2, 572 ticks_labels_pos_1 + offset, 573 ticks_labels_pos_2 + offset / 2, 574 0, 575 -ticks_labels_pos_2 - offset, 576 -ticks_labels_pos_1 - offset, 577 -ticks_labels_pos_2 - offset, 578 ], 579 text=labels, 580 text_align="center", 581 ) 582 583 return ax 584 585 586def compute_ranks(ary): 587 """Compute ranks for continuous and discrete variables.""" 588 if ary.dtype.kind == "i": 589 ary_shape = ary.shape 590 ary = ary.flatten() 591 min_ary, max_ary = min(ary), max(ary) 592 x = np.linspace(min_ary, max_ary, len(ary)) 593 csi = CubicSpline(x, ary) 594 ary = csi(np.linspace(min_ary + 0.001, max_ary - 0.001, len(ary))).reshape(ary_shape) 595 ranks = rankdata(ary, method="average").reshape(ary.shape) 596 597 return ranks 598