1# pylint: disable=c-extension-no-member 2"""Bokeh KDE Plot.""" 3from collections.abc import Callable 4from numbers import Integral 5 6from matplotlib import _contour 7import numpy as np 8from bokeh.models import ColumnDataSource, Range1d 9from bokeh.models.glyphs import Scatter 10from matplotlib.cm import get_cmap 11from matplotlib.colors import rgb2hex 12from matplotlib.pyplot import rcParams as mpl_rcParams 13 14from ...plot_utils import _scale_fig_size 15from .. import show_layout 16from . import backend_kwarg_defaults, create_axes_grid 17 18 19def plot_kde( 20 density, 21 lower, 22 upper, 23 density_q, 24 xmin, 25 xmax, 26 ymin, 27 ymax, 28 gridsize, 29 values, 30 values2, 31 rug, 32 label, # pylint: disable=unused-argument 33 quantiles, 34 rotated, 35 contour, 36 fill_last, 37 figsize, 38 textsize, # pylint: disable=unused-argument 39 plot_kwargs, 40 fill_kwargs, 41 rug_kwargs, 42 contour_kwargs, 43 contourf_kwargs, 44 pcolormesh_kwargs, 45 is_circular, # pylint: disable=unused-argument 46 ax, 47 legend, # pylint: disable=unused-argument 48 backend_kwargs, 49 show, 50 return_glyph, 51): 52 """Bokeh kde plot.""" 53 if backend_kwargs is None: 54 backend_kwargs = {} 55 56 backend_kwargs = { 57 **backend_kwarg_defaults(), 58 **backend_kwargs, 59 } 60 61 figsize, *_ = _scale_fig_size(figsize, textsize) 62 63 if ax is None: 64 ax = create_axes_grid( 65 1, 66 figsize=figsize, 67 squeeze=True, 68 backend_kwargs=backend_kwargs, 69 ) 70 71 glyphs = [] 72 if values2 is None: 73 if plot_kwargs is None: 74 plot_kwargs = {} 75 plot_kwargs.setdefault("line_color", mpl_rcParams["axes.prop_cycle"].by_key()["color"][0]) 76 77 if fill_kwargs is None: 78 fill_kwargs = {} 79 80 fill_kwargs.setdefault("fill_color", mpl_rcParams["axes.prop_cycle"].by_key()["color"][0]) 81 82 if rug: 83 if rug_kwargs is None: 84 rug_kwargs = {} 85 86 rug_kwargs = rug_kwargs.copy() 87 if "cds" in rug_kwargs: 88 cds_rug = rug_kwargs.pop("cds") 89 rug_varname = rug_kwargs.pop("y", "y") 90 else: 91 rug_varname = "y" 92 cds_rug = ColumnDataSource({rug_varname: np.asarray(values)}) 93 94 rug_kwargs.setdefault("size", 8) 95 rug_kwargs.setdefault("line_color", plot_kwargs["line_color"]) 96 rug_kwargs.setdefault("line_width", 1) 97 rug_kwargs.setdefault("line_alpha", 0.35) 98 if not rotated: 99 rug_kwargs.setdefault("angle", np.pi / 2) 100 if isinstance(cds_rug, dict): 101 for _cds_rug in cds_rug.values(): 102 if not rotated: 103 glyph = Scatter(x=rug_varname, y=0.0, marker="dash", **rug_kwargs) 104 else: 105 glyph = Scatter(x=0.0, y=rug_varname, marker="dash", **rug_kwargs) 106 ax.add_glyph(_cds_rug, glyph) 107 else: 108 if not rotated: 109 glyph = Scatter(x=rug_varname, y=0.0, marker="dash", **rug_kwargs) 110 else: 111 glyph = Scatter(x=0.0, y=rug_varname, marker="dash", **rug_kwargs) 112 ax.add_glyph(cds_rug, glyph) 113 glyphs.append(glyph) 114 115 x = np.linspace(lower, upper, len(density)) 116 117 if quantiles is not None: 118 fill_kwargs.setdefault("fill_alpha", 0.75) 119 fill_kwargs.setdefault("line_color", None) 120 121 quantiles = sorted(np.clip(quantiles, 0, 1)) 122 if quantiles[0] != 0: 123 quantiles = [0] + quantiles 124 if quantiles[-1] != 1: 125 quantiles = quantiles + [1] 126 127 for quant_0, quant_1 in zip(quantiles[:-1], quantiles[1:]): 128 idx = (density_q > quant_0) & (density_q < quant_1) 129 if idx.sum(): 130 patch_x = np.concatenate((x[idx], [x[idx][-1]], x[idx][::-1], [x[idx][0]])) 131 patch_y = np.concatenate( 132 (np.zeros_like(density[idx]), [density[idx][-1]], density[idx][::-1], [0]) 133 ) 134 if not rotated: 135 patch = ax.patch(patch_x, patch_y, **fill_kwargs) 136 else: 137 patch = ax.patch(patch_y, patch_x, **fill_kwargs) 138 glyphs.append(patch) 139 else: 140 if fill_kwargs.get("fill_alpha", False): 141 patch_x = np.concatenate((x, [x[-1]], x[::-1], [x[0]])) 142 patch_y = np.concatenate( 143 (np.zeros_like(density), [density[-1]], density[::-1], [0]) 144 ) 145 if not rotated: 146 patch = ax.patch(patch_x, patch_y, **fill_kwargs) 147 else: 148 patch = ax.patch(patch_y, patch_x, **fill_kwargs) 149 glyphs.append(patch) 150 151 if label is not None: 152 plot_kwargs.setdefault("legend_label", label) 153 if not rotated: 154 line = ax.line(x, density, **plot_kwargs) 155 else: 156 line = ax.line(density, x, **plot_kwargs) 157 glyphs.append(line) 158 159 else: 160 if contour_kwargs is None: 161 contour_kwargs = {} 162 if contourf_kwargs is None: 163 contourf_kwargs = {} 164 if pcolormesh_kwargs is None: 165 pcolormesh_kwargs = {} 166 167 g_s = complex(gridsize[0]) 168 x_x, y_y = np.mgrid[xmin:xmax:g_s, ymin:ymax:g_s] 169 170 if contour: 171 172 scaled_density, *scaled_density_args = _scale_axis(density) 173 174 contour_generator = _contour.QuadContourGenerator( 175 x_x, y_y, scaled_density, None, True, 0 176 ) 177 178 levels = 9 179 if "levels" in contourf_kwargs: 180 levels = contourf_kwargs.pop("levels") 181 if "levels" in contour_kwargs: 182 levels = contour_kwargs.pop("levels") 183 184 if isinstance(levels, Integral): 185 levels_scaled = np.linspace(0, 1, levels + 2) 186 levels = _rescale_axis(levels_scaled, scaled_density_args) 187 else: 188 levels_scaled_nonclip, *_ = _scale_axis(np.asarray(levels), scaled_density_args) 189 levels_scaled = np.clip(levels_scaled_nonclip, 0, 1) 190 191 cmap = contourf_kwargs.pop("cmap", "viridis") 192 if isinstance(cmap, str): 193 cmap = get_cmap(cmap) 194 if isinstance(cmap, Callable): 195 colors = [rgb2hex(item) for item in cmap(np.linspace(0, 1, len(levels_scaled) + 1))] 196 else: 197 colors = cmap 198 199 contour_kwargs.update(contourf_kwargs) 200 contour_kwargs.setdefault("line_alpha", 0.25) 201 contour_kwargs.setdefault("fill_alpha", 1) 202 203 for i, (level, level_upper, color) in enumerate( 204 zip(levels_scaled[:-1], levels_scaled[1:], colors[1:]) 205 ): 206 if not fill_last and (i == 0): 207 continue 208 contour_kwargs_ = contour_kwargs.copy() 209 contour_kwargs_.setdefault("line_color", color) 210 contour_kwargs_.setdefault("fill_color", color) 211 vertices, _ = contour_generator.create_filled_contour(level, level_upper) 212 for seg in vertices: 213 # ax.multi_polygon would be better, but input is 214 # currently not suitable 215 # seg is 1 line that defines an area 216 # multi_polygon would need inner and outer edges 217 # as a line 218 patch = ax.patch(*seg.T, **contour_kwargs_) 219 glyphs.append(patch) 220 221 if fill_last: 222 ax.background_fill_color = colors[0] 223 224 ax.xgrid.grid_line_color = None 225 ax.ygrid.grid_line_color = None 226 227 ax.x_range = Range1d(xmin, xmax) 228 ax.y_range = Range1d(ymin, ymax) 229 230 else: 231 232 cmap = pcolormesh_kwargs.pop("cmap", "viridis") 233 if isinstance(cmap, str): 234 cmap = get_cmap(cmap) 235 if isinstance(cmap, Callable): 236 colors = [rgb2hex(item) for item in cmap(np.linspace(0, 1, 256))] 237 else: 238 colors = cmap 239 240 image = ax.image( 241 image=[density.T], 242 x=xmin, 243 y=ymin, 244 dw=(xmax - xmin) / density.shape[0], 245 dh=(ymax - ymin) / density.shape[1], 246 palette=colors, 247 **pcolormesh_kwargs 248 ) 249 glyphs.append(image) 250 ax.x_range.range_padding = ax.y_range.range_padding = 0 251 252 show_layout(ax, show) 253 254 if return_glyph: 255 return ax, glyphs 256 257 return ax 258 259 260def _scale_axis(arr, args=None): 261 if args: 262 amin, amax = args 263 else: 264 amin, amax = arr.min(), arr.max() 265 scaled_arr = arr - amin 266 scaled_arr /= amax - amin 267 return scaled_arr, amin, amax 268 269 270def _rescale_axis(arr, args): 271 amin, amax = args 272 rescaled_arr = arr * (amax - amin) 273 rescaled_arr += amin 274 return rescaled_arr 275