1"""Matplotlib traceplot.""" 2import warnings 3from itertools import cycle 4 5from matplotlib import gridspec 6import matplotlib.pyplot as plt 7import numpy as np 8from matplotlib.lines import Line2D 9import matplotlib.ticker as mticker 10 11from ....stats.density_utils import get_bins 12from ...distplot import plot_dist 13from ...plot_utils import _scale_fig_size, format_coords_as_labels 14from ...rankplot import plot_rank 15from . import backend_kwarg_defaults, backend_show, dealiase_sel_kwargs, matplotlib_kwarg_dealiaser 16 17 18def plot_trace( 19 data, 20 var_names, # pylint: disable=unused-argument 21 divergences, 22 kind, 23 figsize, 24 rug, 25 lines, 26 circ_var_names, 27 circ_var_units, 28 compact, 29 compact_prop, 30 combined, 31 chain_prop, 32 legend, 33 labeller, 34 plot_kwargs, 35 fill_kwargs, 36 rug_kwargs, 37 hist_kwargs, 38 trace_kwargs, 39 rank_kwargs, 40 plotters, 41 divergence_data, 42 axes, 43 backend_kwargs, 44 backend_config, # pylint: disable=unused-argument 45 show, 46): 47 """Plot distribution (histogram or kernel density estimates) and sampled values. 48 49 If `divergences` data is available in `sample_stats`, will plot the location of divergences as 50 dashed vertical lines. 51 52 Parameters 53 ---------- 54 data : obj 55 Any object that can be converted to an az.InferenceData object 56 Refer to documentation of az.convert_to_dataset for details 57 var_names : string, or list of strings 58 One or more variables to be plotted. 59 divergences : {"bottom", "top", None, False} 60 Plot location of divergences on the traceplots. Options are "bottom", "top", or False-y. 61 kind : {"trace", "rank_bar", "rank_vlines"}, optional 62 Choose between plotting sampled values per iteration and rank plots. 63 figsize : figure size tuple 64 If None, size is (12, variables * 2) 65 rug : bool 66 If True adds a rugplot. Defaults to False. Ignored for 2D KDE. Only affects continuous 67 variables. 68 lines : tuple or list 69 List of tuple of (var_name, {'coord': selection}, [line_positions]) to be overplotted as 70 vertical lines on the density and horizontal lines on the trace. 71 circ_var_names : string, or list of strings 72 List of circular variables to account for when plotting KDE. 73 circ_var_units : str 74 Whether the variables in `circ_var_names` are in "degrees" or "radians". 75 combined : bool 76 Flag for combining multiple chains into a single line. If False (default), chains will be 77 plotted separately. 78 legend : bool 79 Add a legend to the figure with the chain color code. 80 plot_kwargs : dict 81 Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables. 82 fill_kwargs : dict 83 Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables. 84 rug_kwargs : dict 85 Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables. 86 hist_kwargs : dict 87 Extra keyword arguments passed to `arviz.plot_dist`. Only affects discrete variables. 88 trace_kwargs : dict 89 Extra keyword arguments passed to `plt.plot` 90 rank_kwargs : dict 91 Extra keyword arguments passed to `arviz.plot_rank` 92 Returns 93 ------- 94 axes : matplotlib axes 95 96 97 Examples 98 -------- 99 Plot a subset variables 100 101 .. plot:: 102 :context: close-figs 103 104 >>> import arviz as az 105 >>> data = az.load_arviz_data('non_centered_eight') 106 >>> coords = {'school': ['Choate', 'Lawrenceville']} 107 >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords) 108 109 Show all dimensions of multidimensional variables in the same plot 110 111 .. plot:: 112 :context: close-figs 113 114 >>> az.plot_trace(data, compact=True) 115 116 Combine all chains into one distribution 117 118 .. plot:: 119 :context: close-figs 120 121 >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, combined=True) 122 123 124 Plot reference lines against distribution and trace 125 126 .. plot:: 127 :context: close-figs 128 129 >>> lines = (('theta_t',{'school': "Choate"}, [-1]),) 130 >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines) 131 132 """ 133 # Set plot default backend kwargs 134 if backend_kwargs is None: 135 backend_kwargs = {} 136 137 if circ_var_names is None: 138 circ_var_names = [] 139 140 backend_kwargs = {**backend_kwarg_defaults(), **backend_kwargs} 141 142 if lines is None: 143 lines = () 144 145 num_chain_props = len(data.chain) + 1 if combined else len(data.chain) 146 if not compact: 147 chain_prop = "color" if chain_prop is None else chain_prop 148 else: 149 chain_prop = ( 150 { 151 "linestyle": ("solid", "dotted", "dashed", "dashdot"), 152 } 153 if chain_prop is None 154 else chain_prop 155 ) 156 compact_prop = "color" if compact_prop is None else compact_prop 157 158 if isinstance(chain_prop, str): 159 chain_prop = {chain_prop: plt.rcParams["axes.prop_cycle"].by_key()[chain_prop]} 160 if isinstance(chain_prop, tuple): 161 warnings.warn( 162 "chain_prop as a tuple will be deprecated in a future warning, use a dict instead", 163 FutureWarning, 164 ) 165 chain_prop = {chain_prop[0]: chain_prop[1]} 166 chain_prop = { 167 prop_name: [prop for _, prop in zip(range(num_chain_props), cycle(props))] 168 for prop_name, props in chain_prop.items() 169 } 170 171 if isinstance(compact_prop, str): 172 compact_prop = {compact_prop: plt.rcParams["axes.prop_cycle"].by_key()[compact_prop]} 173 if isinstance(compact_prop, tuple): 174 warnings.warn( 175 "compact_prop as a tuple will be deprecated in a future warning, use a dict instead", 176 FutureWarning, 177 ) 178 compact_prop = {compact_prop[0]: compact_prop[1]} 179 180 if figsize is None: 181 figsize = (12, len(plotters) * 2) 182 183 backend_kwargs.setdefault("figsize", figsize) 184 185 trace_kwargs = matplotlib_kwarg_dealiaser(trace_kwargs, "plot") 186 trace_kwargs.setdefault("alpha", 0.35) 187 188 hist_kwargs = matplotlib_kwarg_dealiaser(hist_kwargs, "hist") 189 hist_kwargs.setdefault("alpha", 0.35) 190 191 plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot") 192 fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "fill_between") 193 rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "scatter") 194 rank_kwargs = matplotlib_kwarg_dealiaser(rank_kwargs, "bar") 195 196 textsize = plot_kwargs.pop("textsize", 10) 197 198 figsize, _, titlesize, xt_labelsize, linewidth, _ = _scale_fig_size( 199 figsize, textsize, rows=len(plotters), cols=2 200 ) 201 202 trace_kwargs.setdefault("linewidth", linewidth) 203 plot_kwargs.setdefault("linewidth", linewidth) 204 205 # Check the input for lines 206 if lines is not None: 207 all_var_names = set(plotter[0] for plotter in plotters) 208 209 invalid_var_names = set() 210 for line in lines: 211 if line[0] not in all_var_names: 212 invalid_var_names.add(line[0]) 213 if invalid_var_names: 214 warnings.warn( 215 "A valid var_name should be provided, found {} expected from {}".format( 216 invalid_var_names, all_var_names 217 ) 218 ) 219 220 if axes is None: 221 fig = plt.figure(**backend_kwargs) 222 spec = gridspec.GridSpec(ncols=2, nrows=len(plotters), figure=fig) 223 224 # pylint: disable=too-many-nested-blocks 225 for idx, (var_name, selection, isel, value) in enumerate(plotters): 226 for idy in range(2): 227 value = np.atleast_2d(value) 228 229 circular = var_name in circ_var_names and not idy 230 if var_name in circ_var_names and idy: 231 circ_units_trace = circ_var_units 232 else: 233 circ_units_trace = False 234 235 if axes is None: 236 ax = fig.add_subplot(spec[idx, idy], polar=circular) 237 else: 238 ax = axes[idx, idy] 239 240 if len(value.shape) == 2: 241 if compact_prop: 242 aux_plot_kwargs = dealiase_sel_kwargs(plot_kwargs, compact_prop, 0) 243 aux_trace_kwargs = dealiase_sel_kwargs(trace_kwargs, compact_prop, 0) 244 else: 245 aux_plot_kwargs = plot_kwargs 246 aux_trace_kwargs = trace_kwargs 247 248 ax = _plot_chains_mpl( 249 ax, 250 idy, 251 value, 252 data, 253 chain_prop, 254 combined, 255 xt_labelsize, 256 rug, 257 kind, 258 aux_trace_kwargs, 259 hist_kwargs, 260 aux_plot_kwargs, 261 fill_kwargs, 262 rug_kwargs, 263 rank_kwargs, 264 circular, 265 circ_var_units, 266 circ_units_trace, 267 ) 268 269 else: 270 sub_data = data[var_name].sel(**selection) 271 legend_labels = format_coords_as_labels(sub_data, skip_dims=("chain", "draw")) 272 legend_title = ", ".join( 273 [ 274 f"{coord_name}" 275 for coord_name in sub_data.coords 276 if coord_name not in {"chain", "draw"} 277 ] 278 ) 279 value = value.reshape((value.shape[0], value.shape[1], -1)) 280 compact_prop_iter = { 281 prop_name: [prop for _, prop in zip(range(value.shape[2]), cycle(props))] 282 for prop_name, props in compact_prop.items() 283 } 284 handles = [] 285 for sub_idx, label in zip(range(value.shape[2]), legend_labels): 286 aux_plot_kwargs = dealiase_sel_kwargs(plot_kwargs, compact_prop_iter, sub_idx) 287 aux_trace_kwargs = dealiase_sel_kwargs(trace_kwargs, compact_prop_iter, sub_idx) 288 ax = _plot_chains_mpl( 289 ax, 290 idy, 291 value[..., sub_idx], 292 data, 293 chain_prop, 294 combined, 295 xt_labelsize, 296 rug, 297 kind, 298 aux_trace_kwargs, 299 hist_kwargs, 300 aux_plot_kwargs, 301 fill_kwargs, 302 rug_kwargs, 303 rank_kwargs, 304 circular, 305 circ_var_units, 306 circ_units_trace, 307 ) 308 if legend: 309 handles.append( 310 Line2D( 311 [], 312 [], 313 label=label, 314 **dealiase_sel_kwargs(aux_plot_kwargs, chain_prop, 0), 315 ) 316 ) 317 if legend and idy == 0: 318 ax.legend(handles=handles, title=legend_title) 319 320 if value[0].dtype.kind == "i" and idy == 0: 321 xticks = get_bins(value) 322 ax.set_xticks(xticks[:-1]) 323 y = 1 / textsize 324 if not idy: 325 ax.set_yticks([]) 326 if circular: 327 y = 0.13 if selection else 0.12 328 ax.set_title( 329 labeller.make_label_vert(var_name, selection, isel), 330 fontsize=titlesize, 331 wrap=True, 332 y=textsize * y, 333 ) 334 ax.tick_params(labelsize=xt_labelsize) 335 336 xlims = ax.get_xlim() 337 ylims = ax.get_ylim() 338 339 if divergences: 340 div_selection = {k: v for k, v in selection.items() if k in divergence_data.dims} 341 divs = divergence_data.sel(**div_selection).values 342 # if combined: 343 # divs = divs.flatten() 344 divs = np.atleast_2d(divs) 345 346 for chain, chain_divs in enumerate(divs): 347 div_draws = data.draw.values[chain_divs] 348 div_idxs = np.arange(len(chain_divs))[chain_divs] 349 if div_idxs.size > 0: 350 if divergences == "top": 351 ylocs = ylims[1] 352 else: 353 ylocs = ylims[0] 354 values = value[chain, div_idxs] 355 356 if circular: 357 tick = [ax.get_rmin() + ax.get_rmax() * 0.60, ax.get_rmax()] 358 for val in values: 359 ax.plot( 360 [val, val], 361 tick, 362 color="black", 363 markeredgewidth=1.5, 364 markersize=30, 365 alpha=trace_kwargs["alpha"], 366 zorder=0.6, 367 ) 368 else: 369 if kind == "trace" and idy: 370 ax.plot( 371 div_draws, 372 np.zeros_like(div_idxs) + ylocs, 373 marker="|", 374 color="black", 375 markeredgewidth=1.5, 376 markersize=30, 377 linestyle="None", 378 alpha=hist_kwargs["alpha"], 379 zorder=0.6, 380 ) 381 elif not idy: 382 ax.plot( 383 values, 384 np.zeros_like(values) + ylocs, 385 marker="|", 386 color="black", 387 markeredgewidth=1.5, 388 markersize=30, 389 linestyle="None", 390 alpha=trace_kwargs["alpha"], 391 zorder=0.6, 392 ) 393 394 for _, _, vlines in (j for j in lines if j[0] == var_name and j[1] == selection): 395 if isinstance(vlines, (float, int)): 396 line_values = [vlines] 397 else: 398 line_values = np.atleast_1d(vlines).ravel() 399 if not np.issubdtype(line_values.dtype, np.number): 400 raise ValueError(f"line-positions should be numeric, found {line_values}") 401 if idy: 402 ax.hlines( 403 line_values, 404 xlims[0], 405 xlims[1], 406 colors="black", 407 linewidth=1.5, 408 alpha=trace_kwargs["alpha"], 409 ) 410 411 else: 412 ax.vlines( 413 line_values, 414 ylims[0], 415 ylims[1], 416 colors="black", 417 linewidth=1.5, 418 alpha=trace_kwargs["alpha"], 419 ) 420 421 if kind == "trace" and idy: 422 ax.set_xlim(left=data.draw.min(), right=data.draw.max()) 423 424 if legend: 425 legend_kwargs = trace_kwargs if combined else plot_kwargs 426 handles = [ 427 Line2D( 428 [], [], label=chain_id, **dealiase_sel_kwargs(legend_kwargs, chain_prop, chain_id) 429 ) 430 for chain_id in range(data.dims["chain"]) 431 ] 432 if combined: 433 handles.insert( 434 0, 435 Line2D( 436 [], [], label="combined", **dealiase_sel_kwargs(plot_kwargs, chain_prop, -1) 437 ), 438 ) 439 ax.figure.axes[0].legend(handles=handles, title="chain", loc="upper right") 440 441 if axes is None: 442 axes = np.array(ax.figure.axes).reshape(-1, 2) 443 444 if backend_show(show): 445 plt.show() 446 447 return axes 448 449 450def _plot_chains_mpl( 451 axes, 452 idy, 453 value, 454 data, 455 chain_prop, 456 combined, 457 xt_labelsize, 458 rug, 459 kind, 460 trace_kwargs, 461 hist_kwargs, 462 plot_kwargs, 463 fill_kwargs, 464 rug_kwargs, 465 rank_kwargs, 466 circular, 467 circ_var_units, 468 circ_units_trace, 469): 470 471 if not circular: 472 circ_var_units = False 473 474 for chain_idx, row in enumerate(value): 475 if kind == "trace": 476 aux_kwargs = dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx) 477 if idy: 478 axes.plot(data.draw.values, row, **aux_kwargs) 479 if circ_units_trace == "degrees": 480 y_tick_locs = axes.get_yticks() 481 y_tick_labels = [i + 2 * 180 if i < 0 else i for i in np.rad2deg(y_tick_locs)] 482 axes.yaxis.set_major_locator(mticker.FixedLocator(y_tick_locs)) 483 axes.set_yticklabels([f"{i:.0f}°" for i in y_tick_labels]) 484 485 if not combined: 486 aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, chain_idx) 487 if not idy: 488 axes = plot_dist( 489 values=row, 490 textsize=xt_labelsize, 491 rug=rug, 492 ax=axes, 493 hist_kwargs=hist_kwargs, 494 plot_kwargs=aux_kwargs, 495 fill_kwargs=fill_kwargs, 496 rug_kwargs=rug_kwargs, 497 backend="matplotlib", 498 show=False, 499 is_circular=circ_var_units, 500 ) 501 502 if kind == "rank_bars" and idy: 503 axes = plot_rank(data=value, kind="bars", ax=axes, **rank_kwargs) 504 elif kind == "rank_vlines" and idy: 505 axes = plot_rank(data=value, kind="vlines", ax=axes, **rank_kwargs) 506 507 if combined: 508 aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, -1) 509 if not idy: 510 axes = plot_dist( 511 values=value.flatten(), 512 textsize=xt_labelsize, 513 rug=rug, 514 ax=axes, 515 hist_kwargs=hist_kwargs, 516 plot_kwargs=aux_kwargs, 517 fill_kwargs=fill_kwargs, 518 rug_kwargs=rug_kwargs, 519 backend="matplotlib", 520 show=False, 521 is_circular=circ_var_units, 522 ) 523 return axes 524