1"""Matplotlib distplot."""
2import matplotlib.pyplot as plt
3from matplotlib import _pylab_helpers
4import numpy as np
5
6from ....stats.density_utils import get_bins
7from ...kdeplot import plot_kde
8from ...plot_utils import _scale_fig_size
9from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
10
11
12def plot_dist(
13    values,
14    values2,
15    color,
16    kind,
17    cumulative,
18    label,
19    rotated,
20    rug,
21    bw,
22    quantiles,
23    contour,
24    fill_last,
25    figsize,
26    textsize,
27    plot_kwargs,
28    fill_kwargs,
29    rug_kwargs,
30    contour_kwargs,
31    contourf_kwargs,
32    pcolormesh_kwargs,
33    hist_kwargs,
34    is_circular,
35    ax,
36    backend_kwargs,
37    show,
38):
39    """Matplotlib distplot."""
40    if backend_kwargs is None:
41        backend_kwargs = {}
42
43    backend_kwargs = {
44        **backend_kwarg_defaults(),
45        **backend_kwargs,
46    }
47
48    figsize, *_ = _scale_fig_size(figsize, textsize)
49
50    backend_kwargs.setdefault("figsize", figsize)
51    backend_kwargs["squeeze"] = True
52    backend_kwargs.setdefault("subplot_kw", {})
53    backend_kwargs["subplot_kw"].setdefault("polar", is_circular)
54
55    if ax is None:
56        fig_manager = _pylab_helpers.Gcf.get_active()
57        if fig_manager is not None:
58            ax = fig_manager.canvas.figure.gca()
59        else:
60            _, ax = create_axes_grid(
61                1,
62                backend_kwargs=backend_kwargs,
63            )
64
65    if kind == "hist":
66        hist_kwargs = matplotlib_kwarg_dealiaser(hist_kwargs, "hist")
67        hist_kwargs.setdefault("cumulative", cumulative)
68        hist_kwargs.setdefault("color", color)
69        hist_kwargs.setdefault("label", label)
70        hist_kwargs.setdefault("rwidth", 0.9)
71        hist_kwargs.setdefault("align", "left")
72        hist_kwargs.setdefault("density", True)
73
74        if rotated:
75            hist_kwargs.setdefault("orientation", "horizontal")
76        else:
77            hist_kwargs.setdefault("orientation", "vertical")
78
79        ax = _histplot_mpl_op(
80            values=values,
81            values2=values2,
82            rotated=rotated,
83            ax=ax,
84            hist_kwargs=hist_kwargs,
85            is_circular=is_circular,
86        )
87
88    elif kind == "kde":
89        plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
90        plot_kwargs.setdefault("color", color)
91        legend = label is not None
92
93        ax = plot_kde(
94            values,
95            values2,
96            cumulative=cumulative,
97            rug=rug,
98            label=label,
99            bw=bw,
100            quantiles=quantiles,
101            rotated=rotated,
102            contour=contour,
103            legend=legend,
104            fill_last=fill_last,
105            textsize=textsize,
106            plot_kwargs=plot_kwargs,
107            fill_kwargs=fill_kwargs,
108            rug_kwargs=rug_kwargs,
109            contour_kwargs=contour_kwargs,
110            contourf_kwargs=contourf_kwargs,
111            pcolormesh_kwargs=pcolormesh_kwargs,
112            ax=ax,
113            backend="matplotlib",
114            backend_kwargs=backend_kwargs,
115            is_circular=is_circular,
116            show=show,
117        )
118
119    if backend_show(show):
120        plt.show()
121
122    return ax
123
124
125def _histplot_mpl_op(values, values2, rotated, ax, hist_kwargs, is_circular):
126    """Add a histogram for the data to the axes."""
127    bins = hist_kwargs.pop("bins", None)
128
129    if is_circular == "degrees":
130        if bins is None:
131            bins = get_bins(values)
132        values = np.deg2rad(values)
133        bins = np.deg2rad(bins)
134
135    elif is_circular:
136        labels = [
137            "0",
138            f"{np.pi/4:.2f}",
139            f"{np.pi/2:.2f}",
140            f"{3*np.pi/4:.2f}",
141            f"{np.pi:.2f}",
142            f"{-3*np.pi/4:.2f}",
143            f"{-np.pi/2:.2f}",
144            f"{-np.pi/4:.2f}",
145        ]
146
147        ax.set_xticklabels(labels)
148
149    if values2 is not None:
150        raise NotImplementedError("Insert hexbin plot here")
151
152    if bins is None:
153        bins = get_bins(values)
154
155    n, bins, _ = ax.hist(np.asarray(values).flatten(), bins=bins, **hist_kwargs)
156
157    if rotated:
158        ax.set_yticks(bins[:-1])
159    elif not is_circular:
160        ax.set_xticks(bins[:-1])
161
162    if is_circular:
163        ax.set_ylim(0, 1.5 * n.max())
164        ax.set_yticklabels([])
165
166    if hist_kwargs.get("label") is not None:
167        ax.legend()
168
169    return ax
170