1"""Matplotlib energyplot.""" 2import matplotlib.pyplot as plt 3import numpy as np 4from scipy.stats import rankdata 5 6from ...plot_utils import _scale_fig_size 7from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser 8 9 10def plot_ess( 11 ax, 12 plotters, 13 xdata, 14 ess_tail_dataset, 15 mean_ess, 16 sd_ess, 17 idata, 18 data, 19 kind, 20 extra_methods, 21 textsize, 22 rows, 23 cols, 24 figsize, 25 kwargs, 26 extra_kwargs, 27 text_kwargs, 28 n_samples, 29 relative, 30 min_ess, 31 labeller, 32 ylabel, 33 rug, 34 rug_kind, 35 rug_kwargs, 36 hline_kwargs, 37 backend_kwargs, 38 show, 39): 40 """Matplotlib ess plot.""" 41 if backend_kwargs is None: 42 backend_kwargs = {} 43 44 backend_kwargs = { 45 **backend_kwarg_defaults(), 46 **backend_kwargs, 47 } 48 49 (figsize, ax_labelsize, titlesize, xt_labelsize, _linewidth, _markersize) = _scale_fig_size( 50 figsize, textsize, rows, cols 51 ) 52 backend_kwargs.setdefault("figsize", figsize) 53 backend_kwargs["squeeze"] = True 54 55 kwargs = matplotlib_kwarg_dealiaser(kwargs, "plot") 56 _linestyle = "-" if kind == "evolution" else "none" 57 kwargs.setdefault("linestyle", _linestyle) 58 kwargs.setdefault("linewidth", _linewidth) 59 kwargs.setdefault("markersize", _markersize) 60 kwargs.setdefault("marker", "o") 61 kwargs.setdefault("zorder", 3) 62 63 extra_kwargs = matplotlib_kwarg_dealiaser(extra_kwargs, "plot") 64 if kind == "evolution": 65 extra_kwargs = { 66 **extra_kwargs, 67 **{key: item for key, item in kwargs.items() if key not in extra_kwargs}, 68 } 69 kwargs.setdefault("label", "bulk") 70 extra_kwargs.setdefault("label", "tail") 71 else: 72 extra_kwargs.setdefault("linewidth", _linewidth / 2) 73 extra_kwargs.setdefault("color", "k") 74 extra_kwargs.setdefault("alpha", 0.5) 75 kwargs.setdefault("label", kind) 76 77 hline_kwargs = matplotlib_kwarg_dealiaser(hline_kwargs, "plot") 78 hline_kwargs.setdefault("linewidth", _linewidth) 79 hline_kwargs.setdefault("linestyle", "--") 80 hline_kwargs.setdefault("color", "gray") 81 hline_kwargs.setdefault("alpha", 0.7) 82 if extra_methods: 83 text_kwargs = matplotlib_kwarg_dealiaser(text_kwargs, "text") 84 text_x = text_kwargs.pop("x", 1) 85 text_kwargs.setdefault("fontsize", xt_labelsize * 0.7) 86 text_kwargs.setdefault("alpha", extra_kwargs["alpha"]) 87 text_kwargs.setdefault("color", extra_kwargs["color"]) 88 text_kwargs.setdefault("horizontalalignment", "right") 89 text_va = text_kwargs.pop("verticalalignment", None) 90 91 if ax is None: 92 _, ax = create_axes_grid( 93 len(plotters), 94 rows, 95 cols, 96 backend_kwargs=backend_kwargs, 97 ) 98 99 for (var_name, selection, isel, x), ax_ in zip(plotters, np.ravel(ax)): 100 ax_.plot(xdata, x, **kwargs) 101 if kind == "evolution": 102 ess_tail = ess_tail_dataset[var_name].sel(**selection) 103 ax_.plot(xdata, ess_tail, **extra_kwargs) 104 elif rug: 105 rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot") 106 if not hasattr(idata, "sample_stats"): 107 raise ValueError("InferenceData object must contain sample_stats for rug plot") 108 if not hasattr(idata.sample_stats, rug_kind): 109 raise ValueError(f"InferenceData does not contain {rug_kind} data") 110 rug_kwargs.setdefault("marker", "|") 111 rug_kwargs.setdefault("linestyle", rug_kwargs.pop("ls", "None")) 112 rug_kwargs.setdefault("color", rug_kwargs.pop("c", kwargs.get("color", "C0"))) 113 rug_kwargs.setdefault("space", 0.1) 114 rug_kwargs.setdefault("markersize", rug_kwargs.pop("ms", 2 * _markersize)) 115 116 values = data[var_name].sel(**selection).values.flatten() 117 mask = idata.sample_stats[rug_kind].values.flatten() 118 values = rankdata(values, method="average")[mask] 119 rug_space = np.max(x) * rug_kwargs.pop("space") 120 rug_x, rug_y = values / (len(mask) - 1), np.zeros_like(values) - rug_space 121 ax_.plot(rug_x, rug_y, **rug_kwargs) 122 ax_.axhline(0, color="k", linewidth=_linewidth, alpha=0.7) 123 if extra_methods: 124 mean_ess_i = mean_ess[var_name].sel(**selection).values.item() 125 sd_ess_i = sd_ess[var_name].sel(**selection).values.item() 126 ax_.axhline(mean_ess_i, **extra_kwargs) 127 ax_.annotate( 128 "mean", 129 (text_x, mean_ess_i), 130 va=text_va 131 if text_va is not None 132 else "bottom" 133 if mean_ess_i >= sd_ess_i 134 else "top", 135 **text_kwargs, 136 ) 137 ax_.axhline(sd_ess_i, **extra_kwargs) 138 ax_.annotate( 139 "sd", 140 (text_x, sd_ess_i), 141 va=text_va if text_va is not None else "bottom" if sd_ess_i > mean_ess_i else "top", 142 **text_kwargs, 143 ) 144 145 ax_.axhline(400 / n_samples if relative else min_ess, **hline_kwargs) 146 147 ax_.set_title( 148 labeller.make_label_vert(var_name, selection, isel), fontsize=titlesize, wrap=True 149 ) 150 ax_.tick_params(labelsize=xt_labelsize) 151 ax_.set_xlabel( 152 "Total number of draws" if kind == "evolution" else "Quantile", fontsize=ax_labelsize 153 ) 154 ax_.set_ylabel( 155 ylabel.format("Relative ESS" if relative else "ESS"), fontsize=ax_labelsize, wrap=True 156 ) 157 if kind == "evolution": 158 ax_.legend(title="Method", fontsize=xt_labelsize, title_fontsize=xt_labelsize) 159 else: 160 ax_.set_xlim(0, 1) 161 if rug: 162 ax_.yaxis.get_major_locator().set_params(nbins="auto", steps=[1, 2, 5, 10]) 163 _, ymax = ax_.get_ylim() 164 yticks = ax_.get_yticks().astype(np.int64) 165 yticks = yticks[(yticks >= 0) & (yticks < ymax)] 166 ax_.set_yticks(yticks) 167 ax_.set_yticklabels(yticks) 168 else: 169 ax_.set_ylim(bottom=0) 170 171 if backend_show(show): 172 plt.show() 173 174 return ax 175