1"""Matplotlib Violinplot."""
2import matplotlib.pyplot as plt
3import numpy as np
4
5from ....stats import hdi
6from ....stats.density_utils import get_bins, histogram, kde
7from ...plot_utils import _scale_fig_size
8from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
9
10
11def plot_violin(
12    ax,
13    plotters,
14    figsize,
15    rows,
16    cols,
17    sharex,
18    sharey,
19    shade_kwargs,
20    shade,
21    rug,
22    rug_kwargs,
23    bw,
24    textsize,
25    labeller,
26    circular,
27    hdi_prob,
28    quartiles,
29    backend_kwargs,
30    show,
31):
32    """Matplotlib violin plot."""
33    if backend_kwargs is None:
34        backend_kwargs = {}
35
36    backend_kwargs = {
37        **backend_kwarg_defaults(),
38        **backend_kwargs,
39    }
40
41    (figsize, ax_labelsize, _, xt_labelsize, linewidth, _) = _scale_fig_size(
42        figsize, textsize, rows, cols
43    )
44    backend_kwargs.setdefault("figsize", figsize)
45    backend_kwargs.setdefault("sharex", sharex)
46    backend_kwargs.setdefault("sharey", sharey)
47    backend_kwargs.setdefault("squeeze", True)
48
49    shade_kwargs = matplotlib_kwarg_dealiaser(shade_kwargs, "hexbin")
50    rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot")
51    rug_kwargs.setdefault("alpha", 0.1)
52    rug_kwargs.setdefault("marker", ".")
53    rug_kwargs.setdefault("linestyle", "")
54
55    if ax is None:
56        fig, ax = create_axes_grid(
57            len(plotters),
58            rows,
59            cols,
60            backend_kwargs=backend_kwargs,
61        )
62        fig.set_constrained_layout(False)
63        fig.subplots_adjust(wspace=0)
64
65    ax = np.atleast_1d(ax)
66
67    current_col = 0
68    for (var_name, selection, isel, x), ax_ in zip(plotters, ax.flatten()):
69        val = x.flatten()
70        if val[0].dtype.kind == "i":
71            dens = cat_hist(val, rug, shade, ax_, **shade_kwargs)
72        else:
73            dens = _violinplot(val, rug, shade, bw, circular, ax_, **shade_kwargs)
74
75        if rug:
76            rug_x = -np.abs(np.random.normal(scale=max(dens) / 3.5, size=len(val)))
77            ax_.plot(rug_x, val, **rug_kwargs)
78
79        per = np.nanpercentile(val, [25, 75, 50])
80        hdi_probs = hdi(val, hdi_prob, multimodal=False, skipna=True)
81
82        if quartiles:
83            ax_.plot([0, 0], per[:2], lw=linewidth * 3, color="k", solid_capstyle="round")
84        ax_.plot([0, 0], hdi_probs, lw=linewidth, color="k", solid_capstyle="round")
85        ax_.plot(0, per[-1], "wo", ms=linewidth * 1.5)
86
87        ax_.set_title(labeller.make_label_vert(var_name, selection, isel), fontsize=ax_labelsize)
88        ax_.set_xticks([])
89        ax_.tick_params(labelsize=xt_labelsize)
90        ax_.grid(None, axis="x")
91        if current_col != 0:
92            ax_.spines["left"].set_visible(False)
93            ax_.yaxis.set_ticks_position("none")
94        current_col += 1
95        if current_col == cols:
96            current_col = 0
97
98    if backend_show(show):
99        plt.show()
100
101    return ax
102
103
104def _violinplot(val, rug, shade, bw, circular, ax, **shade_kwargs):
105    """Auxiliary function to plot violinplots."""
106    if bw == "default":
107        if circular:
108            bw = "taylor"
109        else:
110            bw = "experimental"
111    x, density = kde(val, circular=circular, bw=bw)
112
113    if not rug:
114        x = np.concatenate([x, x[::-1]])
115        density = np.concatenate([-density, density[::-1]])
116
117    ax.fill_betweenx(x, density, alpha=shade, lw=0, **shade_kwargs)
118    return density
119
120
121def cat_hist(val, rug, shade, ax, **shade_kwargs):
122    """Auxiliary function to plot discrete-violinplots."""
123    bins = get_bins(val)
124    _, binned_d, _ = histogram(val, bins=bins)
125
126    bin_edges = np.linspace(np.min(val), np.max(val), len(bins))
127    heights = np.diff(bin_edges)
128    centers = bin_edges[:-1] + heights.mean() / 2
129
130    if rug:
131        left = None
132    else:
133        left = -0.5 * binned_d
134
135    ax.barh(centers, binned_d, height=heights, left=left, alpha=shade, **shade_kwargs)
136    return binned_d
137