1"""Utilities for plotting."""
2import importlib
3import warnings
4from typing import Any, Dict
5
6import matplotlib as mpl
7import numpy as np
8import packaging
9from matplotlib.colors import to_hex
10from scipy.stats import mode, rankdata
11from scipy.interpolate import CubicSpline
12
13
14from ..rcparams import rcParams
15from ..stats.density_utils import kde
16from ..stats import hdi
17
18KwargSpec = Dict[str, Any]
19
20
21def make_2d(ary):
22    """Convert any array into a 2d numpy array.
23
24    In case the array is already more than 2 dimensional, will ravel the
25    dimensions after the first.
26    """
27    dim_0, *_ = np.atleast_1d(ary).shape
28    return ary.reshape(dim_0, -1, order="F")
29
30
31def _scale_fig_size(figsize, textsize, rows=1, cols=1):
32    """Scale figure properties according to rows and cols.
33
34    Parameters
35    ----------
36    figsize : float or None
37        Size of figure in inches
38    textsize : float or None
39        fontsize
40    rows : int
41        Number of rows
42    cols : int
43        Number of columns
44
45    Returns
46    -------
47    figsize : float or None
48        Size of figure in inches
49    ax_labelsize : int
50        fontsize for axes label
51    titlesize : int
52        fontsize for title
53    xt_labelsize : int
54        fontsize for axes ticks
55    linewidth : int
56        linewidth
57    markersize : int
58        markersize
59    """
60    params = mpl.rcParams
61    rc_width, rc_height = tuple(params["figure.figsize"])
62    rc_ax_labelsize = params["axes.labelsize"]
63    rc_titlesize = params["axes.titlesize"]
64    rc_xt_labelsize = params["xtick.labelsize"]
65    rc_linewidth = params["lines.linewidth"]
66    rc_markersize = params["lines.markersize"]
67    if isinstance(rc_ax_labelsize, str):
68        rc_ax_labelsize = 15
69    if isinstance(rc_titlesize, str):
70        rc_titlesize = 16
71    if isinstance(rc_xt_labelsize, str):
72        rc_xt_labelsize = 14
73
74    if figsize is None:
75        width, height = rc_width, rc_height
76        sff = 1 if (rows == cols == 1) else 1.15
77        width = width * cols * sff
78        height = height * rows * sff
79    else:
80        width, height = figsize
81
82    if textsize is not None:
83        scale_factor = textsize / rc_xt_labelsize
84    elif rows == cols == 1:
85        scale_factor = ((width * height) / (rc_width * rc_height)) ** 0.5
86    else:
87        scale_factor = 1
88
89    ax_labelsize = rc_ax_labelsize * scale_factor
90    titlesize = rc_titlesize * scale_factor
91    xt_labelsize = rc_xt_labelsize * scale_factor
92    linewidth = rc_linewidth * scale_factor
93    markersize = rc_markersize * scale_factor
94
95    return (width, height), ax_labelsize, titlesize, xt_labelsize, linewidth, markersize
96
97
98def default_grid(n_items, grid=None, max_cols=4, min_cols=3):  # noqa: D202
99    """Make a grid for subplots.
100
101    Tries to get as close to sqrt(n_items) x sqrt(n_items) as it can,
102    but allows for custom logic
103
104    Parameters
105    ----------
106    n_items : int
107        Number of panels required
108    grid : tuple
109        Number of rows and columns
110    max_cols : int
111        Maximum number of columns, inclusive
112    min_cols : int
113        Minimum number of columns, inclusive
114
115    Returns
116    -------
117    (int, int)
118        Rows and columns, so that rows * columns >= n_items
119    """
120
121    if grid is None:
122
123        def in_bounds(val):
124            return np.clip(val, min_cols, max_cols)
125
126        if n_items <= max_cols:
127            return 1, n_items
128        ideal = in_bounds(round(n_items ** 0.5))
129
130        for offset in (0, 1, -1, 2, -2):
131            cols = in_bounds(ideal + offset)
132            rows, extra = divmod(n_items, cols)
133            if extra == 0:
134                return rows, cols
135        return n_items // ideal + 1, ideal
136    else:
137        rows, cols = grid
138        if rows * cols < n_items:
139            raise ValueError("The number of rows times columns is less than the number of subplots")
140        if (rows * cols) - n_items >= cols:
141            warnings.warn("The number of rows times columns is larger than necessary")
142        return rows, cols
143
144
145def format_sig_figs(value, default=None):
146    """Get a default number of significant figures.
147
148    Gives the integer part or `default`, whichever is bigger.
149
150    Examples
151    --------
152    0.1234 --> 0.12
153    1.234  --> 1.2
154    12.34  --> 12
155    123.4  --> 123
156    """
157    if default is None:
158        default = 2
159    if value == 0:
160        return 1
161    return max(int(np.log10(np.abs(value))) + 1, default)
162
163
164def round_num(n, round_to):
165    """
166    Return a string representing a number with `round_to` significant figures.
167
168    Parameters
169    ----------
170    n : float
171        number to round
172    round_to : int
173        number of significant figures
174    """
175    sig_figs = format_sig_figs(n, round_to)
176    return "{n:.{sig_figs}g}".format(n=n, sig_figs=sig_figs)
177
178
179def color_from_dim(dataarray, dim_name):
180    """Return colors and color mapping of a DataArray using coord values as color code.
181
182    Parameters
183    ----------
184    dataarray : xarray.DataArray
185    dim_name : str
186    dimension whose coordinates will be used as color code.
187
188    Returns
189    -------
190    colors : array of floats
191        Array of colors (as floats for use with a cmap) for each element in the dataarray.
192    color_mapping : mapping coord_value -> float
193        Mapping from coord values to corresponding color
194    """
195    present_dims = dataarray.dims
196    coord_values = dataarray[dim_name].values
197    unique_coords = set(coord_values)
198    color_mapping = {coord: num / len(unique_coords) for num, coord in enumerate(unique_coords)}
199    if len(present_dims) > 1:
200        multi_coords = dataarray.coords.to_index()
201        coord_idx = present_dims.index(dim_name)
202        colors = [color_mapping[coord[coord_idx]] for coord in multi_coords]
203    else:
204        colors = [color_mapping[coord] for coord in coord_values]
205    return colors, color_mapping
206
207
208def vectorized_to_hex(c_values, keep_alpha=False):
209    """Convert a color (including vector of colors) to hex.
210
211    Parameters
212    ----------
213    c: Matplotlib color
214
215    keep_alpha: boolean
216        to select if alpha values should be kept in the final hex values.
217
218    Returns
219    -------
220    rgba_hex : vector of hex values
221    """
222    try:
223        hex_color = to_hex(c_values, keep_alpha)
224
225    except ValueError:
226        hex_color = [to_hex(color, keep_alpha) for color in c_values]
227    return hex_color
228
229
230def format_coords_as_labels(dataarray, skip_dims=None):
231    """Format 1d or multi-d dataarray coords as strings.
232
233    Parameters
234    ----------
235    dataarray : xarray.DataArray
236        DataArray whose coordinates will be converted to labels.
237    skip_dims : str of list_like, optional
238        Dimensions whose values should not be included in the labels
239    """
240    if skip_dims is None:
241        coord_labels = dataarray.coords.to_index()
242    else:
243        coord_labels = dataarray.coords.to_index().droplevel(skip_dims).drop_duplicates()
244    coord_labels = coord_labels.values
245    if isinstance(coord_labels[0], tuple):
246        fmt = ", ".join(["{}" for _ in coord_labels[0]])
247        coord_labels[:] = [fmt.format(*x) for x in coord_labels]
248    else:
249        coord_labels[:] = [f"{s}" for s in coord_labels]
250    return coord_labels
251
252
253def set_xticklabels(ax, coord_labels):
254    """Set xticklabels to label list using Matplotlib default formatter."""
255    ax.xaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10])
256    xticks = ax.get_xticks().astype(np.int64)
257    xticks = xticks[(xticks >= 0) & (xticks < len(coord_labels))]
258    if len(xticks) > len(coord_labels):
259        ax.set_xticks(np.arange(len(coord_labels)))
260        ax.set_xticklabels(coord_labels)
261    else:
262        ax.set_xticks(xticks)
263        ax.set_xticklabels(coord_labels[xticks])
264
265
266def filter_plotters_list(plotters, plot_kind):
267    """Cut list of plotters so that it is at most of length "plot.max_subplots"."""
268    max_plots = rcParams["plot.max_subplots"]
269    max_plots = len(plotters) if max_plots is None else max_plots
270    if len(plotters) > max_plots:
271        warnings.warn(
272            "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
273            "of variables to plot ({len_plotters}) in {plot_kind}, generating only "
274            "{max_plots} plots".format(
275                max_plots=max_plots, len_plotters=len(plotters), plot_kind=plot_kind
276            ),
277            UserWarning,
278        )
279        return plotters[:max_plots]
280    return plotters
281
282
283def get_plotting_function(plot_name, plot_module, backend):
284    """Return plotting function for correct backend."""
285    _backend = {
286        "mpl": "matplotlib",
287        "bokeh": "bokeh",
288        "matplotlib": "matplotlib",
289    }
290
291    if backend is None:
292        backend = rcParams["plot.backend"]
293    backend = backend.lower()
294
295    try:
296        backend = _backend[backend]
297    except KeyError as err:
298        raise KeyError(
299            "Backend {} is not implemented. Try backend in {}".format(
300                backend, set(_backend.values())
301            )
302        ) from err
303
304    if backend == "bokeh":
305        try:
306            import bokeh
307
308            assert packaging.version.parse(bokeh.__version__) >= packaging.version.parse("1.4.0")
309
310        except (ImportError, AssertionError) as err:
311            raise ImportError(
312                "'bokeh' backend needs Bokeh (1.4.0+) installed." " Please upgrade or install"
313            ) from err
314
315    # Perform import of plotting method
316    # TODO: Convert module import to top level for all plots
317    module = importlib.import_module(f"arviz.plots.backends.{backend}.{plot_module}")
318
319    plotting_method = getattr(module, plot_name)
320
321    return plotting_method
322
323
324def calculate_point_estimate(point_estimate, values, bw="default", circular=False, skipna=False):
325    """Validate and calculate the point estimate.
326
327    Parameters
328    ----------
329    point_estimate : Optional[str]
330        Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
331        Defaults to 'auto' i.e. it falls back to default set in rcParams.
332    values : 1-d array
333    bw: Optional[float or str]
334        If numeric, indicates the bandwidth and must be positive.
335        If str, indicates the method to estimate the bandwidth and must be
336        one of "scott", "silverman", "isj" or "experimental" when `circular` is False
337        and "taylor" (for now) when `circular` is True.
338        Defaults to "default" which means "experimental" when variable is not circular
339        and "taylor" when it is.
340    circular: Optional[bool]
341        If True, it interprets the values passed are from a circular variable measured in radians
342        and a circular KDE is used. Only valid for 1D KDE. Defaults to False.
343    skipna=True,
344        If true ignores nan values when computing the hdi. Defaults to false.
345
346    Returns
347    -------
348    point_value : float
349        best estimate of data distribution
350    """
351    point_value = None
352    if point_estimate == "auto":
353        point_estimate = rcParams["plot.point_estimate"]
354    elif point_estimate not in ("mean", "median", "mode", None):
355        raise ValueError(
356            "Point estimate should be 'mean', 'median', 'mode' or None, not {}".format(
357                point_estimate
358            )
359        )
360    if point_estimate == "mean":
361        if skipna:
362            point_value = np.nanmean(values)
363        else:
364            point_value = np.mean(values)
365    elif point_estimate == "mode":
366        if values.dtype.kind == "f":
367            if bw == "default":
368                if circular:
369                    bw = "taylor"
370                else:
371                    bw = "experimental"
372            x, density = kde(values, circular=circular, bw=bw)
373            point_value = x[np.argmax(density)]
374        else:
375            point_value = mode(values)[0][0]
376    elif point_estimate == "median":
377        if skipna:
378            point_value = np.nanmedian(values)
379        else:
380            point_value = np.median(values)
381
382    return point_value
383
384
385def plot_point_interval(
386    ax,
387    values,
388    point_estimate,
389    hdi_prob,
390    quartiles,
391    linewidth,
392    markersize,
393    markercolor,
394    marker,
395    rotated,
396    intervalcolor,
397    backend="matplotlib",
398):
399    """Plot point intervals.
400
401    Translates the data and represents them as point and interval summaries.
402
403    Parameters
404    ----------
405    ax : axes
406        Matplotlib axes
407    values : array-like
408        Values to plot
409    point_estimate : str
410        Plot point estimate per variable.
411    linewidth : int
412        Line width throughout.
413    quartiles : bool
414        If True then the quartile interval will be plotted with the HDI.
415    markersize : int
416        Markersize throughout.
417    markercolor: string
418        Color of the marker.
419    marker: string
420        Shape of the marker.
421    hdi_prob : float
422        Valid only when point_interval is True. Plots HDI for chosen percentage of density.
423    rotated : bool
424        Whether to rotate the dot plot by 90 degrees.
425    intervalcolor : string
426        Color of the interval.
427    backend : string, optional
428        Matplotlib or Bokeh.
429    """
430    endpoint = (1 - hdi_prob) / 2
431    if quartiles:
432        qlist_interval = [endpoint, 0.25, 0.75, 1 - endpoint]
433    else:
434        qlist_interval = [endpoint, 1 - endpoint]
435    quantiles_interval = np.quantile(values, qlist_interval)
436
437    quantiles_interval[0], quantiles_interval[-1] = hdi(
438        values.flatten(), hdi_prob, multimodal=False
439    )
440    mid = len(quantiles_interval) // 2
441    param_iter = zip(np.linspace(2 * linewidth, linewidth, mid, endpoint=True)[-1::-1], range(mid))
442
443    if backend == "matplotlib":
444        for width, j in param_iter:
445            if rotated:
446                ax.vlines(
447                    0,
448                    quantiles_interval[j],
449                    quantiles_interval[-(j + 1)],
450                    linewidth=width,
451                    color=intervalcolor,
452                )
453            else:
454                ax.hlines(
455                    0,
456                    quantiles_interval[j],
457                    quantiles_interval[-(j + 1)],
458                    linewidth=width,
459                    color=intervalcolor,
460                )
461
462        if point_estimate:
463            point_value = calculate_point_estimate(point_estimate, values)
464            if rotated:
465                ax.plot(
466                    0,
467                    point_value,
468                    marker,
469                    markersize=markersize,
470                    color=markercolor,
471                )
472            else:
473                ax.plot(
474                    point_value,
475                    0,
476                    marker,
477                    markersize=markersize,
478                    color=markercolor,
479                )
480    else:
481        for width, j in param_iter:
482            if rotated:
483                ax.line(
484                    [0, 0],
485                    [quantiles_interval[j], quantiles_interval[-(j + 1)]],
486                    line_width=width,
487                    color=intervalcolor,
488                )
489            else:
490                ax.line(
491                    [quantiles_interval[j], quantiles_interval[-(j + 1)]],
492                    [0, 0],
493                    line_width=width,
494                    color=intervalcolor,
495                )
496
497        if point_estimate:
498            point_value = calculate_point_estimate(point_estimate, values)
499            if rotated:
500                ax.circle(
501                    x=0,
502                    y=point_value,
503                    size=markersize,
504                    fill_color=markercolor,
505                )
506            else:
507                ax.circle(
508                    x=point_value,
509                    y=0,
510                    size=markersize,
511                    fill_color=markercolor,
512                )
513
514    return ax
515
516
517def is_valid_quantile(value):
518    """Check if value is a number between 0 and 1."""
519    try:
520        value = float(value)
521        return 0 < value < 1
522    except ValueError:
523        return False
524
525
526def sample_reference_distribution(dist, shape):
527    """Generate samples from a scipy distribution with a given shape."""
528    x_ss = []
529    densities = []
530    dist_rvs = dist.rvs(size=shape)
531    for idx in range(shape[1]):
532        x_s, density = kde(dist_rvs[:, idx])
533        x_ss.append(x_s)
534        densities.append(density)
535    return np.array(x_ss).T, np.array(densities).T
536
537
538def set_bokeh_circular_ticks_labels(ax, hist, labels):
539    """Place ticks and ticklabels on Bokeh's circular histogram."""
540    ticks = np.linspace(-np.pi, np.pi, len(labels), endpoint=False)
541    ax.annular_wedge(
542        x=0,
543        y=0,
544        inner_radius=0,
545        outer_radius=np.max(hist) * 1.1,
546        start_angle=ticks,
547        end_angle=ticks,
548        line_color="grey",
549    )
550
551    radii_circles = np.linspace(0, np.max(hist) * 1.1, 4)
552    ax.circle(0, 0, radius=radii_circles, fill_color=None, line_color="grey")
553
554    offset = np.max(hist * 1.05) * 0.15
555    ticks_labels_pos_1 = np.max(hist * 1.05)
556    ticks_labels_pos_2 = ticks_labels_pos_1 * np.sqrt(2) / 2
557
558    ax.text(
559        [
560            ticks_labels_pos_1 + offset,
561            ticks_labels_pos_2 + offset,
562            0,
563            -ticks_labels_pos_2 - offset,
564            -ticks_labels_pos_1 - offset,
565            -ticks_labels_pos_2 - offset,
566            0,
567            ticks_labels_pos_2 + offset,
568        ],
569        [
570            0,
571            ticks_labels_pos_2 + offset / 2,
572            ticks_labels_pos_1 + offset,
573            ticks_labels_pos_2 + offset / 2,
574            0,
575            -ticks_labels_pos_2 - offset,
576            -ticks_labels_pos_1 - offset,
577            -ticks_labels_pos_2 - offset,
578        ],
579        text=labels,
580        text_align="center",
581    )
582
583    return ax
584
585
586def compute_ranks(ary):
587    """Compute ranks for continuous and discrete variables."""
588    if ary.dtype.kind == "i":
589        ary_shape = ary.shape
590        ary = ary.flatten()
591        min_ary, max_ary = min(ary), max(ary)
592        x = np.linspace(min_ary, max_ary, len(ary))
593        csi = CubicSpline(x, ary)
594        ary = csi(np.linspace(min_ary + 0.001, max_ary - 0.001, len(ary))).reshape(ary_shape)
595    ranks = rankdata(ary, method="average").reshape(ary.shape)
596
597    return ranks
598