1"""Plot kde or histograms and values from MCMC samples.""" 2import warnings 3from typing import Any, Callable, List, Mapping, Optional, Tuple, Union, Sequence 4 5from ..data import CoordSpec, InferenceData, convert_to_dataset 6from ..labels import BaseLabeller 7from ..rcparams import rcParams 8from ..sel_utils import xarray_var_iter 9from ..utils import _var_names, get_coords 10from .plot_utils import KwargSpec, get_plotting_function 11 12 13def plot_trace( 14 data: InferenceData, 15 var_names: Optional[Sequence[str]] = None, 16 filter_vars: Optional[str] = None, 17 transform: Optional[Callable] = None, 18 coords: Optional[CoordSpec] = None, 19 divergences: Optional[str] = "auto", 20 kind: Optional[str] = "trace", 21 figsize: Optional[Tuple[float, float]] = None, 22 rug: bool = False, 23 lines: Optional[List[Tuple[str, CoordSpec, Any]]] = None, 24 circ_var_names: Optional[List[str]] = None, 25 circ_var_units: str = "radians", 26 compact: bool = True, 27 compact_prop: Optional[Union[str, Mapping[str, Any]]] = None, 28 combined: bool = False, 29 chain_prop: Optional[Union[str, Mapping[str, Any]]] = None, 30 legend: bool = False, 31 plot_kwargs: Optional[KwargSpec] = None, 32 fill_kwargs: Optional[KwargSpec] = None, 33 rug_kwargs: Optional[KwargSpec] = None, 34 hist_kwargs: Optional[KwargSpec] = None, 35 trace_kwargs: Optional[KwargSpec] = None, 36 rank_kwargs: Optional[KwargSpec] = None, 37 labeller=None, 38 axes=None, 39 backend: Optional[str] = None, 40 backend_config: Optional[KwargSpec] = None, 41 backend_kwargs: Optional[KwargSpec] = None, 42 show: Optional[bool] = None, 43): 44 """Plot distribution (histogram or kernel density estimates) and sampled values or rank plot. 45 46 If `divergences` data is available in `sample_stats`, will plot the location of divergences as 47 dashed vertical lines. 48 49 Parameters 50 ---------- 51 data: obj 52 Any object that can be converted to an az.InferenceData object 53 Refer to documentation of az.convert_to_dataset for details 54 var_names: str or list of str, optional 55 One or more variables to be plotted. Prefix the variables by `~` when you want 56 to exclude them from the plot. 57 filter_vars: {None, "like", "regex"}, optional, default=None 58 If `None` (default), interpret var_names as the real variables names. If "like", 59 interpret var_names as substrings of the real variables names. If "regex", 60 interpret var_names as regular expressions on the real variables names. A la 61 `pandas.filter`. 62 coords: dict of {str: slice or array_like}, optional 63 Coordinates of var_names to be plotted. Passed to `Dataset.sel` 64 divergences: {"bottom", "top", None}, optional 65 Plot location of divergences on the traceplots. 66 kind: {"trace", "rank_bar", "rank_vlines"}, optional 67 Choose between plotting sampled values per iteration and rank plots. 68 transform: callable, optional 69 Function to transform data (defaults to None i.e.the identity function) 70 figsize: tuple of (float, float), optional 71 If None, size is (12, variables * 2) 72 rug: bool, optional 73 If True adds a rugplot of samples. Defaults to False. Ignored for 2D KDE. 74 Only affects continuous variables. 75 lines: list of tuple of (str, dict, array_like), optional 76 List of (var_name, {'coord': selection}, [line, positions]) to be overplotted as 77 vertical lines on the density and horizontal lines on the trace. 78 circ_var_names : str or list of str, optional 79 List of circular variables to account for when plotting KDE. 80 circ_var_units : str 81 Whether the variables in `circ_var_names` are in "degrees" or "radians". 82 compact: bool, optional 83 Plot multidimensional variables in a single plot. 84 compact_prop: str or dict {str: array_like}, optional 85 Tuple containing the property name and the property values to distinguish different 86 dimensions with compact=True 87 combined: bool, optional 88 Flag for combining multiple chains into a single line. If False (default), chains will be 89 plotted separately. 90 chain_prop: str or dict {str: array_like}, optional 91 Tuple containing the property name and the property values to distinguish different chains 92 legend: bool, optional 93 Add a legend to the figure with the chain color code. 94 plot_kwargs, fill_kwargs, rug_kwargs, hist_kwargs: dict, optional 95 Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables. 96 trace_kwargs: dict, optional 97 Extra keyword arguments passed to `plt.plot` 98 labeller : labeller instance, optional 99 Class providing the method `make_label_vert` to generate the labels in the plot titles. 100 Read the :ref:`label_guide` for more details and usage examples. 101 rank_kwargs : dict, optional 102 Extra keyword arguments passed to `arviz.plot_rank` 103 axes: axes, optional 104 Matplotlib axes or bokeh figures. 105 backend: {"matplotlib", "bokeh"}, optional 106 Select plotting backend. 107 backend_config: dict, optional 108 Currently specifies the bounds to use for bokeh axes. Defaults to value set in rcParams. 109 backend_kwargs: dict, optional 110 These are kwargs specific to the backend being used. For additional documentation 111 check the plotting method of the backend. 112 show: bool, optional 113 Call backend show function. 114 115 Returns 116 ------- 117 axes: matplotlib axes or bokeh figures 118 119 Examples 120 -------- 121 Plot a subset variables and select them with partial naming 122 123 .. plot:: 124 :context: close-figs 125 126 >>> import arviz as az 127 >>> data = az.load_arviz_data('non_centered_eight') 128 >>> coords = {'school': ['Choate', 'Lawrenceville']} 129 >>> az.plot_trace(data, var_names=('theta'), filter_vars="like", coords=coords) 130 131 Show all dimensions of multidimensional variables in the same plot 132 133 .. plot:: 134 :context: close-figs 135 136 >>> az.plot_trace(data, compact=True) 137 138 Display a rank plot instead of trace 139 140 .. plot:: 141 :context: close-figs 142 143 >>> az.plot_trace(data, var_names=["mu", "tau"], kind="rank_bars") 144 145 Combine all chains into one distribution and select variables with regular expressions 146 147 .. plot:: 148 :context: close-figs 149 150 >>> az.plot_trace( 151 >>> data, var_names=('^theta'), filter_vars="regex", coords=coords, combined=True 152 >>> ) 153 154 155 Plot reference lines against distribution and trace 156 157 .. plot:: 158 :context: close-figs 159 160 >>> lines = (('theta_t',{'school': "Choate"}, [-1]),) 161 >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines) 162 163 """ 164 if kind not in {"trace", "rank_vlines", "rank_bars"}: 165 raise ValueError("The value of kind must be either trace, rank_vlines or rank_bars.") 166 167 if divergences == "auto": 168 divergences = "top" if rug else "bottom" 169 if divergences: 170 try: 171 divergence_data = convert_to_dataset(data, group="sample_stats").diverging 172 except (ValueError, AttributeError): # No sample_stats, or no `.diverging` 173 divergences = None 174 175 if coords is None: 176 coords = {} 177 178 if labeller is None: 179 labeller = BaseLabeller() 180 181 if divergences: 182 divergence_data = get_coords( 183 divergence_data, {k: v for k, v in coords.items() if k in ("chain", "draw")} 184 ) 185 else: 186 divergence_data = False 187 188 coords_data = get_coords(convert_to_dataset(data, group="posterior"), coords) 189 190 if transform is not None: 191 coords_data = transform(coords_data) 192 193 var_names = _var_names(var_names, coords_data, filter_vars) 194 195 if compact: 196 skip_dims = set(coords_data.dims) - {"chain", "draw"} 197 else: 198 skip_dims = set() 199 200 plotters = list( 201 xarray_var_iter(coords_data, var_names=var_names, combined=True, skip_dims=skip_dims) 202 ) 203 max_plots = rcParams["plot.max_subplots"] 204 max_plots = len(plotters) if max_plots is None else max(max_plots // 2, 1) 205 if len(plotters) > max_plots: 206 warnings.warn( 207 "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number " 208 "of variables to plot ({len_plotters}), generating only {max_plots} " 209 "plots".format(max_plots=max_plots, len_plotters=len(plotters)), 210 UserWarning, 211 ) 212 plotters = plotters[:max_plots] 213 214 # TODO: Check if this can be further simplified 215 trace_plot_args = dict( 216 # User Kwargs 217 data=coords_data, 218 var_names=var_names, 219 # coords = coords, 220 divergences=divergences, 221 kind=kind, 222 figsize=figsize, 223 rug=rug, 224 lines=lines, 225 circ_var_names=circ_var_names, 226 circ_var_units=circ_var_units, 227 plot_kwargs=plot_kwargs, 228 fill_kwargs=fill_kwargs, 229 rug_kwargs=rug_kwargs, 230 hist_kwargs=hist_kwargs, 231 trace_kwargs=trace_kwargs, 232 rank_kwargs=rank_kwargs, 233 compact=compact, 234 compact_prop=compact_prop, 235 combined=combined, 236 chain_prop=chain_prop, 237 legend=legend, 238 labeller=labeller, 239 # Generated kwargs 240 divergence_data=divergence_data, 241 # skip_dims=skip_dims, 242 plotters=plotters, 243 axes=axes, 244 backend_config=backend_config, 245 backend_kwargs=backend_kwargs, 246 show=show, 247 ) 248 249 if backend is None: 250 backend = rcParams["plot.backend"] 251 backend = backend.lower() 252 253 plot = get_plotting_function("plot_trace", "traceplot", backend) 254 axes = plot(**trace_plot_args) 255 256 return axes 257