1# pylint: disable=c-extension-no-member
2"""Bokeh KDE Plot."""
3from collections.abc import Callable
4from numbers import Integral
5
6from matplotlib import _contour
7import numpy as np
8from bokeh.models import ColumnDataSource, Range1d
9from bokeh.models.glyphs import Scatter
10from matplotlib.cm import get_cmap
11from matplotlib.colors import rgb2hex
12from matplotlib.pyplot import rcParams as mpl_rcParams
13
14from ...plot_utils import _scale_fig_size
15from .. import show_layout
16from . import backend_kwarg_defaults, create_axes_grid
17
18
19def plot_kde(
20    density,
21    lower,
22    upper,
23    density_q,
24    xmin,
25    xmax,
26    ymin,
27    ymax,
28    gridsize,
29    values,
30    values2,
31    rug,
32    label,  # pylint: disable=unused-argument
33    quantiles,
34    rotated,
35    contour,
36    fill_last,
37    figsize,
38    textsize,  # pylint: disable=unused-argument
39    plot_kwargs,
40    fill_kwargs,
41    rug_kwargs,
42    contour_kwargs,
43    contourf_kwargs,
44    pcolormesh_kwargs,
45    is_circular,  # pylint: disable=unused-argument
46    ax,
47    legend,  # pylint: disable=unused-argument
48    backend_kwargs,
49    show,
50    return_glyph,
51):
52    """Bokeh kde plot."""
53    if backend_kwargs is None:
54        backend_kwargs = {}
55
56    backend_kwargs = {
57        **backend_kwarg_defaults(),
58        **backend_kwargs,
59    }
60
61    figsize, *_ = _scale_fig_size(figsize, textsize)
62
63    if ax is None:
64        ax = create_axes_grid(
65            1,
66            figsize=figsize,
67            squeeze=True,
68            backend_kwargs=backend_kwargs,
69        )
70
71    glyphs = []
72    if values2 is None:
73        if plot_kwargs is None:
74            plot_kwargs = {}
75        plot_kwargs.setdefault("line_color", mpl_rcParams["axes.prop_cycle"].by_key()["color"][0])
76
77        if fill_kwargs is None:
78            fill_kwargs = {}
79
80        fill_kwargs.setdefault("fill_color", mpl_rcParams["axes.prop_cycle"].by_key()["color"][0])
81
82        if rug:
83            if rug_kwargs is None:
84                rug_kwargs = {}
85
86            rug_kwargs = rug_kwargs.copy()
87            if "cds" in rug_kwargs:
88                cds_rug = rug_kwargs.pop("cds")
89                rug_varname = rug_kwargs.pop("y", "y")
90            else:
91                rug_varname = "y"
92                cds_rug = ColumnDataSource({rug_varname: np.asarray(values)})
93
94            rug_kwargs.setdefault("size", 8)
95            rug_kwargs.setdefault("line_color", plot_kwargs["line_color"])
96            rug_kwargs.setdefault("line_width", 1)
97            rug_kwargs.setdefault("line_alpha", 0.35)
98            if not rotated:
99                rug_kwargs.setdefault("angle", np.pi / 2)
100            if isinstance(cds_rug, dict):
101                for _cds_rug in cds_rug.values():
102                    if not rotated:
103                        glyph = Scatter(x=rug_varname, y=0.0, marker="dash", **rug_kwargs)
104                    else:
105                        glyph = Scatter(x=0.0, y=rug_varname, marker="dash", **rug_kwargs)
106                    ax.add_glyph(_cds_rug, glyph)
107            else:
108                if not rotated:
109                    glyph = Scatter(x=rug_varname, y=0.0, marker="dash", **rug_kwargs)
110                else:
111                    glyph = Scatter(x=0.0, y=rug_varname, marker="dash", **rug_kwargs)
112                ax.add_glyph(cds_rug, glyph)
113            glyphs.append(glyph)
114
115        x = np.linspace(lower, upper, len(density))
116
117        if quantiles is not None:
118            fill_kwargs.setdefault("fill_alpha", 0.75)
119            fill_kwargs.setdefault("line_color", None)
120
121            quantiles = sorted(np.clip(quantiles, 0, 1))
122            if quantiles[0] != 0:
123                quantiles = [0] + quantiles
124            if quantiles[-1] != 1:
125                quantiles = quantiles + [1]
126
127            for quant_0, quant_1 in zip(quantiles[:-1], quantiles[1:]):
128                idx = (density_q > quant_0) & (density_q < quant_1)
129                if idx.sum():
130                    patch_x = np.concatenate((x[idx], [x[idx][-1]], x[idx][::-1], [x[idx][0]]))
131                    patch_y = np.concatenate(
132                        (np.zeros_like(density[idx]), [density[idx][-1]], density[idx][::-1], [0])
133                    )
134                    if not rotated:
135                        patch = ax.patch(patch_x, patch_y, **fill_kwargs)
136                    else:
137                        patch = ax.patch(patch_y, patch_x, **fill_kwargs)
138                    glyphs.append(patch)
139        else:
140            if fill_kwargs.get("fill_alpha", False):
141                patch_x = np.concatenate((x, [x[-1]], x[::-1], [x[0]]))
142                patch_y = np.concatenate(
143                    (np.zeros_like(density), [density[-1]], density[::-1], [0])
144                )
145                if not rotated:
146                    patch = ax.patch(patch_x, patch_y, **fill_kwargs)
147                else:
148                    patch = ax.patch(patch_y, patch_x, **fill_kwargs)
149                glyphs.append(patch)
150
151            if label is not None:
152                plot_kwargs.setdefault("legend_label", label)
153            if not rotated:
154                line = ax.line(x, density, **plot_kwargs)
155            else:
156                line = ax.line(density, x, **plot_kwargs)
157            glyphs.append(line)
158
159    else:
160        if contour_kwargs is None:
161            contour_kwargs = {}
162        if contourf_kwargs is None:
163            contourf_kwargs = {}
164        if pcolormesh_kwargs is None:
165            pcolormesh_kwargs = {}
166
167        g_s = complex(gridsize[0])
168        x_x, y_y = np.mgrid[xmin:xmax:g_s, ymin:ymax:g_s]
169
170        if contour:
171
172            scaled_density, *scaled_density_args = _scale_axis(density)
173
174            contour_generator = _contour.QuadContourGenerator(
175                x_x, y_y, scaled_density, None, True, 0
176            )
177
178            levels = 9
179            if "levels" in contourf_kwargs:
180                levels = contourf_kwargs.pop("levels")
181            if "levels" in contour_kwargs:
182                levels = contour_kwargs.pop("levels")
183
184            if isinstance(levels, Integral):
185                levels_scaled = np.linspace(0, 1, levels + 2)
186                levels = _rescale_axis(levels_scaled, scaled_density_args)
187            else:
188                levels_scaled_nonclip, *_ = _scale_axis(np.asarray(levels), scaled_density_args)
189                levels_scaled = np.clip(levels_scaled_nonclip, 0, 1)
190
191            cmap = contourf_kwargs.pop("cmap", "viridis")
192            if isinstance(cmap, str):
193                cmap = get_cmap(cmap)
194            if isinstance(cmap, Callable):
195                colors = [rgb2hex(item) for item in cmap(np.linspace(0, 1, len(levels_scaled) + 1))]
196            else:
197                colors = cmap
198
199            contour_kwargs.update(contourf_kwargs)
200            contour_kwargs.setdefault("line_alpha", 0.25)
201            contour_kwargs.setdefault("fill_alpha", 1)
202
203            for i, (level, level_upper, color) in enumerate(
204                zip(levels_scaled[:-1], levels_scaled[1:], colors[1:])
205            ):
206                if not fill_last and (i == 0):
207                    continue
208                contour_kwargs_ = contour_kwargs.copy()
209                contour_kwargs_.setdefault("line_color", color)
210                contour_kwargs_.setdefault("fill_color", color)
211                vertices, _ = contour_generator.create_filled_contour(level, level_upper)
212                for seg in vertices:
213                    # ax.multi_polygon would be better, but input is
214                    # currently not suitable
215                    # seg is 1 line that defines an area
216                    # multi_polygon would need inner and outer edges
217                    # as a line
218                    patch = ax.patch(*seg.T, **contour_kwargs_)
219                    glyphs.append(patch)
220
221            if fill_last:
222                ax.background_fill_color = colors[0]
223
224            ax.xgrid.grid_line_color = None
225            ax.ygrid.grid_line_color = None
226
227            ax.x_range = Range1d(xmin, xmax)
228            ax.y_range = Range1d(ymin, ymax)
229
230        else:
231
232            cmap = pcolormesh_kwargs.pop("cmap", "viridis")
233            if isinstance(cmap, str):
234                cmap = get_cmap(cmap)
235            if isinstance(cmap, Callable):
236                colors = [rgb2hex(item) for item in cmap(np.linspace(0, 1, 256))]
237            else:
238                colors = cmap
239
240            image = ax.image(
241                image=[density.T],
242                x=xmin,
243                y=ymin,
244                dw=(xmax - xmin) / density.shape[0],
245                dh=(ymax - ymin) / density.shape[1],
246                palette=colors,
247                **pcolormesh_kwargs
248            )
249            glyphs.append(image)
250            ax.x_range.range_padding = ax.y_range.range_padding = 0
251
252    show_layout(ax, show)
253
254    if return_glyph:
255        return ax, glyphs
256
257    return ax
258
259
260def _scale_axis(arr, args=None):
261    if args:
262        amin, amax = args
263    else:
264        amin, amax = arr.min(), arr.max()
265    scaled_arr = arr - amin
266    scaled_arr /= amax - amin
267    return scaled_arr, amin, amax
268
269
270def _rescale_axis(arr, args):
271    amin, amax = args
272    rescaled_arr = arr * (amax - amin)
273    rescaled_arr += amin
274    return rescaled_arr
275