1"""Matplotlib traceplot."""
2import warnings
3from itertools import cycle
4
5from matplotlib import gridspec
6import matplotlib.pyplot as plt
7import numpy as np
8from matplotlib.lines import Line2D
9import matplotlib.ticker as mticker
10
11from ....stats.density_utils import get_bins
12from ...distplot import plot_dist
13from ...plot_utils import _scale_fig_size, format_coords_as_labels
14from ...rankplot import plot_rank
15from . import backend_kwarg_defaults, backend_show, dealiase_sel_kwargs, matplotlib_kwarg_dealiaser
16
17
18def plot_trace(
19    data,
20    var_names,  # pylint: disable=unused-argument
21    divergences,
22    kind,
23    figsize,
24    rug,
25    lines,
26    circ_var_names,
27    circ_var_units,
28    compact,
29    compact_prop,
30    combined,
31    chain_prop,
32    legend,
33    labeller,
34    plot_kwargs,
35    fill_kwargs,
36    rug_kwargs,
37    hist_kwargs,
38    trace_kwargs,
39    rank_kwargs,
40    plotters,
41    divergence_data,
42    axes,
43    backend_kwargs,
44    backend_config,  # pylint: disable=unused-argument
45    show,
46):
47    """Plot distribution (histogram or kernel density estimates) and sampled values.
48
49    If `divergences` data is available in `sample_stats`, will plot the location of divergences as
50    dashed vertical lines.
51
52    Parameters
53    ----------
54    data : obj
55        Any object that can be converted to an az.InferenceData object
56        Refer to documentation of az.convert_to_dataset for details
57    var_names : string, or list of strings
58        One or more variables to be plotted.
59    divergences : {"bottom", "top", None, False}
60        Plot location of divergences on the traceplots. Options are "bottom", "top", or False-y.
61    kind : {"trace", "rank_bar", "rank_vlines"}, optional
62        Choose between plotting sampled values per iteration and rank plots.
63    figsize : figure size tuple
64        If None, size is (12, variables * 2)
65    rug : bool
66        If True adds a rugplot. Defaults to False. Ignored for 2D KDE. Only affects continuous
67        variables.
68    lines : tuple or list
69        List of tuple of (var_name, {'coord': selection}, [line_positions]) to be overplotted as
70        vertical lines on the density and horizontal lines on the trace.
71    circ_var_names : string, or list of strings
72        List of circular variables to account for when plotting KDE.
73    circ_var_units : str
74        Whether the variables in `circ_var_names` are in "degrees" or "radians".
75    combined : bool
76        Flag for combining multiple chains into a single line. If False (default), chains will be
77        plotted separately.
78    legend : bool
79        Add a legend to the figure with the chain color code.
80    plot_kwargs : dict
81        Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
82    fill_kwargs : dict
83        Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
84    rug_kwargs : dict
85        Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
86    hist_kwargs : dict
87        Extra keyword arguments passed to `arviz.plot_dist`. Only affects discrete variables.
88    trace_kwargs : dict
89        Extra keyword arguments passed to `plt.plot`
90    rank_kwargs : dict
91        Extra keyword arguments passed to `arviz.plot_rank`
92    Returns
93    -------
94    axes : matplotlib axes
95
96
97    Examples
98    --------
99    Plot a subset variables
100
101    .. plot::
102        :context: close-figs
103
104        >>> import arviz as az
105        >>> data = az.load_arviz_data('non_centered_eight')
106        >>> coords = {'school': ['Choate', 'Lawrenceville']}
107        >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords)
108
109    Show all dimensions of multidimensional variables in the same plot
110
111    .. plot::
112        :context: close-figs
113
114        >>> az.plot_trace(data, compact=True)
115
116    Combine all chains into one distribution
117
118    .. plot::
119        :context: close-figs
120
121        >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, combined=True)
122
123
124    Plot reference lines against distribution and trace
125
126    .. plot::
127        :context: close-figs
128
129        >>> lines = (('theta_t',{'school': "Choate"}, [-1]),)
130        >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines)
131
132    """
133    # Set plot default backend kwargs
134    if backend_kwargs is None:
135        backend_kwargs = {}
136
137    if circ_var_names is None:
138        circ_var_names = []
139
140    backend_kwargs = {**backend_kwarg_defaults(), **backend_kwargs}
141
142    if lines is None:
143        lines = ()
144
145    num_chain_props = len(data.chain) + 1 if combined else len(data.chain)
146    if not compact:
147        chain_prop = "color" if chain_prop is None else chain_prop
148    else:
149        chain_prop = (
150            {
151                "linestyle": ("solid", "dotted", "dashed", "dashdot"),
152            }
153            if chain_prop is None
154            else chain_prop
155        )
156        compact_prop = "color" if compact_prop is None else compact_prop
157
158    if isinstance(chain_prop, str):
159        chain_prop = {chain_prop: plt.rcParams["axes.prop_cycle"].by_key()[chain_prop]}
160    if isinstance(chain_prop, tuple):
161        warnings.warn(
162            "chain_prop as a tuple will be deprecated in a future warning, use a dict instead",
163            FutureWarning,
164        )
165        chain_prop = {chain_prop[0]: chain_prop[1]}
166    chain_prop = {
167        prop_name: [prop for _, prop in zip(range(num_chain_props), cycle(props))]
168        for prop_name, props in chain_prop.items()
169    }
170
171    if isinstance(compact_prop, str):
172        compact_prop = {compact_prop: plt.rcParams["axes.prop_cycle"].by_key()[compact_prop]}
173    if isinstance(compact_prop, tuple):
174        warnings.warn(
175            "compact_prop as a tuple will be deprecated in a future warning, use a dict instead",
176            FutureWarning,
177        )
178        compact_prop = {compact_prop[0]: compact_prop[1]}
179
180    if figsize is None:
181        figsize = (12, len(plotters) * 2)
182
183    backend_kwargs.setdefault("figsize", figsize)
184
185    trace_kwargs = matplotlib_kwarg_dealiaser(trace_kwargs, "plot")
186    trace_kwargs.setdefault("alpha", 0.35)
187
188    hist_kwargs = matplotlib_kwarg_dealiaser(hist_kwargs, "hist")
189    hist_kwargs.setdefault("alpha", 0.35)
190
191    plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
192    fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "fill_between")
193    rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "scatter")
194    rank_kwargs = matplotlib_kwarg_dealiaser(rank_kwargs, "bar")
195
196    textsize = plot_kwargs.pop("textsize", 10)
197
198    figsize, _, titlesize, xt_labelsize, linewidth, _ = _scale_fig_size(
199        figsize, textsize, rows=len(plotters), cols=2
200    )
201
202    trace_kwargs.setdefault("linewidth", linewidth)
203    plot_kwargs.setdefault("linewidth", linewidth)
204
205    # Check the input for lines
206    if lines is not None:
207        all_var_names = set(plotter[0] for plotter in plotters)
208
209        invalid_var_names = set()
210        for line in lines:
211            if line[0] not in all_var_names:
212                invalid_var_names.add(line[0])
213        if invalid_var_names:
214            warnings.warn(
215                "A valid var_name should be provided, found {} expected from {}".format(
216                    invalid_var_names, all_var_names
217                )
218            )
219
220    if axes is None:
221        fig = plt.figure(**backend_kwargs)
222        spec = gridspec.GridSpec(ncols=2, nrows=len(plotters), figure=fig)
223
224    # pylint: disable=too-many-nested-blocks
225    for idx, (var_name, selection, isel, value) in enumerate(plotters):
226        for idy in range(2):
227            value = np.atleast_2d(value)
228
229            circular = var_name in circ_var_names and not idy
230            if var_name in circ_var_names and idy:
231                circ_units_trace = circ_var_units
232            else:
233                circ_units_trace = False
234
235            if axes is None:
236                ax = fig.add_subplot(spec[idx, idy], polar=circular)
237            else:
238                ax = axes[idx, idy]
239
240            if len(value.shape) == 2:
241                if compact_prop:
242                    aux_plot_kwargs = dealiase_sel_kwargs(plot_kwargs, compact_prop, 0)
243                    aux_trace_kwargs = dealiase_sel_kwargs(trace_kwargs, compact_prop, 0)
244                else:
245                    aux_plot_kwargs = plot_kwargs
246                    aux_trace_kwargs = trace_kwargs
247
248                ax = _plot_chains_mpl(
249                    ax,
250                    idy,
251                    value,
252                    data,
253                    chain_prop,
254                    combined,
255                    xt_labelsize,
256                    rug,
257                    kind,
258                    aux_trace_kwargs,
259                    hist_kwargs,
260                    aux_plot_kwargs,
261                    fill_kwargs,
262                    rug_kwargs,
263                    rank_kwargs,
264                    circular,
265                    circ_var_units,
266                    circ_units_trace,
267                )
268
269            else:
270                sub_data = data[var_name].sel(**selection)
271                legend_labels = format_coords_as_labels(sub_data, skip_dims=("chain", "draw"))
272                legend_title = ", ".join(
273                    [
274                        f"{coord_name}"
275                        for coord_name in sub_data.coords
276                        if coord_name not in {"chain", "draw"}
277                    ]
278                )
279                value = value.reshape((value.shape[0], value.shape[1], -1))
280                compact_prop_iter = {
281                    prop_name: [prop for _, prop in zip(range(value.shape[2]), cycle(props))]
282                    for prop_name, props in compact_prop.items()
283                }
284                handles = []
285                for sub_idx, label in zip(range(value.shape[2]), legend_labels):
286                    aux_plot_kwargs = dealiase_sel_kwargs(plot_kwargs, compact_prop_iter, sub_idx)
287                    aux_trace_kwargs = dealiase_sel_kwargs(trace_kwargs, compact_prop_iter, sub_idx)
288                    ax = _plot_chains_mpl(
289                        ax,
290                        idy,
291                        value[..., sub_idx],
292                        data,
293                        chain_prop,
294                        combined,
295                        xt_labelsize,
296                        rug,
297                        kind,
298                        aux_trace_kwargs,
299                        hist_kwargs,
300                        aux_plot_kwargs,
301                        fill_kwargs,
302                        rug_kwargs,
303                        rank_kwargs,
304                        circular,
305                        circ_var_units,
306                        circ_units_trace,
307                    )
308                    if legend:
309                        handles.append(
310                            Line2D(
311                                [],
312                                [],
313                                label=label,
314                                **dealiase_sel_kwargs(aux_plot_kwargs, chain_prop, 0),
315                            )
316                        )
317                if legend and idy == 0:
318                    ax.legend(handles=handles, title=legend_title)
319
320            if value[0].dtype.kind == "i" and idy == 0:
321                xticks = get_bins(value)
322                ax.set_xticks(xticks[:-1])
323            y = 1 / textsize
324            if not idy:
325                ax.set_yticks([])
326                if circular:
327                    y = 0.13 if selection else 0.12
328            ax.set_title(
329                labeller.make_label_vert(var_name, selection, isel),
330                fontsize=titlesize,
331                wrap=True,
332                y=textsize * y,
333            )
334            ax.tick_params(labelsize=xt_labelsize)
335
336            xlims = ax.get_xlim()
337            ylims = ax.get_ylim()
338
339            if divergences:
340                div_selection = {k: v for k, v in selection.items() if k in divergence_data.dims}
341                divs = divergence_data.sel(**div_selection).values
342                # if combined:
343                #     divs = divs.flatten()
344                divs = np.atleast_2d(divs)
345
346                for chain, chain_divs in enumerate(divs):
347                    div_draws = data.draw.values[chain_divs]
348                    div_idxs = np.arange(len(chain_divs))[chain_divs]
349                    if div_idxs.size > 0:
350                        if divergences == "top":
351                            ylocs = ylims[1]
352                        else:
353                            ylocs = ylims[0]
354                        values = value[chain, div_idxs]
355
356                        if circular:
357                            tick = [ax.get_rmin() + ax.get_rmax() * 0.60, ax.get_rmax()]
358                            for val in values:
359                                ax.plot(
360                                    [val, val],
361                                    tick,
362                                    color="black",
363                                    markeredgewidth=1.5,
364                                    markersize=30,
365                                    alpha=trace_kwargs["alpha"],
366                                    zorder=0.6,
367                                )
368                        else:
369                            if kind == "trace" and idy:
370                                ax.plot(
371                                    div_draws,
372                                    np.zeros_like(div_idxs) + ylocs,
373                                    marker="|",
374                                    color="black",
375                                    markeredgewidth=1.5,
376                                    markersize=30,
377                                    linestyle="None",
378                                    alpha=hist_kwargs["alpha"],
379                                    zorder=0.6,
380                                )
381                            elif not idy:
382                                ax.plot(
383                                    values,
384                                    np.zeros_like(values) + ylocs,
385                                    marker="|",
386                                    color="black",
387                                    markeredgewidth=1.5,
388                                    markersize=30,
389                                    linestyle="None",
390                                    alpha=trace_kwargs["alpha"],
391                                    zorder=0.6,
392                                )
393
394            for _, _, vlines in (j for j in lines if j[0] == var_name and j[1] == selection):
395                if isinstance(vlines, (float, int)):
396                    line_values = [vlines]
397                else:
398                    line_values = np.atleast_1d(vlines).ravel()
399                    if not np.issubdtype(line_values.dtype, np.number):
400                        raise ValueError(f"line-positions should be numeric, found {line_values}")
401                if idy:
402                    ax.hlines(
403                        line_values,
404                        xlims[0],
405                        xlims[1],
406                        colors="black",
407                        linewidth=1.5,
408                        alpha=trace_kwargs["alpha"],
409                    )
410
411                else:
412                    ax.vlines(
413                        line_values,
414                        ylims[0],
415                        ylims[1],
416                        colors="black",
417                        linewidth=1.5,
418                        alpha=trace_kwargs["alpha"],
419                    )
420
421        if kind == "trace" and idy:
422            ax.set_xlim(left=data.draw.min(), right=data.draw.max())
423
424    if legend:
425        legend_kwargs = trace_kwargs if combined else plot_kwargs
426        handles = [
427            Line2D(
428                [], [], label=chain_id, **dealiase_sel_kwargs(legend_kwargs, chain_prop, chain_id)
429            )
430            for chain_id in range(data.dims["chain"])
431        ]
432        if combined:
433            handles.insert(
434                0,
435                Line2D(
436                    [], [], label="combined", **dealiase_sel_kwargs(plot_kwargs, chain_prop, -1)
437                ),
438            )
439        ax.figure.axes[0].legend(handles=handles, title="chain", loc="upper right")
440
441    if axes is None:
442        axes = np.array(ax.figure.axes).reshape(-1, 2)
443
444    if backend_show(show):
445        plt.show()
446
447    return axes
448
449
450def _plot_chains_mpl(
451    axes,
452    idy,
453    value,
454    data,
455    chain_prop,
456    combined,
457    xt_labelsize,
458    rug,
459    kind,
460    trace_kwargs,
461    hist_kwargs,
462    plot_kwargs,
463    fill_kwargs,
464    rug_kwargs,
465    rank_kwargs,
466    circular,
467    circ_var_units,
468    circ_units_trace,
469):
470
471    if not circular:
472        circ_var_units = False
473
474    for chain_idx, row in enumerate(value):
475        if kind == "trace":
476            aux_kwargs = dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx)
477            if idy:
478                axes.plot(data.draw.values, row, **aux_kwargs)
479                if circ_units_trace == "degrees":
480                    y_tick_locs = axes.get_yticks()
481                    y_tick_labels = [i + 2 * 180 if i < 0 else i for i in np.rad2deg(y_tick_locs)]
482                    axes.yaxis.set_major_locator(mticker.FixedLocator(y_tick_locs))
483                    axes.set_yticklabels([f"{i:.0f}°" for i in y_tick_labels])
484
485        if not combined:
486            aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, chain_idx)
487            if not idy:
488                axes = plot_dist(
489                    values=row,
490                    textsize=xt_labelsize,
491                    rug=rug,
492                    ax=axes,
493                    hist_kwargs=hist_kwargs,
494                    plot_kwargs=aux_kwargs,
495                    fill_kwargs=fill_kwargs,
496                    rug_kwargs=rug_kwargs,
497                    backend="matplotlib",
498                    show=False,
499                    is_circular=circ_var_units,
500                )
501
502    if kind == "rank_bars" and idy:
503        axes = plot_rank(data=value, kind="bars", ax=axes, **rank_kwargs)
504    elif kind == "rank_vlines" and idy:
505        axes = plot_rank(data=value, kind="vlines", ax=axes, **rank_kwargs)
506
507    if combined:
508        aux_kwargs = dealiase_sel_kwargs(plot_kwargs, chain_prop, -1)
509        if not idy:
510            axes = plot_dist(
511                values=value.flatten(),
512                textsize=xt_labelsize,
513                rug=rug,
514                ax=axes,
515                hist_kwargs=hist_kwargs,
516                plot_kwargs=aux_kwargs,
517                fill_kwargs=fill_kwargs,
518                rug_kwargs=rug_kwargs,
519                backend="matplotlib",
520                show=False,
521                is_circular=circ_var_units,
522            )
523    return axes
524