1"""Plot kde or histograms and values from MCMC samples."""
2import warnings
3from typing import Any, Callable, List, Mapping, Optional, Tuple, Union, Sequence
4
5from ..data import CoordSpec, InferenceData, convert_to_dataset
6from ..labels import BaseLabeller
7from ..rcparams import rcParams
8from ..sel_utils import xarray_var_iter
9from ..utils import _var_names, get_coords
10from .plot_utils import KwargSpec, get_plotting_function
11
12
13def plot_trace(
14    data: InferenceData,
15    var_names: Optional[Sequence[str]] = None,
16    filter_vars: Optional[str] = None,
17    transform: Optional[Callable] = None,
18    coords: Optional[CoordSpec] = None,
19    divergences: Optional[str] = "auto",
20    kind: Optional[str] = "trace",
21    figsize: Optional[Tuple[float, float]] = None,
22    rug: bool = False,
23    lines: Optional[List[Tuple[str, CoordSpec, Any]]] = None,
24    circ_var_names: Optional[List[str]] = None,
25    circ_var_units: str = "radians",
26    compact: bool = True,
27    compact_prop: Optional[Union[str, Mapping[str, Any]]] = None,
28    combined: bool = False,
29    chain_prop: Optional[Union[str, Mapping[str, Any]]] = None,
30    legend: bool = False,
31    plot_kwargs: Optional[KwargSpec] = None,
32    fill_kwargs: Optional[KwargSpec] = None,
33    rug_kwargs: Optional[KwargSpec] = None,
34    hist_kwargs: Optional[KwargSpec] = None,
35    trace_kwargs: Optional[KwargSpec] = None,
36    rank_kwargs: Optional[KwargSpec] = None,
37    labeller=None,
38    axes=None,
39    backend: Optional[str] = None,
40    backend_config: Optional[KwargSpec] = None,
41    backend_kwargs: Optional[KwargSpec] = None,
42    show: Optional[bool] = None,
43):
44    """Plot distribution (histogram or kernel density estimates) and sampled values or rank plot.
45
46    If `divergences` data is available in `sample_stats`, will plot the location of divergences as
47    dashed vertical lines.
48
49    Parameters
50    ----------
51    data: obj
52        Any object that can be converted to an az.InferenceData object
53        Refer to documentation of az.convert_to_dataset for details
54    var_names: str or list of str, optional
55        One or more variables to be plotted. Prefix the variables by `~` when you want
56        to exclude them from the plot.
57    filter_vars: {None, "like", "regex"}, optional, default=None
58        If `None` (default), interpret var_names as the real variables names. If "like",
59        interpret var_names as substrings of the real variables names. If "regex",
60        interpret var_names as regular expressions on the real variables names. A la
61        `pandas.filter`.
62    coords: dict of {str: slice or array_like}, optional
63        Coordinates of var_names to be plotted. Passed to `Dataset.sel`
64    divergences: {"bottom", "top", None}, optional
65        Plot location of divergences on the traceplots.
66    kind: {"trace", "rank_bar", "rank_vlines"}, optional
67        Choose between plotting sampled values per iteration and rank plots.
68    transform: callable, optional
69        Function to transform data (defaults to None i.e.the identity function)
70    figsize: tuple of (float, float), optional
71        If None, size is (12, variables * 2)
72    rug: bool, optional
73        If True adds a rugplot of samples. Defaults to False. Ignored for 2D KDE.
74        Only affects continuous variables.
75    lines: list of tuple of (str, dict, array_like), optional
76        List of (var_name, {'coord': selection}, [line, positions]) to be overplotted as
77        vertical lines on the density and horizontal lines on the trace.
78    circ_var_names : str or list of str, optional
79        List of circular variables to account for when plotting KDE.
80    circ_var_units : str
81        Whether the variables in `circ_var_names` are in "degrees" or "radians".
82    compact: bool, optional
83        Plot multidimensional variables in a single plot.
84    compact_prop: str or dict {str: array_like}, optional
85        Tuple containing the property name and the property values to distinguish different
86        dimensions with compact=True
87    combined: bool, optional
88        Flag for combining multiple chains into a single line. If False (default), chains will be
89        plotted separately.
90    chain_prop: str or dict {str: array_like}, optional
91        Tuple containing the property name and the property values to distinguish different chains
92    legend: bool, optional
93        Add a legend to the figure with the chain color code.
94    plot_kwargs, fill_kwargs, rug_kwargs, hist_kwargs: dict, optional
95        Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
96    trace_kwargs: dict, optional
97        Extra keyword arguments passed to `plt.plot`
98    labeller : labeller instance, optional
99        Class providing the method `make_label_vert` to generate the labels in the plot titles.
100        Read the :ref:`label_guide` for more details and usage examples.
101    rank_kwargs : dict, optional
102        Extra keyword arguments passed to `arviz.plot_rank`
103    axes: axes, optional
104        Matplotlib axes or bokeh figures.
105    backend: {"matplotlib", "bokeh"}, optional
106        Select plotting backend.
107    backend_config: dict, optional
108        Currently specifies the bounds to use for bokeh axes. Defaults to value set in rcParams.
109    backend_kwargs: dict, optional
110        These are kwargs specific to the backend being used. For additional documentation
111        check the plotting method of the backend.
112    show: bool, optional
113        Call backend show function.
114
115    Returns
116    -------
117    axes: matplotlib axes or bokeh figures
118
119    Examples
120    --------
121    Plot a subset variables and select them with partial naming
122
123    .. plot::
124        :context: close-figs
125
126        >>> import arviz as az
127        >>> data = az.load_arviz_data('non_centered_eight')
128        >>> coords = {'school': ['Choate', 'Lawrenceville']}
129        >>> az.plot_trace(data, var_names=('theta'), filter_vars="like", coords=coords)
130
131    Show all dimensions of multidimensional variables in the same plot
132
133    .. plot::
134        :context: close-figs
135
136        >>> az.plot_trace(data, compact=True)
137
138    Display a rank plot instead of trace
139
140    .. plot::
141        :context: close-figs
142
143        >>> az.plot_trace(data, var_names=["mu", "tau"], kind="rank_bars")
144
145    Combine all chains into one distribution and select variables with regular expressions
146
147    .. plot::
148        :context: close-figs
149
150        >>> az.plot_trace(
151        >>>     data, var_names=('^theta'), filter_vars="regex", coords=coords, combined=True
152        >>> )
153
154
155    Plot reference lines against distribution and trace
156
157    .. plot::
158        :context: close-figs
159
160        >>> lines = (('theta_t',{'school': "Choate"}, [-1]),)
161        >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines)
162
163    """
164    if kind not in {"trace", "rank_vlines", "rank_bars"}:
165        raise ValueError("The value of kind must be either trace, rank_vlines or rank_bars.")
166
167    if divergences == "auto":
168        divergences = "top" if rug else "bottom"
169    if divergences:
170        try:
171            divergence_data = convert_to_dataset(data, group="sample_stats").diverging
172        except (ValueError, AttributeError):  # No sample_stats, or no `.diverging`
173            divergences = None
174
175    if coords is None:
176        coords = {}
177
178    if labeller is None:
179        labeller = BaseLabeller()
180
181    if divergences:
182        divergence_data = get_coords(
183            divergence_data, {k: v for k, v in coords.items() if k in ("chain", "draw")}
184        )
185    else:
186        divergence_data = False
187
188    coords_data = get_coords(convert_to_dataset(data, group="posterior"), coords)
189
190    if transform is not None:
191        coords_data = transform(coords_data)
192
193    var_names = _var_names(var_names, coords_data, filter_vars)
194
195    if compact:
196        skip_dims = set(coords_data.dims) - {"chain", "draw"}
197    else:
198        skip_dims = set()
199
200    plotters = list(
201        xarray_var_iter(coords_data, var_names=var_names, combined=True, skip_dims=skip_dims)
202    )
203    max_plots = rcParams["plot.max_subplots"]
204    max_plots = len(plotters) if max_plots is None else max(max_plots // 2, 1)
205    if len(plotters) > max_plots:
206        warnings.warn(
207            "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
208            "of variables to plot ({len_plotters}), generating only {max_plots} "
209            "plots".format(max_plots=max_plots, len_plotters=len(plotters)),
210            UserWarning,
211        )
212        plotters = plotters[:max_plots]
213
214    # TODO: Check if this can be further simplified
215    trace_plot_args = dict(
216        # User Kwargs
217        data=coords_data,
218        var_names=var_names,
219        # coords = coords,
220        divergences=divergences,
221        kind=kind,
222        figsize=figsize,
223        rug=rug,
224        lines=lines,
225        circ_var_names=circ_var_names,
226        circ_var_units=circ_var_units,
227        plot_kwargs=plot_kwargs,
228        fill_kwargs=fill_kwargs,
229        rug_kwargs=rug_kwargs,
230        hist_kwargs=hist_kwargs,
231        trace_kwargs=trace_kwargs,
232        rank_kwargs=rank_kwargs,
233        compact=compact,
234        compact_prop=compact_prop,
235        combined=combined,
236        chain_prop=chain_prop,
237        legend=legend,
238        labeller=labeller,
239        # Generated kwargs
240        divergence_data=divergence_data,
241        # skip_dims=skip_dims,
242        plotters=plotters,
243        axes=axes,
244        backend_config=backend_config,
245        backend_kwargs=backend_kwargs,
246        show=show,
247    )
248
249    if backend is None:
250        backend = rcParams["plot.backend"]
251    backend = backend.lower()
252
253    plot = get_plotting_function("plot_trace", "traceplot", backend)
254    axes = plot(**trace_plot_args)
255
256    return axes
257