1"""Matplotib Bayesian p-value Posterior predictive plot."""
2import matplotlib.pyplot as plt
3import numpy as np
4from scipy import stats
5
6from ....stats.density_utils import kde
7from ....stats.stats_utils import smooth_data
8from ...kdeplot import plot_kde
9from ...plot_utils import (
10    _scale_fig_size,
11    is_valid_quantile,
12    sample_reference_distribution,
13)
14from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
15
16
17def plot_bpv(
18    ax,
19    length_plotters,
20    rows,
21    cols,
22    obs_plotters,
23    pp_plotters,
24    total_pp_samples,
25    kind,
26    t_stat,
27    bpv,
28    plot_mean,
29    reference,
30    mse,
31    n_ref,
32    hdi_prob,
33    color,
34    figsize,
35    textsize,
36    labeller,
37    plot_ref_kwargs,
38    backend_kwargs,
39    show,
40):
41    """Matplotlib bpv plot."""
42    if backend_kwargs is None:
43        backend_kwargs = {}
44
45    backend_kwargs = {
46        **backend_kwarg_defaults(),
47        **backend_kwargs,
48    }
49
50    figsize, ax_labelsize, _, _, linewidth, markersize = _scale_fig_size(
51        figsize, textsize, rows, cols
52    )
53
54    backend_kwargs.setdefault("figsize", figsize)
55    backend_kwargs.setdefault("squeeze", True)
56
57    if (kind == "u_value") and (reference == "analytical"):
58        plot_ref_kwargs = matplotlib_kwarg_dealiaser(plot_ref_kwargs, "fill_between")
59    else:
60        plot_ref_kwargs = matplotlib_kwarg_dealiaser(plot_ref_kwargs, "plot")
61
62    if kind == "p_value" and reference == "analytical":
63        plot_ref_kwargs.setdefault("color", "k")
64        plot_ref_kwargs.setdefault("linestyle", "--")
65    elif kind == "u_value" and reference == "analytical":
66        plot_ref_kwargs.setdefault("color", "k")
67        plot_ref_kwargs.setdefault("alpha", 0.2)
68    else:
69        plot_ref_kwargs.setdefault("alpha", 0.1)
70        plot_ref_kwargs.setdefault("color", color)
71
72    if ax is None:
73        _, axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs)
74    else:
75        axes = np.asarray(ax)
76        if axes.size < length_plotters:
77            raise ValueError(
78                (
79                    "Found {} variables to plot but {} axes instances. "
80                    "Axes instances must at minimum be equal to variables."
81                ).format(length_plotters, axes.size)
82            )
83
84    for i, ax_i in enumerate(np.ravel(axes)[:length_plotters]):
85        var_name, selection, isel, obs_vals = obs_plotters[i]
86        pp_var_name, _, _, pp_vals = pp_plotters[i]
87
88        obs_vals = obs_vals.flatten()
89        pp_vals = pp_vals.reshape(total_pp_samples, -1)
90
91        if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
92            obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
93
94        if kind == "p_value":
95            tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
96            x_s, tstat_pit_dens = kde(tstat_pit)
97            ax_i.plot(x_s, tstat_pit_dens, linewidth=linewidth, color=color)
98            ax_i.set_yticks([])
99            if reference is not None:
100                dist = stats.beta(obs_vals.size / 2, obs_vals.size / 2)
101                if reference == "analytical":
102                    lwb = dist.ppf((1 - 0.9999) / 2)
103                    upb = 1 - lwb
104                    x = np.linspace(lwb, upb, 500)
105                    dens_ref = dist.pdf(x)
106                    ax_i.plot(x, dens_ref, zorder=1, **plot_ref_kwargs)
107                elif reference == "samples":
108                    x_ss, u_dens = sample_reference_distribution(
109                        dist,
110                        (
111                            tstat_pit_dens.size,
112                            n_ref,
113                        ),
114                    )
115                    ax_i.plot(x_ss, u_dens, linewidth=linewidth, **plot_ref_kwargs)
116
117        elif kind == "u_value":
118            tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
119            x_s, tstat_pit_dens = kde(tstat_pit)
120            ax_i.plot(x_s, tstat_pit_dens, color=color)
121            if reference is not None:
122                if reference == "analytical":
123                    n_obs = obs_vals.size
124                    hdi_ = stats.beta(n_obs / 2, n_obs / 2).ppf((1 - hdi_prob) / 2)
125                    hdi_odds = (hdi_ / (1 - hdi_), (1 - hdi_) / hdi_)
126                    ax_i.axhspan(*hdi_odds, **plot_ref_kwargs)
127                    ax_i.axhline(1, color="w", zorder=1)
128                elif reference == "samples":
129                    dist = stats.uniform(0, 1)
130                    x_ss, u_dens = sample_reference_distribution(dist, (tstat_pit_dens.size, n_ref))
131                    ax_i.plot(x_ss, u_dens, linewidth=linewidth, **plot_ref_kwargs)
132            if mse:
133                ax_i.plot(0, 0, label=f"mse={np.mean((1 - tstat_pit_dens)**2) * 100:.2f}")
134                ax_i.legend()
135
136            ax_i.set_ylim(0, None)
137            ax_i.set_xlim(0, 1)
138        else:
139            if t_stat in ["mean", "median", "std"]:
140                if t_stat == "mean":
141                    tfunc = np.mean
142                elif t_stat == "median":
143                    tfunc = np.median
144                elif t_stat == "std":
145                    tfunc = np.std
146                obs_vals = tfunc(obs_vals)
147                pp_vals = tfunc(pp_vals, axis=1)
148            elif hasattr(t_stat, "__call__"):
149                obs_vals = t_stat(obs_vals.flatten())
150                pp_vals = t_stat(pp_vals)
151            elif is_valid_quantile(t_stat):
152                t_stat = float(t_stat)
153                obs_vals = np.quantile(obs_vals, q=t_stat)
154                pp_vals = np.quantile(pp_vals, q=t_stat, axis=1)
155            else:
156                raise ValueError(f"T statistics {t_stat} not implemented")
157
158            plot_kde(pp_vals, ax=ax_i, plot_kwargs={"color": color})
159            ax_i.set_yticks([])
160            if bpv:
161                p_value = np.mean(pp_vals <= obs_vals)
162                ax_i.plot(obs_vals, 0, label=f"bpv={p_value:.2f}", alpha=0)
163                ax_i.legend()
164
165            if plot_mean:
166                ax_i.plot(
167                    obs_vals.mean(), 0, "o", color=color, markeredgecolor="k", markersize=markersize
168                )
169
170        ax_i.set_title(
171            labeller.make_pp_label(var_name, pp_var_name, selection, isel), fontsize=ax_labelsize
172        )
173
174    if backend_show(show):
175        plt.show()
176
177    return axes
178