1"""Matplotlib Violinplot.""" 2import matplotlib.pyplot as plt 3import numpy as np 4 5from ....stats import hdi 6from ....stats.density_utils import get_bins, histogram, kde 7from ...plot_utils import _scale_fig_size 8from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser 9 10 11def plot_violin( 12 ax, 13 plotters, 14 figsize, 15 rows, 16 cols, 17 sharex, 18 sharey, 19 shade_kwargs, 20 shade, 21 rug, 22 rug_kwargs, 23 bw, 24 textsize, 25 labeller, 26 circular, 27 hdi_prob, 28 quartiles, 29 backend_kwargs, 30 show, 31): 32 """Matplotlib violin plot.""" 33 if backend_kwargs is None: 34 backend_kwargs = {} 35 36 backend_kwargs = { 37 **backend_kwarg_defaults(), 38 **backend_kwargs, 39 } 40 41 (figsize, ax_labelsize, _, xt_labelsize, linewidth, _) = _scale_fig_size( 42 figsize, textsize, rows, cols 43 ) 44 backend_kwargs.setdefault("figsize", figsize) 45 backend_kwargs.setdefault("sharex", sharex) 46 backend_kwargs.setdefault("sharey", sharey) 47 backend_kwargs.setdefault("squeeze", True) 48 49 shade_kwargs = matplotlib_kwarg_dealiaser(shade_kwargs, "hexbin") 50 rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot") 51 rug_kwargs.setdefault("alpha", 0.1) 52 rug_kwargs.setdefault("marker", ".") 53 rug_kwargs.setdefault("linestyle", "") 54 55 if ax is None: 56 fig, ax = create_axes_grid( 57 len(plotters), 58 rows, 59 cols, 60 backend_kwargs=backend_kwargs, 61 ) 62 fig.set_constrained_layout(False) 63 fig.subplots_adjust(wspace=0) 64 65 ax = np.atleast_1d(ax) 66 67 current_col = 0 68 for (var_name, selection, isel, x), ax_ in zip(plotters, ax.flatten()): 69 val = x.flatten() 70 if val[0].dtype.kind == "i": 71 dens = cat_hist(val, rug, shade, ax_, **shade_kwargs) 72 else: 73 dens = _violinplot(val, rug, shade, bw, circular, ax_, **shade_kwargs) 74 75 if rug: 76 rug_x = -np.abs(np.random.normal(scale=max(dens) / 3.5, size=len(val))) 77 ax_.plot(rug_x, val, **rug_kwargs) 78 79 per = np.nanpercentile(val, [25, 75, 50]) 80 hdi_probs = hdi(val, hdi_prob, multimodal=False, skipna=True) 81 82 if quartiles: 83 ax_.plot([0, 0], per[:2], lw=linewidth * 3, color="k", solid_capstyle="round") 84 ax_.plot([0, 0], hdi_probs, lw=linewidth, color="k", solid_capstyle="round") 85 ax_.plot(0, per[-1], "wo", ms=linewidth * 1.5) 86 87 ax_.set_title(labeller.make_label_vert(var_name, selection, isel), fontsize=ax_labelsize) 88 ax_.set_xticks([]) 89 ax_.tick_params(labelsize=xt_labelsize) 90 ax_.grid(None, axis="x") 91 if current_col != 0: 92 ax_.spines["left"].set_visible(False) 93 ax_.yaxis.set_ticks_position("none") 94 current_col += 1 95 if current_col == cols: 96 current_col = 0 97 98 if backend_show(show): 99 plt.show() 100 101 return ax 102 103 104def _violinplot(val, rug, shade, bw, circular, ax, **shade_kwargs): 105 """Auxiliary function to plot violinplots.""" 106 if bw == "default": 107 if circular: 108 bw = "taylor" 109 else: 110 bw = "experimental" 111 x, density = kde(val, circular=circular, bw=bw) 112 113 if not rug: 114 x = np.concatenate([x, x[::-1]]) 115 density = np.concatenate([-density, density[::-1]]) 116 117 ax.fill_betweenx(x, density, alpha=shade, lw=0, **shade_kwargs) 118 return density 119 120 121def cat_hist(val, rug, shade, ax, **shade_kwargs): 122 """Auxiliary function to plot discrete-violinplots.""" 123 bins = get_bins(val) 124 _, binned_d, _ = histogram(val, bins=bins) 125 126 bin_edges = np.linspace(np.min(val), np.max(val), len(bins)) 127 heights = np.diff(bin_edges) 128 centers = bin_edges[:-1] + heights.mean() / 2 129 130 if rug: 131 left = None 132 else: 133 left = -0.5 * binned_d 134 135 ax.barh(centers, binned_d, height=heights, left=left, alpha=shade, **shade_kwargs) 136 return binned_d 137