1"""
2Use this module directly:
3    import xarray.plot as xplt
4
5Or use the methods on a DataArray or Dataset:
6    DataArray.plot._____
7    Dataset.plot._____
8"""
9import functools
10from distutils.version import LooseVersion
11
12import numpy as np
13import pandas as pd
14
15from ..core.alignment import broadcast
16from .facetgrid import _easy_facetgrid
17from .utils import (
18    _add_colorbar,
19    _adjust_legend_subtitles,
20    _assert_valid_xy,
21    _ensure_plottable,
22    _infer_interval_breaks,
23    _infer_xy_labels,
24    _is_numeric,
25    _legend_add_subtitle,
26    _process_cmap_cbar_kwargs,
27    _rescale_imshow_rgb,
28    _resolve_intervals_1dplot,
29    _resolve_intervals_2dplot,
30    _update_axes,
31    get_axis,
32    label_from_attrs,
33    legend_elements,
34    plt,
35)
36
37# copied from seaborn
38_MARKERSIZE_RANGE = np.array([18.0, 72.0])
39
40
41def _infer_scatter_metadata(darray, x, z, hue, hue_style, size):
42    def _determine_array(darray, name, array_style):
43        """Find and determine what type of array it is."""
44        array = darray[name]
45        array_is_numeric = _is_numeric(array.values)
46
47        if array_style is None:
48            array_style = "continuous" if array_is_numeric else "discrete"
49        elif array_style not in ["discrete", "continuous"]:
50            raise ValueError(
51                f"The style '{array_style}' is not valid, "
52                "valid options are None, 'discrete' or 'continuous'."
53            )
54
55        array_label = label_from_attrs(array)
56
57        return array, array_style, array_label
58
59    # Add nice looking labels:
60    out = dict(ylabel=label_from_attrs(darray))
61    out.update(
62        {
63            k: label_from_attrs(darray[v]) if v in darray.coords else None
64            for k, v in [("xlabel", x), ("zlabel", z)]
65        }
66    )
67
68    # Add styles and labels for the dataarrays:
69    for type_, a, style in [("hue", hue, hue_style), ("size", size, None)]:
70        tp, stl, lbl = f"{type_}", f"{type_}_style", f"{type_}_label"
71        if a:
72            out[tp], out[stl], out[lbl] = _determine_array(darray, a, style)
73        else:
74            out[tp], out[stl], out[lbl] = None, None, None
75
76    return out
77
78
79# copied from seaborn
80def _parse_size(data, norm, width):
81    """
82    Determine what type of data it is. Then normalize it to width.
83
84    If the data is categorical, normalize it to numbers.
85    """
86    if data is None:
87        return None
88
89    data = data.values.ravel()
90
91    if not _is_numeric(data):
92        # Data is categorical.
93        # Use pd.unique instead of np.unique because that keeps
94        # the order of the labels:
95        levels = pd.unique(data)
96        numbers = np.arange(1, 1 + len(levels))
97    else:
98        levels = numbers = np.sort(np.unique(data))
99
100    min_width, max_width = width
101    # width_range = min_width, max_width
102
103    if norm is None:
104        norm = plt.Normalize()
105    elif isinstance(norm, tuple):
106        norm = plt.Normalize(*norm)
107    elif not isinstance(norm, plt.Normalize):
108        err = "``size_norm`` must be None, tuple, or Normalize object."
109        raise ValueError(err)
110
111    norm.clip = True
112    if not norm.scaled():
113        norm(np.asarray(numbers))
114    # limits = norm.vmin, norm.vmax
115
116    scl = norm(numbers)
117    widths = np.asarray(min_width + scl * (max_width - min_width))
118    if scl.mask.any():
119        widths[scl.mask] = 0
120    sizes = dict(zip(levels, widths))
121
122    return pd.Series(sizes)
123
124
125def _infer_scatter_data(
126    darray, x, z, hue, size, size_norm, size_mapping=None, size_range=(1, 10)
127):
128    # Broadcast together all the chosen variables:
129    to_broadcast = dict(y=darray)
130    to_broadcast.update(
131        {k: darray[v] for k, v in dict(x=x, z=z).items() if v is not None}
132    )
133    to_broadcast.update(
134        {k: darray[v] for k, v in dict(hue=hue, size=size).items() if v in darray.dims}
135    )
136    broadcasted = dict(zip(to_broadcast.keys(), broadcast(*(to_broadcast.values()))))
137
138    # Normalize hue and size and create lookup tables:
139    for type_, mapping, norm, width in [
140        ("hue", None, None, [0, 1]),
141        ("size", size_mapping, size_norm, size_range),
142    ]:
143        broadcasted_type = broadcasted.get(type_, None)
144        if broadcasted_type is not None:
145            if mapping is None:
146                mapping = _parse_size(broadcasted_type, norm, width)
147
148            broadcasted[type_] = broadcasted_type.copy(
149                data=np.reshape(
150                    mapping.loc[broadcasted_type.values.ravel()].values,
151                    broadcasted_type.shape,
152                )
153            )
154            broadcasted[f"{type_}_to_label"] = pd.Series(mapping.index, index=mapping)
155
156    return broadcasted
157
158
159def _infer_line_data(darray, x, y, hue):
160
161    ndims = len(darray.dims)
162
163    if x is not None and y is not None:
164        raise ValueError("Cannot specify both x and y kwargs for line plots.")
165
166    if x is not None:
167        _assert_valid_xy(darray, x, "x")
168
169    if y is not None:
170        _assert_valid_xy(darray, y, "y")
171
172    if ndims == 1:
173        huename = None
174        hueplt = None
175        huelabel = ""
176
177        if x is not None:
178            xplt = darray[x]
179            yplt = darray
180
181        elif y is not None:
182            xplt = darray
183            yplt = darray[y]
184
185        else:  # Both x & y are None
186            dim = darray.dims[0]
187            xplt = darray[dim]
188            yplt = darray
189
190    else:
191        if x is None and y is None and hue is None:
192            raise ValueError("For 2D inputs, please specify either hue, x or y.")
193
194        if y is None:
195            if hue is not None:
196                _assert_valid_xy(darray, hue, "hue")
197            xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue)
198            xplt = darray[xname]
199            if xplt.ndim > 1:
200                if huename in darray.dims:
201                    otherindex = 1 if darray.dims.index(huename) == 0 else 0
202                    otherdim = darray.dims[otherindex]
203                    yplt = darray.transpose(otherdim, huename, transpose_coords=False)
204                    xplt = xplt.transpose(otherdim, huename, transpose_coords=False)
205                else:
206                    raise ValueError(
207                        "For 2D inputs, hue must be a dimension"
208                        " i.e. one of " + repr(darray.dims)
209                    )
210
211            else:
212                (xdim,) = darray[xname].dims
213                (huedim,) = darray[huename].dims
214                yplt = darray.transpose(xdim, huedim)
215
216        else:
217            yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue)
218            yplt = darray[yname]
219            if yplt.ndim > 1:
220                if huename in darray.dims:
221                    otherindex = 1 if darray.dims.index(huename) == 0 else 0
222                    otherdim = darray.dims[otherindex]
223                    xplt = darray.transpose(otherdim, huename, transpose_coords=False)
224                    yplt = yplt.transpose(otherdim, huename, transpose_coords=False)
225                else:
226                    raise ValueError(
227                        "For 2D inputs, hue must be a dimension"
228                        " i.e. one of " + repr(darray.dims)
229                    )
230
231            else:
232                (ydim,) = darray[yname].dims
233                (huedim,) = darray[huename].dims
234                xplt = darray.transpose(ydim, huedim)
235
236        huelabel = label_from_attrs(darray[huename])
237        hueplt = darray[huename]
238
239    return xplt, yplt, hueplt, huelabel
240
241
242def plot(
243    darray,
244    row=None,
245    col=None,
246    col_wrap=None,
247    ax=None,
248    hue=None,
249    rtol=0.01,
250    subplot_kws=None,
251    **kwargs,
252):
253    """
254    Default plot of DataArray using :py:mod:`matplotlib:matplotlib.pyplot`.
255
256    Calls xarray plotting function based on the dimensions of
257    the squeezed DataArray.
258
259    =============== ===========================
260    Dimensions      Plotting function
261    =============== ===========================
262    1               :py:func:`xarray.plot.line`
263    2               :py:func:`xarray.plot.pcolormesh`
264    Anything else   :py:func:`xarray.plot.hist`
265    =============== ===========================
266
267    Parameters
268    ----------
269    darray : DataArray
270    row : str, optional
271        If passed, make row faceted plots on this dimension name.
272    col : str, optional
273        If passed, make column faceted plots on this dimension name.
274    hue : str, optional
275        If passed, make faceted line plots with hue on this dimension name.
276    col_wrap : int, optional
277        Use together with ``col`` to wrap faceted plots.
278    ax : matplotlib axes object, optional
279        If ``None``, use the current axes. Not applicable when using facets.
280    rtol : float, optional
281        Relative tolerance used to determine if the indexes
282        are uniformly spaced. Usually a small positive number.
283    subplot_kws : dict, optional
284        Dictionary of keyword arguments for Matplotlib subplots
285        (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`).
286    **kwargs : optional
287        Additional keyword arguments for Matplotlib.
288
289    See Also
290    --------
291    xarray.DataArray.squeeze
292    """
293    darray = darray.squeeze().compute()
294
295    plot_dims = set(darray.dims)
296    plot_dims.discard(row)
297    plot_dims.discard(col)
298    plot_dims.discard(hue)
299
300    ndims = len(plot_dims)
301
302    error_msg = (
303        "Only 1d and 2d plots are supported for facets in xarray. "
304        "See the package `Seaborn` for more options."
305    )
306
307    if ndims in [1, 2]:
308        if row or col:
309            kwargs["subplot_kws"] = subplot_kws
310            kwargs["row"] = row
311            kwargs["col"] = col
312            kwargs["col_wrap"] = col_wrap
313        if ndims == 1:
314            plotfunc = line
315            kwargs["hue"] = hue
316        elif ndims == 2:
317            if hue:
318                plotfunc = line
319                kwargs["hue"] = hue
320            else:
321                plotfunc = pcolormesh
322                kwargs["subplot_kws"] = subplot_kws
323    else:
324        if row or col or hue:
325            raise ValueError(error_msg)
326        plotfunc = hist
327
328    kwargs["ax"] = ax
329
330    return plotfunc(darray, **kwargs)
331
332
333# This function signature should not change so that it can use
334# matplotlib format strings
335def line(
336    darray,
337    *args,
338    row=None,
339    col=None,
340    figsize=None,
341    aspect=None,
342    size=None,
343    ax=None,
344    hue=None,
345    x=None,
346    y=None,
347    xincrease=None,
348    yincrease=None,
349    xscale=None,
350    yscale=None,
351    xticks=None,
352    yticks=None,
353    xlim=None,
354    ylim=None,
355    add_legend=True,
356    _labels=True,
357    **kwargs,
358):
359    """
360    Line plot of DataArray values.
361
362    Wraps :py:func:`matplotlib:matplotlib.pyplot.plot`.
363
364    Parameters
365    ----------
366    darray : DataArray
367        Either 1D or 2D. If 2D, one of ``hue``, ``x`` or ``y`` must be provided.
368    figsize : tuple, optional
369        A tuple (width, height) of the figure in inches.
370        Mutually exclusive with ``size`` and ``ax``.
371    aspect : scalar, optional
372        Aspect ratio of plot, so that ``aspect * size`` gives the *width* in
373        inches. Only used if a ``size`` is provided.
374    size : scalar, optional
375        If provided, create a new figure for the plot with the given size:
376        *height* (in inches) of each plot. See also: ``aspect``.
377    ax : matplotlib axes object, optional
378        Axes on which to plot. By default, the current is used.
379        Mutually exclusive with ``size`` and ``figsize``.
380    hue : str, optional
381        Dimension or coordinate for which you want multiple lines plotted.
382        If plotting against a 2D coordinate, ``hue`` must be a dimension.
383    x, y : str, optional
384        Dimension, coordinate or multi-index level for *x*, *y* axis.
385        Only one of these may be specified.
386        The other will be used for values from the DataArray on which this
387        plot method is called.
388    xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional
389        Specifies scaling for the *x*- and *y*-axis, respectively.
390    xticks, yticks : array-like, optional
391        Specify tick locations for *x*- and *y*-axis.
392    xlim, ylim : array-like, optional
393        Specify *x*- and *y*-axis limits.
394    xincrease : None, True, or False, optional
395        Should the values on the *x* axis be increasing from left to right?
396        if ``None``, use the default for the Matplotlib function.
397    yincrease : None, True, or False, optional
398        Should the values on the *y* axis be increasing from top to bottom?
399        if ``None``, use the default for the Matplotlib function.
400    add_legend : bool, optional
401        Add legend with *y* axis coordinates (2D inputs only).
402    *args, **kwargs : optional
403        Additional arguments to :py:func:`matplotlib:matplotlib.pyplot.plot`.
404    """
405    # Handle facetgrids first
406    if row or col:
407        allargs = locals().copy()
408        allargs.update(allargs.pop("kwargs"))
409        allargs.pop("darray")
410        return _easy_facetgrid(darray, line, kind="line", **allargs)
411
412    ndims = len(darray.dims)
413    if ndims > 2:
414        raise ValueError(
415            "Line plots are for 1- or 2-dimensional DataArrays. "
416            "Passed DataArray has {ndims} "
417            "dimensions".format(ndims=ndims)
418        )
419
420    # The allargs dict passed to _easy_facetgrid above contains args
421    if args == ():
422        args = kwargs.pop("args", ())
423    else:
424        assert "args" not in kwargs
425
426    ax = get_axis(figsize, size, aspect, ax)
427    xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue)
428
429    # Remove pd.Intervals if contained in xplt.values and/or yplt.values.
430    xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot(
431        xplt.to_numpy(), yplt.to_numpy(), kwargs
432    )
433    xlabel = label_from_attrs(xplt, extra=x_suffix)
434    ylabel = label_from_attrs(yplt, extra=y_suffix)
435
436    _ensure_plottable(xplt_val, yplt_val)
437
438    primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs)
439
440    if _labels:
441        if xlabel is not None:
442            ax.set_xlabel(xlabel)
443
444        if ylabel is not None:
445            ax.set_ylabel(ylabel)
446
447        ax.set_title(darray._title_for_slice())
448
449    if darray.ndim == 2 and add_legend:
450        ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)
451
452    # Rotate dates on xlabels
453    # Do this without calling autofmt_xdate so that x-axes ticks
454    # on other subplots (if any) are not deleted.
455    # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots
456    if np.issubdtype(xplt.dtype, np.datetime64):
457        for xlabels in ax.get_xticklabels():
458            xlabels.set_rotation(30)
459            xlabels.set_ha("right")
460
461    _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim)
462
463    return primitive
464
465
466def step(darray, *args, where="pre", drawstyle=None, ds=None, **kwargs):
467    """
468    Step plot of DataArray values.
469
470    Similar to :py:func:`matplotlib:matplotlib.pyplot.step`.
471
472    Parameters
473    ----------
474    where : {'pre', 'post', 'mid'}, default: 'pre'
475        Define where the steps should be placed:
476
477        - ``'pre'``: The y value is continued constantly to the left from
478          every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the
479          value ``y[i]``.
480        - ``'post'``: The y value is continued constantly to the right from
481          every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the
482          value ``y[i]``.
483        - ``'mid'``: Steps occur half-way between the *x* positions.
484
485        Note that this parameter is ignored if one coordinate consists of
486        :py:class:`pandas.Interval` values, e.g. as a result of
487        :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual
488        boundaries of the interval are used.
489    *args, **kwargs : optional
490        Additional arguments for :py:func:`xarray.plot.line`.
491    """
492    if where not in {"pre", "post", "mid"}:
493        raise ValueError("'where' argument to step must be 'pre', 'post' or 'mid'")
494
495    if ds is not None:
496        if drawstyle is None:
497            drawstyle = ds
498        else:
499            raise TypeError("ds and drawstyle are mutually exclusive")
500    if drawstyle is None:
501        drawstyle = ""
502    drawstyle = "steps-" + where + drawstyle
503
504    return line(darray, *args, drawstyle=drawstyle, **kwargs)
505
506
507def hist(
508    darray,
509    figsize=None,
510    size=None,
511    aspect=None,
512    ax=None,
513    xincrease=None,
514    yincrease=None,
515    xscale=None,
516    yscale=None,
517    xticks=None,
518    yticks=None,
519    xlim=None,
520    ylim=None,
521    **kwargs,
522):
523    """
524    Histogram of DataArray.
525
526    Wraps :py:func:`matplotlib:matplotlib.pyplot.hist`.
527
528    Plots *N*-dimensional arrays by first flattening the array.
529
530    Parameters
531    ----------
532    darray : DataArray
533        Can have any number of dimensions.
534    figsize : tuple, optional
535        A tuple (width, height) of the figure in inches.
536        Mutually exclusive with ``size`` and ``ax``.
537    aspect : scalar, optional
538        Aspect ratio of plot, so that ``aspect * size`` gives the *width* in
539        inches. Only used if a ``size`` is provided.
540    size : scalar, optional
541        If provided, create a new figure for the plot with the given size:
542        *height* (in inches) of each plot. See also: ``aspect``.
543    ax : matplotlib axes object, optional
544        Axes on which to plot. By default, use the current axes.
545        Mutually exclusive with ``size`` and ``figsize``.
546    **kwargs : optional
547        Additional keyword arguments to :py:func:`matplotlib:matplotlib.pyplot.hist`.
548
549    """
550    ax = get_axis(figsize, size, aspect, ax)
551
552    no_nan = np.ravel(darray.to_numpy())
553    no_nan = no_nan[pd.notnull(no_nan)]
554
555    primitive = ax.hist(no_nan, **kwargs)
556
557    ax.set_title(darray._title_for_slice())
558    ax.set_xlabel(label_from_attrs(darray))
559
560    _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim)
561
562    return primitive
563
564
565def scatter(
566    darray,
567    *args,
568    row=None,
569    col=None,
570    figsize=None,
571    aspect=None,
572    size=None,
573    ax=None,
574    hue=None,
575    hue_style=None,
576    x=None,
577    z=None,
578    xincrease=None,
579    yincrease=None,
580    xscale=None,
581    yscale=None,
582    xticks=None,
583    yticks=None,
584    xlim=None,
585    ylim=None,
586    add_legend=None,
587    add_colorbar=None,
588    cbar_kwargs=None,
589    cbar_ax=None,
590    vmin=None,
591    vmax=None,
592    norm=None,
593    infer_intervals=None,
594    center=None,
595    levels=None,
596    robust=None,
597    colors=None,
598    extend=None,
599    cmap=None,
600    _labels=True,
601    **kwargs,
602):
603    """
604    Scatter plot a DataArray along some coordinates.
605
606    Parameters
607    ----------
608    darray : DataArray
609        Dataarray to plot.
610    x, y : str
611        Variable names for x, y axis.
612    hue: str, optional
613        Variable by which to color scattered points
614    hue_style: str, optional
615        Can be either 'discrete' (legend) or 'continuous' (color bar).
616    markersize: str, optional
617        scatter only. Variable by which to vary size of scattered points.
618    size_norm: optional
619        Either None or 'Norm' instance to normalize the 'markersize' variable.
620    add_guide: bool, optional
621        Add a guide that depends on hue_style
622            - for "discrete", build a legend.
623              This is the default for non-numeric `hue` variables.
624            - for "continuous",  build a colorbar
625    row : str, optional
626        If passed, make row faceted plots on this dimension name
627    col : str, optional
628        If passed, make column faceted plots on this dimension name
629    col_wrap : int, optional
630        Use together with ``col`` to wrap faceted plots
631    ax : matplotlib axes object, optional
632        If None, uses the current axis. Not applicable when using facets.
633    subplot_kws : dict, optional
634        Dictionary of keyword arguments for matplotlib subplots. Only applies
635        to FacetGrid plotting.
636    aspect : scalar, optional
637        Aspect ratio of plot, so that ``aspect * size`` gives the width in
638        inches. Only used if a ``size`` is provided.
639    size : scalar, optional
640        If provided, create a new figure for the plot with the given size.
641        Height (in inches) of each plot. See also: ``aspect``.
642    norm : ``matplotlib.colors.Normalize`` instance, optional
643        If the ``norm`` has vmin or vmax specified, the corresponding kwarg
644        must be None.
645    vmin, vmax : float, optional
646        Values to anchor the colormap, otherwise they are inferred from the
647        data and other keyword arguments. When a diverging dataset is inferred,
648        setting one of these values will fix the other by symmetry around
649        ``center``. Setting both values prevents use of a diverging colormap.
650        If discrete levels are provided as an explicit list, both of these
651        values are ignored.
652    cmap : str or colormap, optional
653        The mapping from data values to color space. Either a
654        matplotlib colormap name or object. If not provided, this will
655        be either ``viridis`` (if the function infers a sequential
656        dataset) or ``RdBu_r`` (if the function infers a diverging
657        dataset).  When `Seaborn` is installed, ``cmap`` may also be a
658        `seaborn` color palette. If ``cmap`` is seaborn color palette
659        and the plot type is not ``contour`` or ``contourf``, ``levels``
660        must also be specified.
661    colors : color-like or list of color-like, optional
662        A single color or a list of colors. If the plot type is not ``contour``
663        or ``contourf``, the ``levels`` argument is required.
664    center : float, optional
665        The value at which to center the colormap. Passing this value implies
666        use of a diverging colormap. Setting it to ``False`` prevents use of a
667        diverging colormap.
668    robust : bool, optional
669        If True and ``vmin`` or ``vmax`` are absent, the colormap range is
670        computed with 2nd and 98th percentiles instead of the extreme values.
671    extend : {"neither", "both", "min", "max"}, optional
672        How to draw arrows extending the colorbar beyond its limits. If not
673        provided, extend is inferred from vmin, vmax and the data limits.
674    levels : int or list-like object, optional
675        Split the colormap (cmap) into discrete color intervals. If an integer
676        is provided, "nice" levels are chosen based on the data range: this can
677        imply that the final number of levels is not exactly the expected one.
678        Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to
679        setting ``levels=np.linspace(vmin, vmax, N)``.
680    **kwargs : optional
681        Additional keyword arguments to matplotlib
682    """
683    # Handle facetgrids first
684    if row or col:
685        allargs = locals().copy()
686        allargs.update(allargs.pop("kwargs"))
687        allargs.pop("darray")
688        subplot_kws = dict(projection="3d") if z is not None else None
689        return _easy_facetgrid(
690            darray, scatter, kind="dataarray", subplot_kws=subplot_kws, **allargs
691        )
692
693    # Further
694    _is_facetgrid = kwargs.pop("_is_facetgrid", False)
695    if _is_facetgrid:
696        # Why do I need to pop these here?
697        kwargs.pop("y", None)
698        kwargs.pop("args", None)
699        kwargs.pop("add_labels", None)
700
701    _sizes = kwargs.pop("markersize", kwargs.pop("linewidth", None))
702    size_norm = kwargs.pop("size_norm", None)
703    size_mapping = kwargs.pop("size_mapping", None)  # set by facetgrid
704    cmap_params = kwargs.pop("cmap_params", None)
705
706    figsize = kwargs.pop("figsize", None)
707    subplot_kws = dict()
708    if z is not None and ax is None:
709        # TODO: Importing Axes3D is not necessary in matplotlib >= 3.2.
710        # Remove when minimum requirement of matplotlib is 3.2:
711        from mpl_toolkits.mplot3d import Axes3D  # type: ignore # noqa
712
713        subplot_kws.update(projection="3d")
714        ax = get_axis(figsize, size, aspect, ax, **subplot_kws)
715        # Using 30, 30 minimizes rotation of the plot. Making it easier to
716        # build on your intuition from 2D plots:
717        if LooseVersion(plt.matplotlib.__version__) < "3.5.0":
718            ax.view_init(azim=30, elev=30)
719        else:
720            # https://github.com/matplotlib/matplotlib/pull/19873
721            ax.view_init(azim=30, elev=30, vertical_axis="y")
722    else:
723        ax = get_axis(figsize, size, aspect, ax, **subplot_kws)
724
725    _data = _infer_scatter_metadata(darray, x, z, hue, hue_style, _sizes)
726
727    add_guide = kwargs.pop("add_guide", None)
728    if add_legend is not None:
729        pass
730    elif add_guide is None or add_guide is True:
731        add_legend = True if _data["hue_style"] == "discrete" else False
732    elif add_legend is None:
733        add_legend = False
734
735    if add_colorbar is not None:
736        pass
737    elif add_guide is None or add_guide is True:
738        add_colorbar = True if _data["hue_style"] == "continuous" else False
739    else:
740        add_colorbar = False
741
742    # need to infer size_mapping with full dataset
743    _data.update(
744        _infer_scatter_data(
745            darray,
746            x,
747            z,
748            hue,
749            _sizes,
750            size_norm,
751            size_mapping,
752            _MARKERSIZE_RANGE,
753        )
754    )
755
756    cmap_params_subset = {}
757    if _data["hue"] is not None:
758        kwargs.update(c=_data["hue"].values.ravel())
759        cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
760            scatter, _data["hue"].values, **locals()
761        )
762
763        # subset that can be passed to scatter, hist2d
764        cmap_params_subset = {
765            vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"]
766        }
767
768    if _data["size"] is not None:
769        kwargs.update(s=_data["size"].values.ravel())
770
771    if LooseVersion(plt.matplotlib.__version__) < "3.5.0":
772        # Plot the data. 3d plots has the z value in upward direction
773        # instead of y. To make jumping between 2d and 3d easy and intuitive
774        # switch the order so that z is shown in the depthwise direction:
775        axis_order = ["x", "z", "y"]
776    else:
777        # Switching axis order not needed in 3.5.0, can also simplify the code
778        # that uses axis_order:
779        # https://github.com/matplotlib/matplotlib/pull/19873
780        axis_order = ["x", "y", "z"]
781
782    primitive = ax.scatter(
783        *[
784            _data[v].values.ravel()
785            for v in axis_order
786            if _data.get(v, None) is not None
787        ],
788        **cmap_params_subset,
789        **kwargs,
790    )
791
792    # Set x, y, z labels:
793    i = 0
794    set_label = [ax.set_xlabel, ax.set_ylabel, getattr(ax, "set_zlabel", None)]
795    for v in axis_order:
796        if _data.get(f"{v}label", None) is not None:
797            set_label[i](_data[f"{v}label"])
798            i += 1
799
800    if add_legend:
801
802        def to_label(data, key, x):
803            """Map prop values back to its original values."""
804            if key in data:
805                # Use reindex to be less sensitive to float errors.
806                # Return as numpy array since legend_elements
807                # seems to require that:
808                return data[key].reindex(x, method="nearest").to_numpy()
809            else:
810                return x
811
812        handles, labels = [], []
813        for subtitle, prop, func in [
814            (
815                _data["hue_label"],
816                "colors",
817                functools.partial(to_label, _data, "hue_to_label"),
818            ),
819            (
820                _data["size_label"],
821                "sizes",
822                functools.partial(to_label, _data, "size_to_label"),
823            ),
824        ]:
825            if subtitle:
826                # Get legend handles and labels that displays the
827                # values correctly. Order might be different because
828                # legend_elements uses np.unique instead of pd.unique,
829                # FacetGrid.add_legend might have troubles with this:
830                hdl, lbl = legend_elements(primitive, prop, num="auto", func=func)
831                hdl, lbl = _legend_add_subtitle(hdl, lbl, subtitle, ax.scatter)
832                handles += hdl
833                labels += lbl
834        legend = ax.legend(handles, labels, framealpha=0.5)
835        _adjust_legend_subtitles(legend)
836
837    if add_colorbar and _data["hue_label"]:
838        if _data["hue_style"] == "discrete":
839            raise NotImplementedError("Cannot create a colorbar for non numerics.")
840        cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs
841        if "label" not in cbar_kwargs:
842            cbar_kwargs["label"] = _data["hue_label"]
843        _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params)
844
845    return primitive
846
847
848# MUST run before any 2d plotting functions are defined since
849# _plot2d decorator adds them as methods here.
850class _PlotMethods:
851    """
852    Enables use of xarray.plot functions as attributes on a DataArray.
853    For example, DataArray.plot.imshow
854    """
855
856    __slots__ = ("_da",)
857
858    def __init__(self, darray):
859        self._da = darray
860
861    def __call__(self, **kwargs):
862        return plot(self._da, **kwargs)
863
864    # we can't use functools.wraps here since that also modifies the name / qualname
865    __doc__ = __call__.__doc__ = plot.__doc__
866    __call__.__wrapped__ = plot  # type: ignore[attr-defined]
867    __call__.__annotations__ = plot.__annotations__
868
869    @functools.wraps(hist)
870    def hist(self, ax=None, **kwargs):
871        return hist(self._da, ax=ax, **kwargs)
872
873    @functools.wraps(line)
874    def line(self, *args, **kwargs):
875        return line(self._da, *args, **kwargs)
876
877    @functools.wraps(step)
878    def step(self, *args, **kwargs):
879        return step(self._da, *args, **kwargs)
880
881    @functools.wraps(scatter)
882    def _scatter(self, *args, **kwargs):
883        return scatter(self._da, *args, **kwargs)
884
885
886def override_signature(f):
887    def wrapper(func):
888        func.__wrapped__ = f
889
890        return func
891
892    return wrapper
893
894
895def _plot2d(plotfunc):
896    """
897    Decorator for common 2d plotting logic
898
899    Also adds the 2d plot method to class _PlotMethods
900    """
901    commondoc = """
902    Parameters
903    ----------
904    darray : DataArray
905        Must be two-dimensional, unless creating faceted plots.
906    x : str, optional
907        Coordinate for *x* axis. If ``None``, use ``darray.dims[1]``.
908    y : str, optional
909        Coordinate for *y* axis. If ``None``, use ``darray.dims[0]``.
910    figsize : tuple, optional
911        A tuple (width, height) of the figure in inches.
912        Mutually exclusive with ``size`` and ``ax``.
913    aspect : scalar, optional
914        Aspect ratio of plot, so that ``aspect * size`` gives the *width* in
915        inches. Only used if a ``size`` is provided.
916    size : scalar, optional
917        If provided, create a new figure for the plot with the given size:
918        *height* (in inches) of each plot. See also: ``aspect``.
919    ax : matplotlib axes object, optional
920        Axes on which to plot. By default, use the current axes.
921        Mutually exclusive with ``size`` and ``figsize``.
922    row : string, optional
923        If passed, make row faceted plots on this dimension name.
924    col : string, optional
925        If passed, make column faceted plots on this dimension name.
926    col_wrap : int, optional
927        Use together with ``col`` to wrap faceted plots.
928    xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional
929        Specifies scaling for the *x*- and *y*-axis, respectively.
930    xticks, yticks : array-like, optional
931        Specify tick locations for *x*- and *y*-axis.
932    xlim, ylim : array-like, optional
933        Specify *x*- and *y*-axis limits.
934    xincrease : None, True, or False, optional
935        Should the values on the *x* axis be increasing from left to right?
936        If ``None``, use the default for the Matplotlib function.
937    yincrease : None, True, or False, optional
938        Should the values on the *y* axis be increasing from top to bottom?
939        If ``None``, use the default for the Matplotlib function.
940    add_colorbar : bool, optional
941        Add colorbar to axes.
942    add_labels : bool, optional
943        Use xarray metadata to label axes.
944    norm : matplotlib.colors.Normalize, optional
945        If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding
946        kwarg must be ``None``.
947    vmin, vmax : float, optional
948        Values to anchor the colormap, otherwise they are inferred from the
949        data and other keyword arguments. When a diverging dataset is inferred,
950        setting one of these values will fix the other by symmetry around
951        ``center``. Setting both values prevents use of a diverging colormap.
952        If discrete levels are provided as an explicit list, both of these
953        values are ignored.
954    cmap : matplotlib colormap name or colormap, optional
955        The mapping from data values to color space. If not provided, this
956        will be either be ``'viridis'`` (if the function infers a sequential
957        dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset).
958        See :doc:`Choosing Colormaps in Matplotlib <matplotlib:tutorials/colors/colormaps>`
959        for more information.
960
961        If *seaborn* is installed, ``cmap`` may also be a
962        `seaborn color palette <https://seaborn.pydata.org/tutorial/color_palettes.html>`_.
963        Note: if ``cmap`` is a seaborn color palette and the plot type
964        is not ``'contour'`` or ``'contourf'``, ``levels`` must also be specified.
965    colors : str or array-like of color-like, optional
966        A single color or a sequence of colors. If the plot type is not ``'contour'``
967        or ``'contourf'``, the ``levels`` argument is required.
968    center : float, optional
969        The value at which to center the colormap. Passing this value implies
970        use of a diverging colormap. Setting it to ``False`` prevents use of a
971        diverging colormap.
972    robust : bool, optional
973        If ``True`` and ``vmin`` or ``vmax`` are absent, the colormap range is
974        computed with 2nd and 98th percentiles instead of the extreme values.
975    extend : {'neither', 'both', 'min', 'max'}, optional
976        How to draw arrows extending the colorbar beyond its limits. If not
977        provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits.
978    levels : int or array-like, optional
979        Split the colormap (``cmap``) into discrete color intervals. If an integer
980        is provided, "nice" levels are chosen based on the data range: this can
981        imply that the final number of levels is not exactly the expected one.
982        Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to
983        setting ``levels=np.linspace(vmin, vmax, N)``.
984    infer_intervals : bool, optional
985        Only applies to pcolormesh. If ``True``, the coordinate intervals are
986        passed to pcolormesh. If ``False``, the original coordinates are used
987        (this can be useful for certain map projections). The default is to
988        always infer intervals, unless the mesh is irregular and plotted on
989        a map projection.
990    subplot_kws : dict, optional
991        Dictionary of keyword arguments for Matplotlib subplots. Only used
992        for 2D and faceted plots.
993        (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`).
994    cbar_ax : matplotlib axes object, optional
995        Axes in which to draw the colorbar.
996    cbar_kwargs : dict, optional
997        Dictionary of keyword arguments to pass to the colorbar
998        (see :meth:`matplotlib:matplotlib.figure.Figure.colorbar`).
999    **kwargs : optional
1000        Additional keyword arguments to wrapped Matplotlib function.
1001
1002    Returns
1003    -------
1004    artist :
1005        The same type of primitive artist that the wrapped Matplotlib
1006        function returns.
1007    """
1008
1009    # Build on the original docstring
1010    plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}"
1011
1012    # plotfunc and newplotfunc have different signatures:
1013    # - plotfunc: (x, y, z, ax, **kwargs)
1014    # - newplotfunc: (darray, x, y, **kwargs)
1015    # where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray
1016    # and variable names. newplotfunc also explicitly lists most kwargs, so we
1017    # need to shorten it
1018    def signature(darray, x, y, **kwargs):
1019        pass
1020
1021    @override_signature(signature)
1022    @functools.wraps(plotfunc)
1023    def newplotfunc(
1024        darray,
1025        x=None,
1026        y=None,
1027        figsize=None,
1028        size=None,
1029        aspect=None,
1030        ax=None,
1031        row=None,
1032        col=None,
1033        col_wrap=None,
1034        xincrease=True,
1035        yincrease=True,
1036        add_colorbar=None,
1037        add_labels=True,
1038        vmin=None,
1039        vmax=None,
1040        cmap=None,
1041        center=None,
1042        robust=False,
1043        extend=None,
1044        levels=None,
1045        infer_intervals=None,
1046        colors=None,
1047        subplot_kws=None,
1048        cbar_ax=None,
1049        cbar_kwargs=None,
1050        xscale=None,
1051        yscale=None,
1052        xticks=None,
1053        yticks=None,
1054        xlim=None,
1055        ylim=None,
1056        norm=None,
1057        **kwargs,
1058    ):
1059        # All 2d plots in xarray share this function signature.
1060        # Method signature below should be consistent.
1061
1062        # Decide on a default for the colorbar before facetgrids
1063        if add_colorbar is None:
1064            add_colorbar = True
1065            if plotfunc.__name__ == "contour" or (
1066                plotfunc.__name__ == "surface" and cmap is None
1067            ):
1068                add_colorbar = False
1069        imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == (
1070            3 + (row is not None) + (col is not None)
1071        )
1072        if imshow_rgb:
1073            # Don't add a colorbar when showing an image with explicit colors
1074            add_colorbar = False
1075            # Matplotlib does not support normalising RGB data, so do it here.
1076            # See eg. https://github.com/matplotlib/matplotlib/pull/10220
1077            if robust or vmax is not None or vmin is not None:
1078                darray = _rescale_imshow_rgb(darray.as_numpy(), vmin, vmax, robust)
1079                vmin, vmax, robust = None, None, False
1080
1081        if subplot_kws is None:
1082            subplot_kws = dict()
1083
1084        if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False):
1085            if ax is None:
1086                # TODO: Importing Axes3D is no longer necessary in matplotlib >= 3.2.
1087                # Remove when minimum requirement of matplotlib is 3.2:
1088                from mpl_toolkits.mplot3d import Axes3D  # type: ignore  # noqa: F401
1089
1090                # delete so it does not end up in locals()
1091                del Axes3D
1092
1093                # Need to create a "3d" Axes instance for surface plots
1094                subplot_kws["projection"] = "3d"
1095
1096            # In facet grids, shared axis labels don't make sense for surface plots
1097            sharex = False
1098            sharey = False
1099
1100        # Handle facetgrids first
1101        if row or col:
1102            allargs = locals().copy()
1103            del allargs["darray"]
1104            del allargs["imshow_rgb"]
1105            allargs.update(allargs.pop("kwargs"))
1106            # Need the decorated plotting function
1107            allargs["plotfunc"] = globals()[plotfunc.__name__]
1108            return _easy_facetgrid(darray, kind="dataarray", **allargs)
1109
1110        if (
1111            plotfunc.__name__ == "surface"
1112            and not kwargs.get("_is_facetgrid", False)
1113            and ax is not None
1114        ):
1115            import mpl_toolkits  # type: ignore
1116
1117            if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D):
1118                raise ValueError(
1119                    "If ax is passed to surface(), it must be created with "
1120                    'projection="3d"'
1121                )
1122
1123        rgb = kwargs.pop("rgb", None)
1124        if rgb is not None and plotfunc.__name__ != "imshow":
1125            raise ValueError('The "rgb" keyword is only valid for imshow()')
1126        elif rgb is not None and not imshow_rgb:
1127            raise ValueError(
1128                'The "rgb" keyword is only valid for imshow()'
1129                "with a three-dimensional array (per facet)"
1130            )
1131
1132        xlab, ylab = _infer_xy_labels(
1133            darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb
1134        )
1135
1136        xval = darray[xlab]
1137        yval = darray[ylab]
1138
1139        if xval.ndim > 1 or yval.ndim > 1 or plotfunc.__name__ == "surface":
1140            # Passing 2d coordinate values, need to ensure they are transposed the same
1141            # way as darray.
1142            # Also surface plots always need 2d coordinates
1143            xval = xval.broadcast_like(darray)
1144            yval = yval.broadcast_like(darray)
1145            dims = darray.dims
1146        else:
1147            dims = (yval.dims[0], xval.dims[0])
1148
1149        # May need to transpose for correct x, y labels
1150        # xlab may be the name of a coord, we have to check for dim names
1151        if imshow_rgb:
1152            # For RGB[A] images, matplotlib requires the color dimension
1153            # to be last.  In Xarray the order should be unimportant, so
1154            # we transpose to (y, x, color) to make this work.
1155            yx_dims = (ylab, xlab)
1156            dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
1157
1158        if dims != darray.dims:
1159            darray = darray.transpose(*dims, transpose_coords=True)
1160
1161        # better to pass the ndarrays directly to plotting functions
1162        xval = xval.to_numpy()
1163        yval = yval.to_numpy()
1164
1165        # Pass the data as a masked ndarray too
1166        zval = darray.to_masked_array(copy=False)
1167
1168        # Replace pd.Intervals if contained in xval or yval.
1169        xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__)
1170        yplt, ylab_extra = _resolve_intervals_2dplot(yval, plotfunc.__name__)
1171
1172        _ensure_plottable(xplt, yplt, zval)
1173
1174        cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
1175            plotfunc,
1176            zval.data,
1177            **locals(),
1178            _is_facetgrid=kwargs.pop("_is_facetgrid", False),
1179        )
1180
1181        if "contour" in plotfunc.__name__:
1182            # extend is a keyword argument only for contour and contourf, but
1183            # passing it to the colorbar is sufficient for imshow and
1184            # pcolormesh
1185            kwargs["extend"] = cmap_params["extend"]
1186            kwargs["levels"] = cmap_params["levels"]
1187            # if colors == a single color, matplotlib draws dashed negative
1188            # contours. we lose this feature if we pass cmap and not colors
1189            if isinstance(colors, str):
1190                cmap_params["cmap"] = None
1191                kwargs["colors"] = colors
1192
1193        if "pcolormesh" == plotfunc.__name__:
1194            kwargs["infer_intervals"] = infer_intervals
1195            kwargs["xscale"] = xscale
1196            kwargs["yscale"] = yscale
1197
1198        if "imshow" == plotfunc.__name__ and isinstance(aspect, str):
1199            # forbid usage of mpl strings
1200            raise ValueError("plt.imshow's `aspect` kwarg is not available in xarray")
1201
1202        ax = get_axis(figsize, size, aspect, ax, **subplot_kws)
1203
1204        primitive = plotfunc(
1205            xplt,
1206            yplt,
1207            zval,
1208            ax=ax,
1209            cmap=cmap_params["cmap"],
1210            vmin=cmap_params["vmin"],
1211            vmax=cmap_params["vmax"],
1212            norm=cmap_params["norm"],
1213            **kwargs,
1214        )
1215
1216        # Label the plot with metadata
1217        if add_labels:
1218            ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra))
1219            ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra))
1220            ax.set_title(darray._title_for_slice())
1221            if plotfunc.__name__ == "surface":
1222                ax.set_zlabel(label_from_attrs(darray))
1223
1224        if add_colorbar:
1225            if add_labels and "label" not in cbar_kwargs:
1226                cbar_kwargs["label"] = label_from_attrs(darray)
1227            cbar = _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params)
1228        elif cbar_ax is not None or cbar_kwargs:
1229            # inform the user about keywords which aren't used
1230            raise ValueError(
1231                "cbar_ax and cbar_kwargs can't be used with add_colorbar=False."
1232            )
1233
1234        # origin kwarg overrides yincrease
1235        if "origin" in kwargs:
1236            yincrease = None
1237
1238        _update_axes(
1239            ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim
1240        )
1241
1242        # Rotate dates on xlabels
1243        # Do this without calling autofmt_xdate so that x-axes ticks
1244        # on other subplots (if any) are not deleted.
1245        # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots
1246        if np.issubdtype(xplt.dtype, np.datetime64):
1247            for xlabels in ax.get_xticklabels():
1248                xlabels.set_rotation(30)
1249                xlabels.set_ha("right")
1250
1251        return primitive
1252
1253    # For use as DataArray.plot.plotmethod
1254    @functools.wraps(newplotfunc)
1255    def plotmethod(
1256        _PlotMethods_obj,
1257        x=None,
1258        y=None,
1259        figsize=None,
1260        size=None,
1261        aspect=None,
1262        ax=None,
1263        row=None,
1264        col=None,
1265        col_wrap=None,
1266        xincrease=True,
1267        yincrease=True,
1268        add_colorbar=None,
1269        add_labels=True,
1270        vmin=None,
1271        vmax=None,
1272        cmap=None,
1273        colors=None,
1274        center=None,
1275        robust=False,
1276        extend=None,
1277        levels=None,
1278        infer_intervals=None,
1279        subplot_kws=None,
1280        cbar_ax=None,
1281        cbar_kwargs=None,
1282        xscale=None,
1283        yscale=None,
1284        xticks=None,
1285        yticks=None,
1286        xlim=None,
1287        ylim=None,
1288        norm=None,
1289        **kwargs,
1290    ):
1291        """
1292        The method should have the same signature as the function.
1293
1294        This just makes the method work on Plotmethods objects,
1295        and passes all the other arguments straight through.
1296        """
1297        allargs = locals()
1298        allargs["darray"] = _PlotMethods_obj._da
1299        allargs.update(kwargs)
1300        for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]:
1301            del allargs[arg]
1302        return newplotfunc(**allargs)
1303
1304    # Add to class _PlotMethods
1305    setattr(_PlotMethods, plotmethod.__name__, plotmethod)
1306
1307    return newplotfunc
1308
1309
1310@_plot2d
1311def imshow(x, y, z, ax, **kwargs):
1312    """
1313    Image plot of 2D DataArray.
1314
1315    Wraps :py:func:`matplotlib:matplotlib.pyplot.imshow`.
1316
1317    While other plot methods require the DataArray to be strictly
1318    two-dimensional, ``imshow`` also accepts a 3D array where some
1319    dimension can be interpreted as RGB or RGBA color channels and
1320    allows this dimension to be specified via the kwarg ``rgb=``.
1321
1322    Unlike :py:func:`matplotlib:matplotlib.pyplot.imshow`, which ignores ``vmin``/``vmax``
1323    for RGB(A) data,
1324    xarray *will* use ``vmin`` and ``vmax`` for RGB(A) data
1325    by applying a single scaling factor and offset to all bands.
1326    Passing  ``robust=True`` infers ``vmin`` and ``vmax``
1327    :ref:`in the usual way <robust-plotting>`.
1328
1329    .. note::
1330        This function needs uniformly spaced coordinates to
1331        properly label the axes. Call :py:meth:`DataArray.plot` to check.
1332
1333    The pixels are centered on the coordinates. For example, if the coordinate
1334    value is 3.2, then the pixels for those coordinates will be centered on 3.2.
1335    """
1336
1337    if x.ndim != 1 or y.ndim != 1:
1338        raise ValueError(
1339            "imshow requires 1D coordinates, try using pcolormesh or contour(f)"
1340        )
1341
1342    def _center_pixels(x):
1343        """Center the pixels on the coordinates."""
1344        if np.issubdtype(x.dtype, str):
1345            # When using strings as inputs imshow converts it to
1346            # integers. Choose extent values which puts the indices in
1347            # in the center of the pixels:
1348            return 0 - 0.5, len(x) - 0.5
1349
1350        try:
1351            # Center the pixels assuming uniform spacing:
1352            xstep = 0.5 * (x[1] - x[0])
1353        except IndexError:
1354            # Arbitrary default value, similar to matplotlib behaviour:
1355            xstep = 0.1
1356
1357        return x[0] - xstep, x[-1] + xstep
1358
1359    # Center the pixels:
1360    left, right = _center_pixels(x)
1361    top, bottom = _center_pixels(y)
1362
1363    defaults = {"origin": "upper", "interpolation": "nearest"}
1364
1365    if not hasattr(ax, "projection"):
1366        # not for cartopy geoaxes
1367        defaults["aspect"] = "auto"
1368
1369    # Allow user to override these defaults
1370    defaults.update(kwargs)
1371
1372    if defaults["origin"] == "upper":
1373        defaults["extent"] = [left, right, bottom, top]
1374    else:
1375        defaults["extent"] = [left, right, top, bottom]
1376
1377    if z.ndim == 3:
1378        # matplotlib imshow uses black for missing data, but Xarray makes
1379        # missing data transparent.  We therefore add an alpha channel if
1380        # there isn't one, and set it to transparent where data is masked.
1381        if z.shape[-1] == 3:
1382            alpha = np.ma.ones(z.shape[:2] + (1,), dtype=z.dtype)
1383            if np.issubdtype(z.dtype, np.integer):
1384                alpha *= 255
1385            z = np.ma.concatenate((z, alpha), axis=2)
1386        else:
1387            z = z.copy()
1388        z[np.any(z.mask, axis=-1), -1] = 0
1389
1390    primitive = ax.imshow(z, **defaults)
1391
1392    # If x or y are strings the ticklabels have been replaced with
1393    # integer indices. Replace them back to strings:
1394    for axis, v in [("x", x), ("y", y)]:
1395        if np.issubdtype(v.dtype, str):
1396            getattr(ax, f"set_{axis}ticks")(np.arange(len(v)))
1397            getattr(ax, f"set_{axis}ticklabels")(v)
1398
1399    return primitive
1400
1401
1402@_plot2d
1403def contour(x, y, z, ax, **kwargs):
1404    """
1405    Contour plot of 2D DataArray.
1406
1407    Wraps :py:func:`matplotlib:matplotlib.pyplot.contour`.
1408    """
1409    primitive = ax.contour(x, y, z, **kwargs)
1410    return primitive
1411
1412
1413@_plot2d
1414def contourf(x, y, z, ax, **kwargs):
1415    """
1416    Filled contour plot of 2D DataArray.
1417
1418    Wraps :py:func:`matplotlib:matplotlib.pyplot.contourf`.
1419    """
1420    primitive = ax.contourf(x, y, z, **kwargs)
1421    return primitive
1422
1423
1424@_plot2d
1425def pcolormesh(x, y, z, ax, xscale=None, yscale=None, infer_intervals=None, **kwargs):
1426    """
1427    Pseudocolor plot of 2D DataArray.
1428
1429    Wraps :py:func:`matplotlib:matplotlib.pyplot.pcolormesh`.
1430    """
1431
1432    # decide on a default for infer_intervals (GH781)
1433    x = np.asarray(x)
1434    if infer_intervals is None:
1435        if hasattr(ax, "projection"):
1436            if len(x.shape) == 1:
1437                infer_intervals = True
1438            else:
1439                infer_intervals = False
1440        else:
1441            infer_intervals = True
1442
1443    if (
1444        infer_intervals
1445        and not np.issubdtype(x.dtype, str)
1446        and (
1447            (np.shape(x)[0] == np.shape(z)[1])
1448            or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1]))
1449        )
1450    ):
1451        if len(x.shape) == 1:
1452            x = _infer_interval_breaks(x, check_monotonic=True, scale=xscale)
1453        else:
1454            # we have to infer the intervals on both axes
1455            x = _infer_interval_breaks(x, axis=1, scale=xscale)
1456            x = _infer_interval_breaks(x, axis=0, scale=xscale)
1457
1458    if (
1459        infer_intervals
1460        and not np.issubdtype(y.dtype, str)
1461        and (np.shape(y)[0] == np.shape(z)[0])
1462    ):
1463        if len(y.shape) == 1:
1464            y = _infer_interval_breaks(y, check_monotonic=True, scale=yscale)
1465        else:
1466            # we have to infer the intervals on both axes
1467            y = _infer_interval_breaks(y, axis=1, scale=yscale)
1468            y = _infer_interval_breaks(y, axis=0, scale=yscale)
1469
1470    primitive = ax.pcolormesh(x, y, z, **kwargs)
1471
1472    # by default, pcolormesh picks "round" values for bounds
1473    # this results in ugly looking plots with lots of surrounding whitespace
1474    if not hasattr(ax, "projection") and x.ndim == 1 and y.ndim == 1:
1475        # not a cartopy geoaxis
1476        ax.set_xlim(x[0], x[-1])
1477        ax.set_ylim(y[0], y[-1])
1478
1479    return primitive
1480
1481
1482@_plot2d
1483def surface(x, y, z, ax, **kwargs):
1484    """
1485    Surface plot of 2D DataArray.
1486
1487    Wraps :py:meth:`matplotlib:mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`.
1488    """
1489    primitive = ax.plot_surface(x, y, z, **kwargs)
1490    return primitive
1491