1"""Matplotlib energyplot."""
2import matplotlib.pyplot as plt
3import numpy as np
4from scipy.stats import rankdata
5
6from ...plot_utils import _scale_fig_size
7from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
8
9
10def plot_ess(
11    ax,
12    plotters,
13    xdata,
14    ess_tail_dataset,
15    mean_ess,
16    sd_ess,
17    idata,
18    data,
19    kind,
20    extra_methods,
21    textsize,
22    rows,
23    cols,
24    figsize,
25    kwargs,
26    extra_kwargs,
27    text_kwargs,
28    n_samples,
29    relative,
30    min_ess,
31    labeller,
32    ylabel,
33    rug,
34    rug_kind,
35    rug_kwargs,
36    hline_kwargs,
37    backend_kwargs,
38    show,
39):
40    """Matplotlib ess plot."""
41    if backend_kwargs is None:
42        backend_kwargs = {}
43
44    backend_kwargs = {
45        **backend_kwarg_defaults(),
46        **backend_kwargs,
47    }
48
49    (figsize, ax_labelsize, titlesize, xt_labelsize, _linewidth, _markersize) = _scale_fig_size(
50        figsize, textsize, rows, cols
51    )
52    backend_kwargs.setdefault("figsize", figsize)
53    backend_kwargs["squeeze"] = True
54
55    kwargs = matplotlib_kwarg_dealiaser(kwargs, "plot")
56    _linestyle = "-" if kind == "evolution" else "none"
57    kwargs.setdefault("linestyle", _linestyle)
58    kwargs.setdefault("linewidth", _linewidth)
59    kwargs.setdefault("markersize", _markersize)
60    kwargs.setdefault("marker", "o")
61    kwargs.setdefault("zorder", 3)
62
63    extra_kwargs = matplotlib_kwarg_dealiaser(extra_kwargs, "plot")
64    if kind == "evolution":
65        extra_kwargs = {
66            **extra_kwargs,
67            **{key: item for key, item in kwargs.items() if key not in extra_kwargs},
68        }
69        kwargs.setdefault("label", "bulk")
70        extra_kwargs.setdefault("label", "tail")
71    else:
72        extra_kwargs.setdefault("linewidth", _linewidth / 2)
73        extra_kwargs.setdefault("color", "k")
74        extra_kwargs.setdefault("alpha", 0.5)
75    kwargs.setdefault("label", kind)
76
77    hline_kwargs = matplotlib_kwarg_dealiaser(hline_kwargs, "plot")
78    hline_kwargs.setdefault("linewidth", _linewidth)
79    hline_kwargs.setdefault("linestyle", "--")
80    hline_kwargs.setdefault("color", "gray")
81    hline_kwargs.setdefault("alpha", 0.7)
82    if extra_methods:
83        text_kwargs = matplotlib_kwarg_dealiaser(text_kwargs, "text")
84        text_x = text_kwargs.pop("x", 1)
85        text_kwargs.setdefault("fontsize", xt_labelsize * 0.7)
86        text_kwargs.setdefault("alpha", extra_kwargs["alpha"])
87        text_kwargs.setdefault("color", extra_kwargs["color"])
88        text_kwargs.setdefault("horizontalalignment", "right")
89        text_va = text_kwargs.pop("verticalalignment", None)
90
91    if ax is None:
92        _, ax = create_axes_grid(
93            len(plotters),
94            rows,
95            cols,
96            backend_kwargs=backend_kwargs,
97        )
98
99    for (var_name, selection, isel, x), ax_ in zip(plotters, np.ravel(ax)):
100        ax_.plot(xdata, x, **kwargs)
101        if kind == "evolution":
102            ess_tail = ess_tail_dataset[var_name].sel(**selection)
103            ax_.plot(xdata, ess_tail, **extra_kwargs)
104        elif rug:
105            rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot")
106            if not hasattr(idata, "sample_stats"):
107                raise ValueError("InferenceData object must contain sample_stats for rug plot")
108            if not hasattr(idata.sample_stats, rug_kind):
109                raise ValueError(f"InferenceData does not contain {rug_kind} data")
110            rug_kwargs.setdefault("marker", "|")
111            rug_kwargs.setdefault("linestyle", rug_kwargs.pop("ls", "None"))
112            rug_kwargs.setdefault("color", rug_kwargs.pop("c", kwargs.get("color", "C0")))
113            rug_kwargs.setdefault("space", 0.1)
114            rug_kwargs.setdefault("markersize", rug_kwargs.pop("ms", 2 * _markersize))
115
116            values = data[var_name].sel(**selection).values.flatten()
117            mask = idata.sample_stats[rug_kind].values.flatten()
118            values = rankdata(values, method="average")[mask]
119            rug_space = np.max(x) * rug_kwargs.pop("space")
120            rug_x, rug_y = values / (len(mask) - 1), np.zeros_like(values) - rug_space
121            ax_.plot(rug_x, rug_y, **rug_kwargs)
122            ax_.axhline(0, color="k", linewidth=_linewidth, alpha=0.7)
123        if extra_methods:
124            mean_ess_i = mean_ess[var_name].sel(**selection).values.item()
125            sd_ess_i = sd_ess[var_name].sel(**selection).values.item()
126            ax_.axhline(mean_ess_i, **extra_kwargs)
127            ax_.annotate(
128                "mean",
129                (text_x, mean_ess_i),
130                va=text_va
131                if text_va is not None
132                else "bottom"
133                if mean_ess_i >= sd_ess_i
134                else "top",
135                **text_kwargs,
136            )
137            ax_.axhline(sd_ess_i, **extra_kwargs)
138            ax_.annotate(
139                "sd",
140                (text_x, sd_ess_i),
141                va=text_va if text_va is not None else "bottom" if sd_ess_i > mean_ess_i else "top",
142                **text_kwargs,
143            )
144
145        ax_.axhline(400 / n_samples if relative else min_ess, **hline_kwargs)
146
147        ax_.set_title(
148            labeller.make_label_vert(var_name, selection, isel), fontsize=titlesize, wrap=True
149        )
150        ax_.tick_params(labelsize=xt_labelsize)
151        ax_.set_xlabel(
152            "Total number of draws" if kind == "evolution" else "Quantile", fontsize=ax_labelsize
153        )
154        ax_.set_ylabel(
155            ylabel.format("Relative ESS" if relative else "ESS"), fontsize=ax_labelsize, wrap=True
156        )
157        if kind == "evolution":
158            ax_.legend(title="Method", fontsize=xt_labelsize, title_fontsize=xt_labelsize)
159        else:
160            ax_.set_xlim(0, 1)
161        if rug:
162            ax_.yaxis.get_major_locator().set_params(nbins="auto", steps=[1, 2, 5, 10])
163            _, ymax = ax_.get_ylim()
164            yticks = ax_.get_yticks().astype(np.int64)
165            yticks = yticks[(yticks >= 0) & (yticks < ymax)]
166            ax_.set_yticks(yticks)
167            ax_.set_yticklabels(yticks)
168        else:
169            ax_.set_ylim(bottom=0)
170
171    if backend_show(show):
172        plt.show()
173
174    return ax
175