1import functools
2import itertools
3import warnings
4
5import numpy as np
6
7from ..core.formatting import format_item
8from .utils import (
9    _get_nice_quiver_magnitude,
10    _infer_xy_labels,
11    _process_cmap_cbar_kwargs,
12    label_from_attrs,
13    plt,
14)
15
16# Overrides axes.labelsize, xtick.major.size, ytick.major.size
17# from mpl.rcParams
18_FONTSIZE = "small"
19# For major ticks on x, y axes
20_NTICKS = 5
21
22
23def _nicetitle(coord, value, maxchar, template):
24    """
25    Put coord, value in template and truncate at maxchar
26    """
27    prettyvalue = format_item(value, quote_strings=False)
28    title = template.format(coord=coord, value=prettyvalue)
29
30    if len(title) > maxchar:
31        title = title[: (maxchar - 3)] + "..."
32
33    return title
34
35
36class FacetGrid:
37    """
38    Initialize the Matplotlib figure and FacetGrid object.
39
40    The :class:`FacetGrid` is an object that links a xarray DataArray to
41    a Matplotlib figure with a particular structure.
42
43    In particular, :class:`FacetGrid` is used to draw plots with multiple
44    axes, where each axes shows the same relationship conditioned on
45    different levels of some dimension. It's possible to condition on up to
46    two variables by assigning variables to the rows and columns of the
47    grid.
48
49    The general approach to plotting here is called "small multiples",
50    where the same kind of plot is repeated multiple times, and the
51    specific use of small multiples to display the same relationship
52    conditioned on one ore more other variables is often called a "trellis
53    plot".
54
55    The basic workflow is to initialize the :class:`FacetGrid` object with
56    the DataArray and the variable names that are used to structure the grid.
57    Then plotting functions can be applied to each subset by calling
58    :meth:`FacetGrid.map_dataarray` or :meth:`FacetGrid.map`.
59
60    Attributes
61    ----------
62    axes : ndarray of matplotlib.axes.Axes
63        Array containing axes in corresponding position, as returned from
64        :py:func:`matplotlib.pyplot.subplots`.
65    col_labels : list of matplotlib.text.Text
66        Column titles.
67    row_labels : list of matplotlib.text.Text
68        Row titles.
69    fig : matplotlib.figure.Figure
70        The figure containing all the axes.
71    name_dicts : ndarray of dict
72        Array containing dictionaries mapping coordinate names to values. ``None`` is
73        used as a sentinel value for axes that should remain empty, i.e.,
74        sometimes the rightmost grid positions in the bottom row.
75    """
76
77    def __init__(
78        self,
79        data,
80        col=None,
81        row=None,
82        col_wrap=None,
83        sharex=True,
84        sharey=True,
85        figsize=None,
86        aspect=1,
87        size=3,
88        subplot_kws=None,
89    ):
90        """
91        Parameters
92        ----------
93        data : DataArray
94            xarray DataArray to be plotted.
95        row, col : str
96            Dimesion names that define subsets of the data, which will be drawn
97            on separate facets in the grid.
98        col_wrap : int, optional
99            "Wrap" the grid the for the column variable after this number of columns,
100            adding rows if ``col_wrap`` is less than the number of facets.
101        sharex : bool, optional
102            If true, the facets will share *x* axes.
103        sharey : bool, optional
104            If true, the facets will share *y* axes.
105        figsize : tuple, optional
106            A tuple (width, height) of the figure in inches.
107            If set, overrides ``size`` and ``aspect``.
108        aspect : scalar, optional
109            Aspect ratio of each facet, so that ``aspect * size`` gives the
110            width of each facet in inches.
111        size : scalar, optional
112            Height (in inches) of each facet. See also: ``aspect``.
113        subplot_kws : dict, optional
114            Dictionary of keyword arguments for Matplotlib subplots
115            (:py:func:`matplotlib.pyplot.subplots`).
116
117        """
118
119        # Handle corner case of nonunique coordinates
120        rep_col = col is not None and not data[col].to_index().is_unique
121        rep_row = row is not None and not data[row].to_index().is_unique
122        if rep_col or rep_row:
123            raise ValueError(
124                "Coordinates used for faceting cannot "
125                "contain repeated (nonunique) values."
126            )
127
128        # single_group is the grouping variable, if there is exactly one
129        if col and row:
130            single_group = False
131            nrow = len(data[row])
132            ncol = len(data[col])
133            nfacet = nrow * ncol
134            if col_wrap is not None:
135                warnings.warn("Ignoring col_wrap since both col and row were passed")
136        elif row and not col:
137            single_group = row
138        elif not row and col:
139            single_group = col
140        else:
141            raise ValueError("Pass a coordinate name as an argument for row or col")
142
143        # Compute grid shape
144        if single_group:
145            nfacet = len(data[single_group])
146            if col:
147                # idea - could add heuristic for nice shapes like 3x4
148                ncol = nfacet
149            if row:
150                ncol = 1
151            if col_wrap is not None:
152                # Overrides previous settings
153                ncol = col_wrap
154            nrow = int(np.ceil(nfacet / ncol))
155
156        # Set the subplot kwargs
157        subplot_kws = {} if subplot_kws is None else subplot_kws
158
159        if figsize is None:
160            # Calculate the base figure size with extra horizontal space for a
161            # colorbar
162            cbar_space = 1
163            figsize = (ncol * size * aspect + cbar_space, nrow * size)
164
165        fig, axes = plt.subplots(
166            nrow,
167            ncol,
168            sharex=sharex,
169            sharey=sharey,
170            squeeze=False,
171            figsize=figsize,
172            subplot_kw=subplot_kws,
173        )
174
175        # Set up the lists of names for the row and column facet variables
176        col_names = list(data[col].to_numpy()) if col else []
177        row_names = list(data[row].to_numpy()) if row else []
178
179        if single_group:
180            full = [{single_group: x} for x in data[single_group].to_numpy()]
181            empty = [None for x in range(nrow * ncol - len(full))]
182            name_dicts = full + empty
183        else:
184            rowcols = itertools.product(row_names, col_names)
185            name_dicts = [{row: r, col: c} for r, c in rowcols]
186
187        name_dicts = np.array(name_dicts).reshape(nrow, ncol)
188
189        # Set up the class attributes
190        # ---------------------------
191
192        # First the public API
193        self.data = data
194        self.name_dicts = name_dicts
195        self.fig = fig
196        self.axes = axes
197        self.row_names = row_names
198        self.col_names = col_names
199
200        # guides
201        self.figlegend = None
202        self.quiverkey = None
203        self.cbar = None
204
205        # Next the private variables
206        self._single_group = single_group
207        self._nrow = nrow
208        self._row_var = row
209        self._ncol = ncol
210        self._col_var = col
211        self._col_wrap = col_wrap
212        self.row_labels = [None] * nrow
213        self.col_labels = [None] * ncol
214        self._x_var = None
215        self._y_var = None
216        self._cmap_extend = None
217        self._mappables = []
218        self._finalized = False
219
220    @property
221    def _left_axes(self):
222        return self.axes[:, 0]
223
224    @property
225    def _bottom_axes(self):
226        return self.axes[-1, :]
227
228    def map_dataarray(self, func, x, y, **kwargs):
229        """
230        Apply a plotting function to a 2d facet's subset of the data.
231
232        This is more convenient and less general than ``FacetGrid.map``
233
234        Parameters
235        ----------
236        func : callable
237            A plotting function with the same signature as a 2d xarray
238            plotting method such as `xarray.plot.imshow`
239        x, y : string
240            Names of the coordinates to plot on x, y axes
241        **kwargs
242            additional keyword arguments to func
243
244        Returns
245        -------
246        self : FacetGrid object
247
248        """
249
250        if kwargs.get("cbar_ax", None) is not None:
251            raise ValueError("cbar_ax not supported by FacetGrid.")
252
253        cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
254            func, self.data.to_numpy(), **kwargs
255        )
256
257        self._cmap_extend = cmap_params.get("extend")
258
259        # Order is important
260        func_kwargs = {
261            k: v
262            for k, v in kwargs.items()
263            if k not in {"cmap", "colors", "cbar_kwargs", "levels"}
264        }
265        func_kwargs.update(cmap_params)
266        func_kwargs["add_colorbar"] = False
267        if func.__name__ != "surface":
268            func_kwargs["add_labels"] = False
269
270        # Get x, y labels for the first subplot
271        x, y = _infer_xy_labels(
272            darray=self.data.loc[self.name_dicts.flat[0]],
273            x=x,
274            y=y,
275            imshow=func.__name__ == "imshow",
276            rgb=kwargs.get("rgb", None),
277        )
278
279        for d, ax in zip(self.name_dicts.flat, self.axes.flat):
280            # None is the sentinel value
281            if d is not None:
282                subset = self.data.loc[d]
283                mappable = func(
284                    subset, x=x, y=y, ax=ax, **func_kwargs, _is_facetgrid=True
285                )
286                self._mappables.append(mappable)
287
288        self._finalize_grid(x, y)
289
290        if kwargs.get("add_colorbar", True):
291            self.add_colorbar(**cbar_kwargs)
292
293        return self
294
295    def map_dataarray_line(
296        self, func, x, y, hue, add_legend=True, _labels=None, **kwargs
297    ):
298        from .plot import _infer_line_data
299
300        for d, ax in zip(self.name_dicts.flat, self.axes.flat):
301            # None is the sentinel value
302            if d is not None:
303                subset = self.data.loc[d]
304                mappable = func(
305                    subset,
306                    x=x,
307                    y=y,
308                    ax=ax,
309                    hue=hue,
310                    add_legend=False,
311                    _labels=False,
312                    **kwargs,
313                )
314                self._mappables.append(mappable)
315
316        xplt, yplt, hueplt, huelabel = _infer_line_data(
317            darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, hue=hue
318        )
319        xlabel = label_from_attrs(xplt)
320        ylabel = label_from_attrs(yplt)
321
322        self._hue_var = hueplt
323        self._hue_label = huelabel
324        self._finalize_grid(xlabel, ylabel)
325
326        if add_legend and hueplt is not None and huelabel is not None:
327            self.add_legend()
328
329        return self
330
331    def map_dataset(
332        self, func, x=None, y=None, hue=None, hue_style=None, add_guide=None, **kwargs
333    ):
334        from .dataset_plot import _infer_meta_data, _parse_size
335
336        kwargs["add_guide"] = False
337
338        if kwargs.get("markersize", None):
339            kwargs["size_mapping"] = _parse_size(
340                self.data[kwargs["markersize"]], kwargs.pop("size_norm", None)
341            )
342
343        meta_data = _infer_meta_data(
344            self.data, x, y, hue, hue_style, add_guide, funcname=func.__name__
345        )
346        kwargs["meta_data"] = meta_data
347
348        if hue and meta_data["hue_style"] == "continuous":
349            cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
350                func, self.data[hue].to_numpy(), **kwargs
351            )
352            kwargs["meta_data"]["cmap_params"] = cmap_params
353            kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs
354
355        kwargs["_is_facetgrid"] = True
356
357        if func.__name__ == "quiver" and "scale" not in kwargs:
358            raise ValueError("Please provide scale.")
359            # TODO: come up with an algorithm for reasonable scale choice
360
361        for d, ax in zip(self.name_dicts.flat, self.axes.flat):
362            # None is the sentinel value
363            if d is not None:
364                subset = self.data.loc[d]
365                maybe_mappable = func(
366                    ds=subset, x=x, y=y, hue=hue, hue_style=hue_style, ax=ax, **kwargs
367                )
368                # TODO: this is needed to get legends to work.
369                # but maybe_mappable is a list in that case :/
370                self._mappables.append(maybe_mappable)
371
372        self._finalize_grid(meta_data["xlabel"], meta_data["ylabel"])
373
374        if hue:
375            self._hue_label = meta_data.pop("hue_label", None)
376            if meta_data["add_legend"]:
377                self._hue_var = meta_data["hue"]
378                self.add_legend()
379            elif meta_data["add_colorbar"]:
380                self.add_colorbar(label=self._hue_label, **cbar_kwargs)
381
382        if meta_data["add_quiverkey"]:
383            self.add_quiverkey(kwargs["u"], kwargs["v"])
384
385        return self
386
387    def _finalize_grid(self, *axlabels):
388        """Finalize the annotations and layout."""
389        if not self._finalized:
390            self.set_axis_labels(*axlabels)
391            self.set_titles()
392            self.fig.tight_layout()
393
394            for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
395                if namedict is None:
396                    ax.set_visible(False)
397
398            self._finalized = True
399
400    def _adjust_fig_for_guide(self, guide):
401        # Draw the plot to set the bounding boxes correctly
402        renderer = self.fig.canvas.get_renderer()
403        self.fig.draw(renderer)
404
405        # Calculate and set the new width of the figure so the legend fits
406        guide_width = guide.get_window_extent(renderer).width / self.fig.dpi
407        figure_width = self.fig.get_figwidth()
408        self.fig.set_figwidth(figure_width + guide_width)
409
410        # Draw the plot again to get the new transformations
411        self.fig.draw(renderer)
412
413        # Now calculate how much space we need on the right side
414        guide_width = guide.get_window_extent(renderer).width / self.fig.dpi
415        space_needed = guide_width / (figure_width + guide_width) + 0.02
416        # margin = .01
417        # _space_needed = margin + space_needed
418        right = 1 - space_needed
419
420        # Place the subplot axes to give space for the legend
421        self.fig.subplots_adjust(right=right)
422
423    def add_legend(self, **kwargs):
424        self.figlegend = self.fig.legend(
425            handles=self._mappables[-1],
426            labels=list(self._hue_var.to_numpy()),
427            title=self._hue_label,
428            loc="center right",
429            **kwargs,
430        )
431        self._adjust_fig_for_guide(self.figlegend)
432
433    def add_colorbar(self, **kwargs):
434        """Draw a colorbar."""
435        kwargs = kwargs.copy()
436        if self._cmap_extend is not None:
437            kwargs.setdefault("extend", self._cmap_extend)
438        # dont pass extend as kwarg if it is in the mappable
439        if hasattr(self._mappables[-1], "extend"):
440            kwargs.pop("extend", None)
441        if "label" not in kwargs:
442            kwargs.setdefault("label", label_from_attrs(self.data))
443        self.cbar = self.fig.colorbar(
444            self._mappables[-1], ax=list(self.axes.flat), **kwargs
445        )
446        return self
447
448    def add_quiverkey(self, u, v, **kwargs):
449        kwargs = kwargs.copy()
450
451        magnitude = _get_nice_quiver_magnitude(self.data[u], self.data[v])
452        units = self.data[u].attrs.get("units", "")
453        self.quiverkey = self.axes.flat[-1].quiverkey(
454            self._mappables[-1],
455            X=0.8,
456            Y=0.9,
457            U=magnitude,
458            label=f"{magnitude}\n{units}",
459            labelpos="E",
460            coordinates="figure",
461        )
462
463        # TODO: does not work because self.quiverkey.get_window_extent(renderer) = 0
464        # https://github.com/matplotlib/matplotlib/issues/18530
465        # self._adjust_fig_for_guide(self.quiverkey.text)
466        return self
467
468    def set_axis_labels(self, x_var=None, y_var=None):
469        """Set axis labels on the left column and bottom row of the grid."""
470        if x_var is not None:
471            if x_var in self.data.coords:
472                self._x_var = x_var
473                self.set_xlabels(label_from_attrs(self.data[x_var]))
474            else:
475                # x_var is a string
476                self.set_xlabels(x_var)
477
478        if y_var is not None:
479            if y_var in self.data.coords:
480                self._y_var = y_var
481                self.set_ylabels(label_from_attrs(self.data[y_var]))
482            else:
483                self.set_ylabels(y_var)
484        return self
485
486    def set_xlabels(self, label=None, **kwargs):
487        """Label the x axis on the bottom row of the grid."""
488        if label is None:
489            label = label_from_attrs(self.data[self._x_var])
490        for ax in self._bottom_axes:
491            ax.set_xlabel(label, **kwargs)
492        return self
493
494    def set_ylabels(self, label=None, **kwargs):
495        """Label the y axis on the left column of the grid."""
496        if label is None:
497            label = label_from_attrs(self.data[self._y_var])
498        for ax in self._left_axes:
499            ax.set_ylabel(label, **kwargs)
500        return self
501
502    def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwargs):
503        """
504        Draw titles either above each facet or on the grid margins.
505
506        Parameters
507        ----------
508        template : string
509            Template for plot titles containing {coord} and {value}
510        maxchar : int
511            Truncate titles at maxchar
512        **kwargs : keyword args
513            additional arguments to matplotlib.text
514
515        Returns
516        -------
517        self: FacetGrid object
518
519        """
520        if size is None:
521            size = plt.rcParams["axes.labelsize"]
522
523        nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template)
524
525        if self._single_group:
526            for d, ax in zip(self.name_dicts.flat, self.axes.flat):
527                # Only label the ones with data
528                if d is not None:
529                    coord, value = list(d.items()).pop()
530                    title = nicetitle(coord, value, maxchar=maxchar)
531                    ax.set_title(title, size=size, **kwargs)
532        else:
533            # The row titles on the right edge of the grid
534            for index, (ax, row_name, handle) in enumerate(
535                zip(self.axes[:, -1], self.row_names, self.row_labels)
536            ):
537                title = nicetitle(coord=self._row_var, value=row_name, maxchar=maxchar)
538                if not handle:
539                    self.row_labels[index] = ax.annotate(
540                        title,
541                        xy=(1.02, 0.5),
542                        xycoords="axes fraction",
543                        rotation=270,
544                        ha="left",
545                        va="center",
546                        **kwargs,
547                    )
548                else:
549                    handle.set_text(title)
550
551            # The column titles on the top row
552            for index, (ax, col_name, handle) in enumerate(
553                zip(self.axes[0, :], self.col_names, self.col_labels)
554            ):
555                title = nicetitle(coord=self._col_var, value=col_name, maxchar=maxchar)
556                if not handle:
557                    self.col_labels[index] = ax.set_title(title, size=size, **kwargs)
558                else:
559                    handle.set_text(title)
560
561        return self
562
563    def set_ticks(self, max_xticks=_NTICKS, max_yticks=_NTICKS, fontsize=_FONTSIZE):
564        """
565        Set and control tick behavior.
566
567        Parameters
568        ----------
569        max_xticks, max_yticks : int, optional
570            Maximum number of labeled ticks to plot on x, y axes
571        fontsize : string or int
572            Font size as used by matplotlib text
573
574        Returns
575        -------
576        self : FacetGrid object
577
578        """
579        from matplotlib.ticker import MaxNLocator
580
581        # Both are necessary
582        x_major_locator = MaxNLocator(nbins=max_xticks)
583        y_major_locator = MaxNLocator(nbins=max_yticks)
584
585        for ax in self.axes.flat:
586            ax.xaxis.set_major_locator(x_major_locator)
587            ax.yaxis.set_major_locator(y_major_locator)
588            for tick in itertools.chain(
589                ax.xaxis.get_major_ticks(), ax.yaxis.get_major_ticks()
590            ):
591                tick.label1.set_fontsize(fontsize)
592
593        return self
594
595    def map(self, func, *args, **kwargs):
596        """
597        Apply a plotting function to each facet's subset of the data.
598
599        Parameters
600        ----------
601        func : callable
602            A plotting function that takes data and keyword arguments. It
603            must plot to the currently active matplotlib Axes and take a
604            `color` keyword argument. If faceting on the `hue` dimension,
605            it must also take a `label` keyword argument.
606        *args : strings
607            Column names in self.data that identify variables with data to
608            plot. The data for each variable is passed to `func` in the
609            order the variables are specified in the call.
610        **kwargs : keyword arguments
611            All keyword arguments are passed to the plotting function.
612
613        Returns
614        -------
615        self : FacetGrid object
616
617        """
618        for ax, namedict in zip(self.axes.flat, self.name_dicts.flat):
619            if namedict is not None:
620                data = self.data.loc[namedict]
621                plt.sca(ax)
622                innerargs = [data[a].to_numpy() for a in args]
623                maybe_mappable = func(*innerargs, **kwargs)
624                # TODO: better way to verify that an artist is mappable?
625                # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522
626                if maybe_mappable and hasattr(maybe_mappable, "autoscale_None"):
627                    self._mappables.append(maybe_mappable)
628
629        self._finalize_grid(*args[:2])
630
631        return self
632
633
634def _easy_facetgrid(
635    data,
636    plotfunc,
637    kind,
638    x=None,
639    y=None,
640    row=None,
641    col=None,
642    col_wrap=None,
643    sharex=True,
644    sharey=True,
645    aspect=None,
646    size=None,
647    subplot_kws=None,
648    ax=None,
649    figsize=None,
650    **kwargs,
651):
652    """
653    Convenience method to call xarray.plot.FacetGrid from 2d plotting methods
654
655    kwargs are the arguments to 2d plotting method
656    """
657    if ax is not None:
658        raise ValueError("Can't use axes when making faceted plots.")
659    if aspect is None:
660        aspect = 1
661    if size is None:
662        size = 3
663    elif figsize is not None:
664        raise ValueError("cannot provide both `figsize` and `size` arguments")
665
666    g = FacetGrid(
667        data=data,
668        col=col,
669        row=row,
670        col_wrap=col_wrap,
671        sharex=sharex,
672        sharey=sharey,
673        figsize=figsize,
674        aspect=aspect,
675        size=size,
676        subplot_kws=subplot_kws,
677    )
678
679    if kind == "line":
680        return g.map_dataarray_line(plotfunc, x, y, **kwargs)
681
682    if kind == "dataarray":
683        return g.map_dataarray(plotfunc, x, y, **kwargs)
684
685    if kind == "dataset":
686        return g.map_dataset(plotfunc, x, y, **kwargs)
687