1"""Matplotib Bayesian p-value Posterior predictive plot.""" 2import matplotlib.pyplot as plt 3import numpy as np 4from scipy import stats 5 6from ....stats.density_utils import kde 7from ....stats.stats_utils import smooth_data 8from ...kdeplot import plot_kde 9from ...plot_utils import ( 10 _scale_fig_size, 11 is_valid_quantile, 12 sample_reference_distribution, 13) 14from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser 15 16 17def plot_bpv( 18 ax, 19 length_plotters, 20 rows, 21 cols, 22 obs_plotters, 23 pp_plotters, 24 total_pp_samples, 25 kind, 26 t_stat, 27 bpv, 28 plot_mean, 29 reference, 30 mse, 31 n_ref, 32 hdi_prob, 33 color, 34 figsize, 35 textsize, 36 labeller, 37 plot_ref_kwargs, 38 backend_kwargs, 39 show, 40): 41 """Matplotlib bpv plot.""" 42 if backend_kwargs is None: 43 backend_kwargs = {} 44 45 backend_kwargs = { 46 **backend_kwarg_defaults(), 47 **backend_kwargs, 48 } 49 50 figsize, ax_labelsize, _, _, linewidth, markersize = _scale_fig_size( 51 figsize, textsize, rows, cols 52 ) 53 54 backend_kwargs.setdefault("figsize", figsize) 55 backend_kwargs.setdefault("squeeze", True) 56 57 if (kind == "u_value") and (reference == "analytical"): 58 plot_ref_kwargs = matplotlib_kwarg_dealiaser(plot_ref_kwargs, "fill_between") 59 else: 60 plot_ref_kwargs = matplotlib_kwarg_dealiaser(plot_ref_kwargs, "plot") 61 62 if kind == "p_value" and reference == "analytical": 63 plot_ref_kwargs.setdefault("color", "k") 64 plot_ref_kwargs.setdefault("linestyle", "--") 65 elif kind == "u_value" and reference == "analytical": 66 plot_ref_kwargs.setdefault("color", "k") 67 plot_ref_kwargs.setdefault("alpha", 0.2) 68 else: 69 plot_ref_kwargs.setdefault("alpha", 0.1) 70 plot_ref_kwargs.setdefault("color", color) 71 72 if ax is None: 73 _, axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs) 74 else: 75 axes = np.asarray(ax) 76 if axes.size < length_plotters: 77 raise ValueError( 78 ( 79 "Found {} variables to plot but {} axes instances. " 80 "Axes instances must at minimum be equal to variables." 81 ).format(length_plotters, axes.size) 82 ) 83 84 for i, ax_i in enumerate(np.ravel(axes)[:length_plotters]): 85 var_name, selection, isel, obs_vals = obs_plotters[i] 86 pp_var_name, _, _, pp_vals = pp_plotters[i] 87 88 obs_vals = obs_vals.flatten() 89 pp_vals = pp_vals.reshape(total_pp_samples, -1) 90 91 if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i": 92 obs_vals, pp_vals = smooth_data(obs_vals, pp_vals) 93 94 if kind == "p_value": 95 tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1) 96 x_s, tstat_pit_dens = kde(tstat_pit) 97 ax_i.plot(x_s, tstat_pit_dens, linewidth=linewidth, color=color) 98 ax_i.set_yticks([]) 99 if reference is not None: 100 dist = stats.beta(obs_vals.size / 2, obs_vals.size / 2) 101 if reference == "analytical": 102 lwb = dist.ppf((1 - 0.9999) / 2) 103 upb = 1 - lwb 104 x = np.linspace(lwb, upb, 500) 105 dens_ref = dist.pdf(x) 106 ax_i.plot(x, dens_ref, zorder=1, **plot_ref_kwargs) 107 elif reference == "samples": 108 x_ss, u_dens = sample_reference_distribution( 109 dist, 110 ( 111 tstat_pit_dens.size, 112 n_ref, 113 ), 114 ) 115 ax_i.plot(x_ss, u_dens, linewidth=linewidth, **plot_ref_kwargs) 116 117 elif kind == "u_value": 118 tstat_pit = np.mean(pp_vals <= obs_vals, axis=0) 119 x_s, tstat_pit_dens = kde(tstat_pit) 120 ax_i.plot(x_s, tstat_pit_dens, color=color) 121 if reference is not None: 122 if reference == "analytical": 123 n_obs = obs_vals.size 124 hdi_ = stats.beta(n_obs / 2, n_obs / 2).ppf((1 - hdi_prob) / 2) 125 hdi_odds = (hdi_ / (1 - hdi_), (1 - hdi_) / hdi_) 126 ax_i.axhspan(*hdi_odds, **plot_ref_kwargs) 127 ax_i.axhline(1, color="w", zorder=1) 128 elif reference == "samples": 129 dist = stats.uniform(0, 1) 130 x_ss, u_dens = sample_reference_distribution(dist, (tstat_pit_dens.size, n_ref)) 131 ax_i.plot(x_ss, u_dens, linewidth=linewidth, **plot_ref_kwargs) 132 if mse: 133 ax_i.plot(0, 0, label=f"mse={np.mean((1 - tstat_pit_dens)**2) * 100:.2f}") 134 ax_i.legend() 135 136 ax_i.set_ylim(0, None) 137 ax_i.set_xlim(0, 1) 138 else: 139 if t_stat in ["mean", "median", "std"]: 140 if t_stat == "mean": 141 tfunc = np.mean 142 elif t_stat == "median": 143 tfunc = np.median 144 elif t_stat == "std": 145 tfunc = np.std 146 obs_vals = tfunc(obs_vals) 147 pp_vals = tfunc(pp_vals, axis=1) 148 elif hasattr(t_stat, "__call__"): 149 obs_vals = t_stat(obs_vals.flatten()) 150 pp_vals = t_stat(pp_vals) 151 elif is_valid_quantile(t_stat): 152 t_stat = float(t_stat) 153 obs_vals = np.quantile(obs_vals, q=t_stat) 154 pp_vals = np.quantile(pp_vals, q=t_stat, axis=1) 155 else: 156 raise ValueError(f"T statistics {t_stat} not implemented") 157 158 plot_kde(pp_vals, ax=ax_i, plot_kwargs={"color": color}) 159 ax_i.set_yticks([]) 160 if bpv: 161 p_value = np.mean(pp_vals <= obs_vals) 162 ax_i.plot(obs_vals, 0, label=f"bpv={p_value:.2f}", alpha=0) 163 ax_i.legend() 164 165 if plot_mean: 166 ax_i.plot( 167 obs_vals.mean(), 0, "o", color=color, markeredgecolor="k", markersize=markersize 168 ) 169 170 ax_i.set_title( 171 labeller.make_pp_label(var_name, pp_var_name, selection, isel), fontsize=ax_labelsize 172 ) 173 174 if backend_show(show): 175 plt.show() 176 177 return axes 178