1""" 2Use this module directly: 3 import xarray.plot as xplt 4 5Or use the methods on a DataArray or Dataset: 6 DataArray.plot._____ 7 Dataset.plot._____ 8""" 9import functools 10from distutils.version import LooseVersion 11 12import numpy as np 13import pandas as pd 14 15from ..core.alignment import broadcast 16from .facetgrid import _easy_facetgrid 17from .utils import ( 18 _add_colorbar, 19 _adjust_legend_subtitles, 20 _assert_valid_xy, 21 _ensure_plottable, 22 _infer_interval_breaks, 23 _infer_xy_labels, 24 _is_numeric, 25 _legend_add_subtitle, 26 _process_cmap_cbar_kwargs, 27 _rescale_imshow_rgb, 28 _resolve_intervals_1dplot, 29 _resolve_intervals_2dplot, 30 _update_axes, 31 get_axis, 32 label_from_attrs, 33 legend_elements, 34 plt, 35) 36 37# copied from seaborn 38_MARKERSIZE_RANGE = np.array([18.0, 72.0]) 39 40 41def _infer_scatter_metadata(darray, x, z, hue, hue_style, size): 42 def _determine_array(darray, name, array_style): 43 """Find and determine what type of array it is.""" 44 array = darray[name] 45 array_is_numeric = _is_numeric(array.values) 46 47 if array_style is None: 48 array_style = "continuous" if array_is_numeric else "discrete" 49 elif array_style not in ["discrete", "continuous"]: 50 raise ValueError( 51 f"The style '{array_style}' is not valid, " 52 "valid options are None, 'discrete' or 'continuous'." 53 ) 54 55 array_label = label_from_attrs(array) 56 57 return array, array_style, array_label 58 59 # Add nice looking labels: 60 out = dict(ylabel=label_from_attrs(darray)) 61 out.update( 62 { 63 k: label_from_attrs(darray[v]) if v in darray.coords else None 64 for k, v in [("xlabel", x), ("zlabel", z)] 65 } 66 ) 67 68 # Add styles and labels for the dataarrays: 69 for type_, a, style in [("hue", hue, hue_style), ("size", size, None)]: 70 tp, stl, lbl = f"{type_}", f"{type_}_style", f"{type_}_label" 71 if a: 72 out[tp], out[stl], out[lbl] = _determine_array(darray, a, style) 73 else: 74 out[tp], out[stl], out[lbl] = None, None, None 75 76 return out 77 78 79# copied from seaborn 80def _parse_size(data, norm, width): 81 """ 82 Determine what type of data it is. Then normalize it to width. 83 84 If the data is categorical, normalize it to numbers. 85 """ 86 if data is None: 87 return None 88 89 data = data.values.ravel() 90 91 if not _is_numeric(data): 92 # Data is categorical. 93 # Use pd.unique instead of np.unique because that keeps 94 # the order of the labels: 95 levels = pd.unique(data) 96 numbers = np.arange(1, 1 + len(levels)) 97 else: 98 levels = numbers = np.sort(np.unique(data)) 99 100 min_width, max_width = width 101 # width_range = min_width, max_width 102 103 if norm is None: 104 norm = plt.Normalize() 105 elif isinstance(norm, tuple): 106 norm = plt.Normalize(*norm) 107 elif not isinstance(norm, plt.Normalize): 108 err = "``size_norm`` must be None, tuple, or Normalize object." 109 raise ValueError(err) 110 111 norm.clip = True 112 if not norm.scaled(): 113 norm(np.asarray(numbers)) 114 # limits = norm.vmin, norm.vmax 115 116 scl = norm(numbers) 117 widths = np.asarray(min_width + scl * (max_width - min_width)) 118 if scl.mask.any(): 119 widths[scl.mask] = 0 120 sizes = dict(zip(levels, widths)) 121 122 return pd.Series(sizes) 123 124 125def _infer_scatter_data( 126 darray, x, z, hue, size, size_norm, size_mapping=None, size_range=(1, 10) 127): 128 # Broadcast together all the chosen variables: 129 to_broadcast = dict(y=darray) 130 to_broadcast.update( 131 {k: darray[v] for k, v in dict(x=x, z=z).items() if v is not None} 132 ) 133 to_broadcast.update( 134 {k: darray[v] for k, v in dict(hue=hue, size=size).items() if v in darray.dims} 135 ) 136 broadcasted = dict(zip(to_broadcast.keys(), broadcast(*(to_broadcast.values())))) 137 138 # Normalize hue and size and create lookup tables: 139 for type_, mapping, norm, width in [ 140 ("hue", None, None, [0, 1]), 141 ("size", size_mapping, size_norm, size_range), 142 ]: 143 broadcasted_type = broadcasted.get(type_, None) 144 if broadcasted_type is not None: 145 if mapping is None: 146 mapping = _parse_size(broadcasted_type, norm, width) 147 148 broadcasted[type_] = broadcasted_type.copy( 149 data=np.reshape( 150 mapping.loc[broadcasted_type.values.ravel()].values, 151 broadcasted_type.shape, 152 ) 153 ) 154 broadcasted[f"{type_}_to_label"] = pd.Series(mapping.index, index=mapping) 155 156 return broadcasted 157 158 159def _infer_line_data(darray, x, y, hue): 160 161 ndims = len(darray.dims) 162 163 if x is not None and y is not None: 164 raise ValueError("Cannot specify both x and y kwargs for line plots.") 165 166 if x is not None: 167 _assert_valid_xy(darray, x, "x") 168 169 if y is not None: 170 _assert_valid_xy(darray, y, "y") 171 172 if ndims == 1: 173 huename = None 174 hueplt = None 175 huelabel = "" 176 177 if x is not None: 178 xplt = darray[x] 179 yplt = darray 180 181 elif y is not None: 182 xplt = darray 183 yplt = darray[y] 184 185 else: # Both x & y are None 186 dim = darray.dims[0] 187 xplt = darray[dim] 188 yplt = darray 189 190 else: 191 if x is None and y is None and hue is None: 192 raise ValueError("For 2D inputs, please specify either hue, x or y.") 193 194 if y is None: 195 if hue is not None: 196 _assert_valid_xy(darray, hue, "hue") 197 xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) 198 xplt = darray[xname] 199 if xplt.ndim > 1: 200 if huename in darray.dims: 201 otherindex = 1 if darray.dims.index(huename) == 0 else 0 202 otherdim = darray.dims[otherindex] 203 yplt = darray.transpose(otherdim, huename, transpose_coords=False) 204 xplt = xplt.transpose(otherdim, huename, transpose_coords=False) 205 else: 206 raise ValueError( 207 "For 2D inputs, hue must be a dimension" 208 " i.e. one of " + repr(darray.dims) 209 ) 210 211 else: 212 (xdim,) = darray[xname].dims 213 (huedim,) = darray[huename].dims 214 yplt = darray.transpose(xdim, huedim) 215 216 else: 217 yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) 218 yplt = darray[yname] 219 if yplt.ndim > 1: 220 if huename in darray.dims: 221 otherindex = 1 if darray.dims.index(huename) == 0 else 0 222 otherdim = darray.dims[otherindex] 223 xplt = darray.transpose(otherdim, huename, transpose_coords=False) 224 yplt = yplt.transpose(otherdim, huename, transpose_coords=False) 225 else: 226 raise ValueError( 227 "For 2D inputs, hue must be a dimension" 228 " i.e. one of " + repr(darray.dims) 229 ) 230 231 else: 232 (ydim,) = darray[yname].dims 233 (huedim,) = darray[huename].dims 234 xplt = darray.transpose(ydim, huedim) 235 236 huelabel = label_from_attrs(darray[huename]) 237 hueplt = darray[huename] 238 239 return xplt, yplt, hueplt, huelabel 240 241 242def plot( 243 darray, 244 row=None, 245 col=None, 246 col_wrap=None, 247 ax=None, 248 hue=None, 249 rtol=0.01, 250 subplot_kws=None, 251 **kwargs, 252): 253 """ 254 Default plot of DataArray using :py:mod:`matplotlib:matplotlib.pyplot`. 255 256 Calls xarray plotting function based on the dimensions of 257 the squeezed DataArray. 258 259 =============== =========================== 260 Dimensions Plotting function 261 =============== =========================== 262 1 :py:func:`xarray.plot.line` 263 2 :py:func:`xarray.plot.pcolormesh` 264 Anything else :py:func:`xarray.plot.hist` 265 =============== =========================== 266 267 Parameters 268 ---------- 269 darray : DataArray 270 row : str, optional 271 If passed, make row faceted plots on this dimension name. 272 col : str, optional 273 If passed, make column faceted plots on this dimension name. 274 hue : str, optional 275 If passed, make faceted line plots with hue on this dimension name. 276 col_wrap : int, optional 277 Use together with ``col`` to wrap faceted plots. 278 ax : matplotlib axes object, optional 279 If ``None``, use the current axes. Not applicable when using facets. 280 rtol : float, optional 281 Relative tolerance used to determine if the indexes 282 are uniformly spaced. Usually a small positive number. 283 subplot_kws : dict, optional 284 Dictionary of keyword arguments for Matplotlib subplots 285 (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). 286 **kwargs : optional 287 Additional keyword arguments for Matplotlib. 288 289 See Also 290 -------- 291 xarray.DataArray.squeeze 292 """ 293 darray = darray.squeeze().compute() 294 295 plot_dims = set(darray.dims) 296 plot_dims.discard(row) 297 plot_dims.discard(col) 298 plot_dims.discard(hue) 299 300 ndims = len(plot_dims) 301 302 error_msg = ( 303 "Only 1d and 2d plots are supported for facets in xarray. " 304 "See the package `Seaborn` for more options." 305 ) 306 307 if ndims in [1, 2]: 308 if row or col: 309 kwargs["subplot_kws"] = subplot_kws 310 kwargs["row"] = row 311 kwargs["col"] = col 312 kwargs["col_wrap"] = col_wrap 313 if ndims == 1: 314 plotfunc = line 315 kwargs["hue"] = hue 316 elif ndims == 2: 317 if hue: 318 plotfunc = line 319 kwargs["hue"] = hue 320 else: 321 plotfunc = pcolormesh 322 kwargs["subplot_kws"] = subplot_kws 323 else: 324 if row or col or hue: 325 raise ValueError(error_msg) 326 plotfunc = hist 327 328 kwargs["ax"] = ax 329 330 return plotfunc(darray, **kwargs) 331 332 333# This function signature should not change so that it can use 334# matplotlib format strings 335def line( 336 darray, 337 *args, 338 row=None, 339 col=None, 340 figsize=None, 341 aspect=None, 342 size=None, 343 ax=None, 344 hue=None, 345 x=None, 346 y=None, 347 xincrease=None, 348 yincrease=None, 349 xscale=None, 350 yscale=None, 351 xticks=None, 352 yticks=None, 353 xlim=None, 354 ylim=None, 355 add_legend=True, 356 _labels=True, 357 **kwargs, 358): 359 """ 360 Line plot of DataArray values. 361 362 Wraps :py:func:`matplotlib:matplotlib.pyplot.plot`. 363 364 Parameters 365 ---------- 366 darray : DataArray 367 Either 1D or 2D. If 2D, one of ``hue``, ``x`` or ``y`` must be provided. 368 figsize : tuple, optional 369 A tuple (width, height) of the figure in inches. 370 Mutually exclusive with ``size`` and ``ax``. 371 aspect : scalar, optional 372 Aspect ratio of plot, so that ``aspect * size`` gives the *width* in 373 inches. Only used if a ``size`` is provided. 374 size : scalar, optional 375 If provided, create a new figure for the plot with the given size: 376 *height* (in inches) of each plot. See also: ``aspect``. 377 ax : matplotlib axes object, optional 378 Axes on which to plot. By default, the current is used. 379 Mutually exclusive with ``size`` and ``figsize``. 380 hue : str, optional 381 Dimension or coordinate for which you want multiple lines plotted. 382 If plotting against a 2D coordinate, ``hue`` must be a dimension. 383 x, y : str, optional 384 Dimension, coordinate or multi-index level for *x*, *y* axis. 385 Only one of these may be specified. 386 The other will be used for values from the DataArray on which this 387 plot method is called. 388 xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional 389 Specifies scaling for the *x*- and *y*-axis, respectively. 390 xticks, yticks : array-like, optional 391 Specify tick locations for *x*- and *y*-axis. 392 xlim, ylim : array-like, optional 393 Specify *x*- and *y*-axis limits. 394 xincrease : None, True, or False, optional 395 Should the values on the *x* axis be increasing from left to right? 396 if ``None``, use the default for the Matplotlib function. 397 yincrease : None, True, or False, optional 398 Should the values on the *y* axis be increasing from top to bottom? 399 if ``None``, use the default for the Matplotlib function. 400 add_legend : bool, optional 401 Add legend with *y* axis coordinates (2D inputs only). 402 *args, **kwargs : optional 403 Additional arguments to :py:func:`matplotlib:matplotlib.pyplot.plot`. 404 """ 405 # Handle facetgrids first 406 if row or col: 407 allargs = locals().copy() 408 allargs.update(allargs.pop("kwargs")) 409 allargs.pop("darray") 410 return _easy_facetgrid(darray, line, kind="line", **allargs) 411 412 ndims = len(darray.dims) 413 if ndims > 2: 414 raise ValueError( 415 "Line plots are for 1- or 2-dimensional DataArrays. " 416 "Passed DataArray has {ndims} " 417 "dimensions".format(ndims=ndims) 418 ) 419 420 # The allargs dict passed to _easy_facetgrid above contains args 421 if args == (): 422 args = kwargs.pop("args", ()) 423 else: 424 assert "args" not in kwargs 425 426 ax = get_axis(figsize, size, aspect, ax) 427 xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) 428 429 # Remove pd.Intervals if contained in xplt.values and/or yplt.values. 430 xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( 431 xplt.to_numpy(), yplt.to_numpy(), kwargs 432 ) 433 xlabel = label_from_attrs(xplt, extra=x_suffix) 434 ylabel = label_from_attrs(yplt, extra=y_suffix) 435 436 _ensure_plottable(xplt_val, yplt_val) 437 438 primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) 439 440 if _labels: 441 if xlabel is not None: 442 ax.set_xlabel(xlabel) 443 444 if ylabel is not None: 445 ax.set_ylabel(ylabel) 446 447 ax.set_title(darray._title_for_slice()) 448 449 if darray.ndim == 2 and add_legend: 450 ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) 451 452 # Rotate dates on xlabels 453 # Do this without calling autofmt_xdate so that x-axes ticks 454 # on other subplots (if any) are not deleted. 455 # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots 456 if np.issubdtype(xplt.dtype, np.datetime64): 457 for xlabels in ax.get_xticklabels(): 458 xlabels.set_rotation(30) 459 xlabels.set_ha("right") 460 461 _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) 462 463 return primitive 464 465 466def step(darray, *args, where="pre", drawstyle=None, ds=None, **kwargs): 467 """ 468 Step plot of DataArray values. 469 470 Similar to :py:func:`matplotlib:matplotlib.pyplot.step`. 471 472 Parameters 473 ---------- 474 where : {'pre', 'post', 'mid'}, default: 'pre' 475 Define where the steps should be placed: 476 477 - ``'pre'``: The y value is continued constantly to the left from 478 every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the 479 value ``y[i]``. 480 - ``'post'``: The y value is continued constantly to the right from 481 every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the 482 value ``y[i]``. 483 - ``'mid'``: Steps occur half-way between the *x* positions. 484 485 Note that this parameter is ignored if one coordinate consists of 486 :py:class:`pandas.Interval` values, e.g. as a result of 487 :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual 488 boundaries of the interval are used. 489 *args, **kwargs : optional 490 Additional arguments for :py:func:`xarray.plot.line`. 491 """ 492 if where not in {"pre", "post", "mid"}: 493 raise ValueError("'where' argument to step must be 'pre', 'post' or 'mid'") 494 495 if ds is not None: 496 if drawstyle is None: 497 drawstyle = ds 498 else: 499 raise TypeError("ds and drawstyle are mutually exclusive") 500 if drawstyle is None: 501 drawstyle = "" 502 drawstyle = "steps-" + where + drawstyle 503 504 return line(darray, *args, drawstyle=drawstyle, **kwargs) 505 506 507def hist( 508 darray, 509 figsize=None, 510 size=None, 511 aspect=None, 512 ax=None, 513 xincrease=None, 514 yincrease=None, 515 xscale=None, 516 yscale=None, 517 xticks=None, 518 yticks=None, 519 xlim=None, 520 ylim=None, 521 **kwargs, 522): 523 """ 524 Histogram of DataArray. 525 526 Wraps :py:func:`matplotlib:matplotlib.pyplot.hist`. 527 528 Plots *N*-dimensional arrays by first flattening the array. 529 530 Parameters 531 ---------- 532 darray : DataArray 533 Can have any number of dimensions. 534 figsize : tuple, optional 535 A tuple (width, height) of the figure in inches. 536 Mutually exclusive with ``size`` and ``ax``. 537 aspect : scalar, optional 538 Aspect ratio of plot, so that ``aspect * size`` gives the *width* in 539 inches. Only used if a ``size`` is provided. 540 size : scalar, optional 541 If provided, create a new figure for the plot with the given size: 542 *height* (in inches) of each plot. See also: ``aspect``. 543 ax : matplotlib axes object, optional 544 Axes on which to plot. By default, use the current axes. 545 Mutually exclusive with ``size`` and ``figsize``. 546 **kwargs : optional 547 Additional keyword arguments to :py:func:`matplotlib:matplotlib.pyplot.hist`. 548 549 """ 550 ax = get_axis(figsize, size, aspect, ax) 551 552 no_nan = np.ravel(darray.to_numpy()) 553 no_nan = no_nan[pd.notnull(no_nan)] 554 555 primitive = ax.hist(no_nan, **kwargs) 556 557 ax.set_title(darray._title_for_slice()) 558 ax.set_xlabel(label_from_attrs(darray)) 559 560 _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) 561 562 return primitive 563 564 565def scatter( 566 darray, 567 *args, 568 row=None, 569 col=None, 570 figsize=None, 571 aspect=None, 572 size=None, 573 ax=None, 574 hue=None, 575 hue_style=None, 576 x=None, 577 z=None, 578 xincrease=None, 579 yincrease=None, 580 xscale=None, 581 yscale=None, 582 xticks=None, 583 yticks=None, 584 xlim=None, 585 ylim=None, 586 add_legend=None, 587 add_colorbar=None, 588 cbar_kwargs=None, 589 cbar_ax=None, 590 vmin=None, 591 vmax=None, 592 norm=None, 593 infer_intervals=None, 594 center=None, 595 levels=None, 596 robust=None, 597 colors=None, 598 extend=None, 599 cmap=None, 600 _labels=True, 601 **kwargs, 602): 603 """ 604 Scatter plot a DataArray along some coordinates. 605 606 Parameters 607 ---------- 608 darray : DataArray 609 Dataarray to plot. 610 x, y : str 611 Variable names for x, y axis. 612 hue: str, optional 613 Variable by which to color scattered points 614 hue_style: str, optional 615 Can be either 'discrete' (legend) or 'continuous' (color bar). 616 markersize: str, optional 617 scatter only. Variable by which to vary size of scattered points. 618 size_norm: optional 619 Either None or 'Norm' instance to normalize the 'markersize' variable. 620 add_guide: bool, optional 621 Add a guide that depends on hue_style 622 - for "discrete", build a legend. 623 This is the default for non-numeric `hue` variables. 624 - for "continuous", build a colorbar 625 row : str, optional 626 If passed, make row faceted plots on this dimension name 627 col : str, optional 628 If passed, make column faceted plots on this dimension name 629 col_wrap : int, optional 630 Use together with ``col`` to wrap faceted plots 631 ax : matplotlib axes object, optional 632 If None, uses the current axis. Not applicable when using facets. 633 subplot_kws : dict, optional 634 Dictionary of keyword arguments for matplotlib subplots. Only applies 635 to FacetGrid plotting. 636 aspect : scalar, optional 637 Aspect ratio of plot, so that ``aspect * size`` gives the width in 638 inches. Only used if a ``size`` is provided. 639 size : scalar, optional 640 If provided, create a new figure for the plot with the given size. 641 Height (in inches) of each plot. See also: ``aspect``. 642 norm : ``matplotlib.colors.Normalize`` instance, optional 643 If the ``norm`` has vmin or vmax specified, the corresponding kwarg 644 must be None. 645 vmin, vmax : float, optional 646 Values to anchor the colormap, otherwise they are inferred from the 647 data and other keyword arguments. When a diverging dataset is inferred, 648 setting one of these values will fix the other by symmetry around 649 ``center``. Setting both values prevents use of a diverging colormap. 650 If discrete levels are provided as an explicit list, both of these 651 values are ignored. 652 cmap : str or colormap, optional 653 The mapping from data values to color space. Either a 654 matplotlib colormap name or object. If not provided, this will 655 be either ``viridis`` (if the function infers a sequential 656 dataset) or ``RdBu_r`` (if the function infers a diverging 657 dataset). When `Seaborn` is installed, ``cmap`` may also be a 658 `seaborn` color palette. If ``cmap`` is seaborn color palette 659 and the plot type is not ``contour`` or ``contourf``, ``levels`` 660 must also be specified. 661 colors : color-like or list of color-like, optional 662 A single color or a list of colors. If the plot type is not ``contour`` 663 or ``contourf``, the ``levels`` argument is required. 664 center : float, optional 665 The value at which to center the colormap. Passing this value implies 666 use of a diverging colormap. Setting it to ``False`` prevents use of a 667 diverging colormap. 668 robust : bool, optional 669 If True and ``vmin`` or ``vmax`` are absent, the colormap range is 670 computed with 2nd and 98th percentiles instead of the extreme values. 671 extend : {"neither", "both", "min", "max"}, optional 672 How to draw arrows extending the colorbar beyond its limits. If not 673 provided, extend is inferred from vmin, vmax and the data limits. 674 levels : int or list-like object, optional 675 Split the colormap (cmap) into discrete color intervals. If an integer 676 is provided, "nice" levels are chosen based on the data range: this can 677 imply that the final number of levels is not exactly the expected one. 678 Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to 679 setting ``levels=np.linspace(vmin, vmax, N)``. 680 **kwargs : optional 681 Additional keyword arguments to matplotlib 682 """ 683 # Handle facetgrids first 684 if row or col: 685 allargs = locals().copy() 686 allargs.update(allargs.pop("kwargs")) 687 allargs.pop("darray") 688 subplot_kws = dict(projection="3d") if z is not None else None 689 return _easy_facetgrid( 690 darray, scatter, kind="dataarray", subplot_kws=subplot_kws, **allargs 691 ) 692 693 # Further 694 _is_facetgrid = kwargs.pop("_is_facetgrid", False) 695 if _is_facetgrid: 696 # Why do I need to pop these here? 697 kwargs.pop("y", None) 698 kwargs.pop("args", None) 699 kwargs.pop("add_labels", None) 700 701 _sizes = kwargs.pop("markersize", kwargs.pop("linewidth", None)) 702 size_norm = kwargs.pop("size_norm", None) 703 size_mapping = kwargs.pop("size_mapping", None) # set by facetgrid 704 cmap_params = kwargs.pop("cmap_params", None) 705 706 figsize = kwargs.pop("figsize", None) 707 subplot_kws = dict() 708 if z is not None and ax is None: 709 # TODO: Importing Axes3D is not necessary in matplotlib >= 3.2. 710 # Remove when minimum requirement of matplotlib is 3.2: 711 from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa 712 713 subplot_kws.update(projection="3d") 714 ax = get_axis(figsize, size, aspect, ax, **subplot_kws) 715 # Using 30, 30 minimizes rotation of the plot. Making it easier to 716 # build on your intuition from 2D plots: 717 if LooseVersion(plt.matplotlib.__version__) < "3.5.0": 718 ax.view_init(azim=30, elev=30) 719 else: 720 # https://github.com/matplotlib/matplotlib/pull/19873 721 ax.view_init(azim=30, elev=30, vertical_axis="y") 722 else: 723 ax = get_axis(figsize, size, aspect, ax, **subplot_kws) 724 725 _data = _infer_scatter_metadata(darray, x, z, hue, hue_style, _sizes) 726 727 add_guide = kwargs.pop("add_guide", None) 728 if add_legend is not None: 729 pass 730 elif add_guide is None or add_guide is True: 731 add_legend = True if _data["hue_style"] == "discrete" else False 732 elif add_legend is None: 733 add_legend = False 734 735 if add_colorbar is not None: 736 pass 737 elif add_guide is None or add_guide is True: 738 add_colorbar = True if _data["hue_style"] == "continuous" else False 739 else: 740 add_colorbar = False 741 742 # need to infer size_mapping with full dataset 743 _data.update( 744 _infer_scatter_data( 745 darray, 746 x, 747 z, 748 hue, 749 _sizes, 750 size_norm, 751 size_mapping, 752 _MARKERSIZE_RANGE, 753 ) 754 ) 755 756 cmap_params_subset = {} 757 if _data["hue"] is not None: 758 kwargs.update(c=_data["hue"].values.ravel()) 759 cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( 760 scatter, _data["hue"].values, **locals() 761 ) 762 763 # subset that can be passed to scatter, hist2d 764 cmap_params_subset = { 765 vv: cmap_params[vv] for vv in ["vmin", "vmax", "norm", "cmap"] 766 } 767 768 if _data["size"] is not None: 769 kwargs.update(s=_data["size"].values.ravel()) 770 771 if LooseVersion(plt.matplotlib.__version__) < "3.5.0": 772 # Plot the data. 3d plots has the z value in upward direction 773 # instead of y. To make jumping between 2d and 3d easy and intuitive 774 # switch the order so that z is shown in the depthwise direction: 775 axis_order = ["x", "z", "y"] 776 else: 777 # Switching axis order not needed in 3.5.0, can also simplify the code 778 # that uses axis_order: 779 # https://github.com/matplotlib/matplotlib/pull/19873 780 axis_order = ["x", "y", "z"] 781 782 primitive = ax.scatter( 783 *[ 784 _data[v].values.ravel() 785 for v in axis_order 786 if _data.get(v, None) is not None 787 ], 788 **cmap_params_subset, 789 **kwargs, 790 ) 791 792 # Set x, y, z labels: 793 i = 0 794 set_label = [ax.set_xlabel, ax.set_ylabel, getattr(ax, "set_zlabel", None)] 795 for v in axis_order: 796 if _data.get(f"{v}label", None) is not None: 797 set_label[i](_data[f"{v}label"]) 798 i += 1 799 800 if add_legend: 801 802 def to_label(data, key, x): 803 """Map prop values back to its original values.""" 804 if key in data: 805 # Use reindex to be less sensitive to float errors. 806 # Return as numpy array since legend_elements 807 # seems to require that: 808 return data[key].reindex(x, method="nearest").to_numpy() 809 else: 810 return x 811 812 handles, labels = [], [] 813 for subtitle, prop, func in [ 814 ( 815 _data["hue_label"], 816 "colors", 817 functools.partial(to_label, _data, "hue_to_label"), 818 ), 819 ( 820 _data["size_label"], 821 "sizes", 822 functools.partial(to_label, _data, "size_to_label"), 823 ), 824 ]: 825 if subtitle: 826 # Get legend handles and labels that displays the 827 # values correctly. Order might be different because 828 # legend_elements uses np.unique instead of pd.unique, 829 # FacetGrid.add_legend might have troubles with this: 830 hdl, lbl = legend_elements(primitive, prop, num="auto", func=func) 831 hdl, lbl = _legend_add_subtitle(hdl, lbl, subtitle, ax.scatter) 832 handles += hdl 833 labels += lbl 834 legend = ax.legend(handles, labels, framealpha=0.5) 835 _adjust_legend_subtitles(legend) 836 837 if add_colorbar and _data["hue_label"]: 838 if _data["hue_style"] == "discrete": 839 raise NotImplementedError("Cannot create a colorbar for non numerics.") 840 cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs 841 if "label" not in cbar_kwargs: 842 cbar_kwargs["label"] = _data["hue_label"] 843 _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) 844 845 return primitive 846 847 848# MUST run before any 2d plotting functions are defined since 849# _plot2d decorator adds them as methods here. 850class _PlotMethods: 851 """ 852 Enables use of xarray.plot functions as attributes on a DataArray. 853 For example, DataArray.plot.imshow 854 """ 855 856 __slots__ = ("_da",) 857 858 def __init__(self, darray): 859 self._da = darray 860 861 def __call__(self, **kwargs): 862 return plot(self._da, **kwargs) 863 864 # we can't use functools.wraps here since that also modifies the name / qualname 865 __doc__ = __call__.__doc__ = plot.__doc__ 866 __call__.__wrapped__ = plot # type: ignore[attr-defined] 867 __call__.__annotations__ = plot.__annotations__ 868 869 @functools.wraps(hist) 870 def hist(self, ax=None, **kwargs): 871 return hist(self._da, ax=ax, **kwargs) 872 873 @functools.wraps(line) 874 def line(self, *args, **kwargs): 875 return line(self._da, *args, **kwargs) 876 877 @functools.wraps(step) 878 def step(self, *args, **kwargs): 879 return step(self._da, *args, **kwargs) 880 881 @functools.wraps(scatter) 882 def _scatter(self, *args, **kwargs): 883 return scatter(self._da, *args, **kwargs) 884 885 886def override_signature(f): 887 def wrapper(func): 888 func.__wrapped__ = f 889 890 return func 891 892 return wrapper 893 894 895def _plot2d(plotfunc): 896 """ 897 Decorator for common 2d plotting logic 898 899 Also adds the 2d plot method to class _PlotMethods 900 """ 901 commondoc = """ 902 Parameters 903 ---------- 904 darray : DataArray 905 Must be two-dimensional, unless creating faceted plots. 906 x : str, optional 907 Coordinate for *x* axis. If ``None``, use ``darray.dims[1]``. 908 y : str, optional 909 Coordinate for *y* axis. If ``None``, use ``darray.dims[0]``. 910 figsize : tuple, optional 911 A tuple (width, height) of the figure in inches. 912 Mutually exclusive with ``size`` and ``ax``. 913 aspect : scalar, optional 914 Aspect ratio of plot, so that ``aspect * size`` gives the *width* in 915 inches. Only used if a ``size`` is provided. 916 size : scalar, optional 917 If provided, create a new figure for the plot with the given size: 918 *height* (in inches) of each plot. See also: ``aspect``. 919 ax : matplotlib axes object, optional 920 Axes on which to plot. By default, use the current axes. 921 Mutually exclusive with ``size`` and ``figsize``. 922 row : string, optional 923 If passed, make row faceted plots on this dimension name. 924 col : string, optional 925 If passed, make column faceted plots on this dimension name. 926 col_wrap : int, optional 927 Use together with ``col`` to wrap faceted plots. 928 xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional 929 Specifies scaling for the *x*- and *y*-axis, respectively. 930 xticks, yticks : array-like, optional 931 Specify tick locations for *x*- and *y*-axis. 932 xlim, ylim : array-like, optional 933 Specify *x*- and *y*-axis limits. 934 xincrease : None, True, or False, optional 935 Should the values on the *x* axis be increasing from left to right? 936 If ``None``, use the default for the Matplotlib function. 937 yincrease : None, True, or False, optional 938 Should the values on the *y* axis be increasing from top to bottom? 939 If ``None``, use the default for the Matplotlib function. 940 add_colorbar : bool, optional 941 Add colorbar to axes. 942 add_labels : bool, optional 943 Use xarray metadata to label axes. 944 norm : matplotlib.colors.Normalize, optional 945 If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding 946 kwarg must be ``None``. 947 vmin, vmax : float, optional 948 Values to anchor the colormap, otherwise they are inferred from the 949 data and other keyword arguments. When a diverging dataset is inferred, 950 setting one of these values will fix the other by symmetry around 951 ``center``. Setting both values prevents use of a diverging colormap. 952 If discrete levels are provided as an explicit list, both of these 953 values are ignored. 954 cmap : matplotlib colormap name or colormap, optional 955 The mapping from data values to color space. If not provided, this 956 will be either be ``'viridis'`` (if the function infers a sequential 957 dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset). 958 See :doc:`Choosing Colormaps in Matplotlib <matplotlib:tutorials/colors/colormaps>` 959 for more information. 960 961 If *seaborn* is installed, ``cmap`` may also be a 962 `seaborn color palette <https://seaborn.pydata.org/tutorial/color_palettes.html>`_. 963 Note: if ``cmap`` is a seaborn color palette and the plot type 964 is not ``'contour'`` or ``'contourf'``, ``levels`` must also be specified. 965 colors : str or array-like of color-like, optional 966 A single color or a sequence of colors. If the plot type is not ``'contour'`` 967 or ``'contourf'``, the ``levels`` argument is required. 968 center : float, optional 969 The value at which to center the colormap. Passing this value implies 970 use of a diverging colormap. Setting it to ``False`` prevents use of a 971 diverging colormap. 972 robust : bool, optional 973 If ``True`` and ``vmin`` or ``vmax`` are absent, the colormap range is 974 computed with 2nd and 98th percentiles instead of the extreme values. 975 extend : {'neither', 'both', 'min', 'max'}, optional 976 How to draw arrows extending the colorbar beyond its limits. If not 977 provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits. 978 levels : int or array-like, optional 979 Split the colormap (``cmap``) into discrete color intervals. If an integer 980 is provided, "nice" levels are chosen based on the data range: this can 981 imply that the final number of levels is not exactly the expected one. 982 Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to 983 setting ``levels=np.linspace(vmin, vmax, N)``. 984 infer_intervals : bool, optional 985 Only applies to pcolormesh. If ``True``, the coordinate intervals are 986 passed to pcolormesh. If ``False``, the original coordinates are used 987 (this can be useful for certain map projections). The default is to 988 always infer intervals, unless the mesh is irregular and plotted on 989 a map projection. 990 subplot_kws : dict, optional 991 Dictionary of keyword arguments for Matplotlib subplots. Only used 992 for 2D and faceted plots. 993 (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). 994 cbar_ax : matplotlib axes object, optional 995 Axes in which to draw the colorbar. 996 cbar_kwargs : dict, optional 997 Dictionary of keyword arguments to pass to the colorbar 998 (see :meth:`matplotlib:matplotlib.figure.Figure.colorbar`). 999 **kwargs : optional 1000 Additional keyword arguments to wrapped Matplotlib function. 1001 1002 Returns 1003 ------- 1004 artist : 1005 The same type of primitive artist that the wrapped Matplotlib 1006 function returns. 1007 """ 1008 1009 # Build on the original docstring 1010 plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" 1011 1012 # plotfunc and newplotfunc have different signatures: 1013 # - plotfunc: (x, y, z, ax, **kwargs) 1014 # - newplotfunc: (darray, x, y, **kwargs) 1015 # where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray 1016 # and variable names. newplotfunc also explicitly lists most kwargs, so we 1017 # need to shorten it 1018 def signature(darray, x, y, **kwargs): 1019 pass 1020 1021 @override_signature(signature) 1022 @functools.wraps(plotfunc) 1023 def newplotfunc( 1024 darray, 1025 x=None, 1026 y=None, 1027 figsize=None, 1028 size=None, 1029 aspect=None, 1030 ax=None, 1031 row=None, 1032 col=None, 1033 col_wrap=None, 1034 xincrease=True, 1035 yincrease=True, 1036 add_colorbar=None, 1037 add_labels=True, 1038 vmin=None, 1039 vmax=None, 1040 cmap=None, 1041 center=None, 1042 robust=False, 1043 extend=None, 1044 levels=None, 1045 infer_intervals=None, 1046 colors=None, 1047 subplot_kws=None, 1048 cbar_ax=None, 1049 cbar_kwargs=None, 1050 xscale=None, 1051 yscale=None, 1052 xticks=None, 1053 yticks=None, 1054 xlim=None, 1055 ylim=None, 1056 norm=None, 1057 **kwargs, 1058 ): 1059 # All 2d plots in xarray share this function signature. 1060 # Method signature below should be consistent. 1061 1062 # Decide on a default for the colorbar before facetgrids 1063 if add_colorbar is None: 1064 add_colorbar = True 1065 if plotfunc.__name__ == "contour" or ( 1066 plotfunc.__name__ == "surface" and cmap is None 1067 ): 1068 add_colorbar = False 1069 imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == ( 1070 3 + (row is not None) + (col is not None) 1071 ) 1072 if imshow_rgb: 1073 # Don't add a colorbar when showing an image with explicit colors 1074 add_colorbar = False 1075 # Matplotlib does not support normalising RGB data, so do it here. 1076 # See eg. https://github.com/matplotlib/matplotlib/pull/10220 1077 if robust or vmax is not None or vmin is not None: 1078 darray = _rescale_imshow_rgb(darray.as_numpy(), vmin, vmax, robust) 1079 vmin, vmax, robust = None, None, False 1080 1081 if subplot_kws is None: 1082 subplot_kws = dict() 1083 1084 if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False): 1085 if ax is None: 1086 # TODO: Importing Axes3D is no longer necessary in matplotlib >= 3.2. 1087 # Remove when minimum requirement of matplotlib is 3.2: 1088 from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa: F401 1089 1090 # delete so it does not end up in locals() 1091 del Axes3D 1092 1093 # Need to create a "3d" Axes instance for surface plots 1094 subplot_kws["projection"] = "3d" 1095 1096 # In facet grids, shared axis labels don't make sense for surface plots 1097 sharex = False 1098 sharey = False 1099 1100 # Handle facetgrids first 1101 if row or col: 1102 allargs = locals().copy() 1103 del allargs["darray"] 1104 del allargs["imshow_rgb"] 1105 allargs.update(allargs.pop("kwargs")) 1106 # Need the decorated plotting function 1107 allargs["plotfunc"] = globals()[plotfunc.__name__] 1108 return _easy_facetgrid(darray, kind="dataarray", **allargs) 1109 1110 if ( 1111 plotfunc.__name__ == "surface" 1112 and not kwargs.get("_is_facetgrid", False) 1113 and ax is not None 1114 ): 1115 import mpl_toolkits # type: ignore 1116 1117 if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D): 1118 raise ValueError( 1119 "If ax is passed to surface(), it must be created with " 1120 'projection="3d"' 1121 ) 1122 1123 rgb = kwargs.pop("rgb", None) 1124 if rgb is not None and plotfunc.__name__ != "imshow": 1125 raise ValueError('The "rgb" keyword is only valid for imshow()') 1126 elif rgb is not None and not imshow_rgb: 1127 raise ValueError( 1128 'The "rgb" keyword is only valid for imshow()' 1129 "with a three-dimensional array (per facet)" 1130 ) 1131 1132 xlab, ylab = _infer_xy_labels( 1133 darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb 1134 ) 1135 1136 xval = darray[xlab] 1137 yval = darray[ylab] 1138 1139 if xval.ndim > 1 or yval.ndim > 1 or plotfunc.__name__ == "surface": 1140 # Passing 2d coordinate values, need to ensure they are transposed the same 1141 # way as darray. 1142 # Also surface plots always need 2d coordinates 1143 xval = xval.broadcast_like(darray) 1144 yval = yval.broadcast_like(darray) 1145 dims = darray.dims 1146 else: 1147 dims = (yval.dims[0], xval.dims[0]) 1148 1149 # May need to transpose for correct x, y labels 1150 # xlab may be the name of a coord, we have to check for dim names 1151 if imshow_rgb: 1152 # For RGB[A] images, matplotlib requires the color dimension 1153 # to be last. In Xarray the order should be unimportant, so 1154 # we transpose to (y, x, color) to make this work. 1155 yx_dims = (ylab, xlab) 1156 dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims) 1157 1158 if dims != darray.dims: 1159 darray = darray.transpose(*dims, transpose_coords=True) 1160 1161 # better to pass the ndarrays directly to plotting functions 1162 xval = xval.to_numpy() 1163 yval = yval.to_numpy() 1164 1165 # Pass the data as a masked ndarray too 1166 zval = darray.to_masked_array(copy=False) 1167 1168 # Replace pd.Intervals if contained in xval or yval. 1169 xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__) 1170 yplt, ylab_extra = _resolve_intervals_2dplot(yval, plotfunc.__name__) 1171 1172 _ensure_plottable(xplt, yplt, zval) 1173 1174 cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( 1175 plotfunc, 1176 zval.data, 1177 **locals(), 1178 _is_facetgrid=kwargs.pop("_is_facetgrid", False), 1179 ) 1180 1181 if "contour" in plotfunc.__name__: 1182 # extend is a keyword argument only for contour and contourf, but 1183 # passing it to the colorbar is sufficient for imshow and 1184 # pcolormesh 1185 kwargs["extend"] = cmap_params["extend"] 1186 kwargs["levels"] = cmap_params["levels"] 1187 # if colors == a single color, matplotlib draws dashed negative 1188 # contours. we lose this feature if we pass cmap and not colors 1189 if isinstance(colors, str): 1190 cmap_params["cmap"] = None 1191 kwargs["colors"] = colors 1192 1193 if "pcolormesh" == plotfunc.__name__: 1194 kwargs["infer_intervals"] = infer_intervals 1195 kwargs["xscale"] = xscale 1196 kwargs["yscale"] = yscale 1197 1198 if "imshow" == plotfunc.__name__ and isinstance(aspect, str): 1199 # forbid usage of mpl strings 1200 raise ValueError("plt.imshow's `aspect` kwarg is not available in xarray") 1201 1202 ax = get_axis(figsize, size, aspect, ax, **subplot_kws) 1203 1204 primitive = plotfunc( 1205 xplt, 1206 yplt, 1207 zval, 1208 ax=ax, 1209 cmap=cmap_params["cmap"], 1210 vmin=cmap_params["vmin"], 1211 vmax=cmap_params["vmax"], 1212 norm=cmap_params["norm"], 1213 **kwargs, 1214 ) 1215 1216 # Label the plot with metadata 1217 if add_labels: 1218 ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra)) 1219 ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) 1220 ax.set_title(darray._title_for_slice()) 1221 if plotfunc.__name__ == "surface": 1222 ax.set_zlabel(label_from_attrs(darray)) 1223 1224 if add_colorbar: 1225 if add_labels and "label" not in cbar_kwargs: 1226 cbar_kwargs["label"] = label_from_attrs(darray) 1227 cbar = _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) 1228 elif cbar_ax is not None or cbar_kwargs: 1229 # inform the user about keywords which aren't used 1230 raise ValueError( 1231 "cbar_ax and cbar_kwargs can't be used with add_colorbar=False." 1232 ) 1233 1234 # origin kwarg overrides yincrease 1235 if "origin" in kwargs: 1236 yincrease = None 1237 1238 _update_axes( 1239 ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim 1240 ) 1241 1242 # Rotate dates on xlabels 1243 # Do this without calling autofmt_xdate so that x-axes ticks 1244 # on other subplots (if any) are not deleted. 1245 # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots 1246 if np.issubdtype(xplt.dtype, np.datetime64): 1247 for xlabels in ax.get_xticklabels(): 1248 xlabels.set_rotation(30) 1249 xlabels.set_ha("right") 1250 1251 return primitive 1252 1253 # For use as DataArray.plot.plotmethod 1254 @functools.wraps(newplotfunc) 1255 def plotmethod( 1256 _PlotMethods_obj, 1257 x=None, 1258 y=None, 1259 figsize=None, 1260 size=None, 1261 aspect=None, 1262 ax=None, 1263 row=None, 1264 col=None, 1265 col_wrap=None, 1266 xincrease=True, 1267 yincrease=True, 1268 add_colorbar=None, 1269 add_labels=True, 1270 vmin=None, 1271 vmax=None, 1272 cmap=None, 1273 colors=None, 1274 center=None, 1275 robust=False, 1276 extend=None, 1277 levels=None, 1278 infer_intervals=None, 1279 subplot_kws=None, 1280 cbar_ax=None, 1281 cbar_kwargs=None, 1282 xscale=None, 1283 yscale=None, 1284 xticks=None, 1285 yticks=None, 1286 xlim=None, 1287 ylim=None, 1288 norm=None, 1289 **kwargs, 1290 ): 1291 """ 1292 The method should have the same signature as the function. 1293 1294 This just makes the method work on Plotmethods objects, 1295 and passes all the other arguments straight through. 1296 """ 1297 allargs = locals() 1298 allargs["darray"] = _PlotMethods_obj._da 1299 allargs.update(kwargs) 1300 for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]: 1301 del allargs[arg] 1302 return newplotfunc(**allargs) 1303 1304 # Add to class _PlotMethods 1305 setattr(_PlotMethods, plotmethod.__name__, plotmethod) 1306 1307 return newplotfunc 1308 1309 1310@_plot2d 1311def imshow(x, y, z, ax, **kwargs): 1312 """ 1313 Image plot of 2D DataArray. 1314 1315 Wraps :py:func:`matplotlib:matplotlib.pyplot.imshow`. 1316 1317 While other plot methods require the DataArray to be strictly 1318 two-dimensional, ``imshow`` also accepts a 3D array where some 1319 dimension can be interpreted as RGB or RGBA color channels and 1320 allows this dimension to be specified via the kwarg ``rgb=``. 1321 1322 Unlike :py:func:`matplotlib:matplotlib.pyplot.imshow`, which ignores ``vmin``/``vmax`` 1323 for RGB(A) data, 1324 xarray *will* use ``vmin`` and ``vmax`` for RGB(A) data 1325 by applying a single scaling factor and offset to all bands. 1326 Passing ``robust=True`` infers ``vmin`` and ``vmax`` 1327 :ref:`in the usual way <robust-plotting>`. 1328 1329 .. note:: 1330 This function needs uniformly spaced coordinates to 1331 properly label the axes. Call :py:meth:`DataArray.plot` to check. 1332 1333 The pixels are centered on the coordinates. For example, if the coordinate 1334 value is 3.2, then the pixels for those coordinates will be centered on 3.2. 1335 """ 1336 1337 if x.ndim != 1 or y.ndim != 1: 1338 raise ValueError( 1339 "imshow requires 1D coordinates, try using pcolormesh or contour(f)" 1340 ) 1341 1342 def _center_pixels(x): 1343 """Center the pixels on the coordinates.""" 1344 if np.issubdtype(x.dtype, str): 1345 # When using strings as inputs imshow converts it to 1346 # integers. Choose extent values which puts the indices in 1347 # in the center of the pixels: 1348 return 0 - 0.5, len(x) - 0.5 1349 1350 try: 1351 # Center the pixels assuming uniform spacing: 1352 xstep = 0.5 * (x[1] - x[0]) 1353 except IndexError: 1354 # Arbitrary default value, similar to matplotlib behaviour: 1355 xstep = 0.1 1356 1357 return x[0] - xstep, x[-1] + xstep 1358 1359 # Center the pixels: 1360 left, right = _center_pixels(x) 1361 top, bottom = _center_pixels(y) 1362 1363 defaults = {"origin": "upper", "interpolation": "nearest"} 1364 1365 if not hasattr(ax, "projection"): 1366 # not for cartopy geoaxes 1367 defaults["aspect"] = "auto" 1368 1369 # Allow user to override these defaults 1370 defaults.update(kwargs) 1371 1372 if defaults["origin"] == "upper": 1373 defaults["extent"] = [left, right, bottom, top] 1374 else: 1375 defaults["extent"] = [left, right, top, bottom] 1376 1377 if z.ndim == 3: 1378 # matplotlib imshow uses black for missing data, but Xarray makes 1379 # missing data transparent. We therefore add an alpha channel if 1380 # there isn't one, and set it to transparent where data is masked. 1381 if z.shape[-1] == 3: 1382 alpha = np.ma.ones(z.shape[:2] + (1,), dtype=z.dtype) 1383 if np.issubdtype(z.dtype, np.integer): 1384 alpha *= 255 1385 z = np.ma.concatenate((z, alpha), axis=2) 1386 else: 1387 z = z.copy() 1388 z[np.any(z.mask, axis=-1), -1] = 0 1389 1390 primitive = ax.imshow(z, **defaults) 1391 1392 # If x or y are strings the ticklabels have been replaced with 1393 # integer indices. Replace them back to strings: 1394 for axis, v in [("x", x), ("y", y)]: 1395 if np.issubdtype(v.dtype, str): 1396 getattr(ax, f"set_{axis}ticks")(np.arange(len(v))) 1397 getattr(ax, f"set_{axis}ticklabels")(v) 1398 1399 return primitive 1400 1401 1402@_plot2d 1403def contour(x, y, z, ax, **kwargs): 1404 """ 1405 Contour plot of 2D DataArray. 1406 1407 Wraps :py:func:`matplotlib:matplotlib.pyplot.contour`. 1408 """ 1409 primitive = ax.contour(x, y, z, **kwargs) 1410 return primitive 1411 1412 1413@_plot2d 1414def contourf(x, y, z, ax, **kwargs): 1415 """ 1416 Filled contour plot of 2D DataArray. 1417 1418 Wraps :py:func:`matplotlib:matplotlib.pyplot.contourf`. 1419 """ 1420 primitive = ax.contourf(x, y, z, **kwargs) 1421 return primitive 1422 1423 1424@_plot2d 1425def pcolormesh(x, y, z, ax, xscale=None, yscale=None, infer_intervals=None, **kwargs): 1426 """ 1427 Pseudocolor plot of 2D DataArray. 1428 1429 Wraps :py:func:`matplotlib:matplotlib.pyplot.pcolormesh`. 1430 """ 1431 1432 # decide on a default for infer_intervals (GH781) 1433 x = np.asarray(x) 1434 if infer_intervals is None: 1435 if hasattr(ax, "projection"): 1436 if len(x.shape) == 1: 1437 infer_intervals = True 1438 else: 1439 infer_intervals = False 1440 else: 1441 infer_intervals = True 1442 1443 if ( 1444 infer_intervals 1445 and not np.issubdtype(x.dtype, str) 1446 and ( 1447 (np.shape(x)[0] == np.shape(z)[1]) 1448 or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])) 1449 ) 1450 ): 1451 if len(x.shape) == 1: 1452 x = _infer_interval_breaks(x, check_monotonic=True, scale=xscale) 1453 else: 1454 # we have to infer the intervals on both axes 1455 x = _infer_interval_breaks(x, axis=1, scale=xscale) 1456 x = _infer_interval_breaks(x, axis=0, scale=xscale) 1457 1458 if ( 1459 infer_intervals 1460 and not np.issubdtype(y.dtype, str) 1461 and (np.shape(y)[0] == np.shape(z)[0]) 1462 ): 1463 if len(y.shape) == 1: 1464 y = _infer_interval_breaks(y, check_monotonic=True, scale=yscale) 1465 else: 1466 # we have to infer the intervals on both axes 1467 y = _infer_interval_breaks(y, axis=1, scale=yscale) 1468 y = _infer_interval_breaks(y, axis=0, scale=yscale) 1469 1470 primitive = ax.pcolormesh(x, y, z, **kwargs) 1471 1472 # by default, pcolormesh picks "round" values for bounds 1473 # this results in ugly looking plots with lots of surrounding whitespace 1474 if not hasattr(ax, "projection") and x.ndim == 1 and y.ndim == 1: 1475 # not a cartopy geoaxis 1476 ax.set_xlim(x[0], x[-1]) 1477 ax.set_ylim(y[0], y[-1]) 1478 1479 return primitive 1480 1481 1482@_plot2d 1483def surface(x, y, z, ax, **kwargs): 1484 """ 1485 Surface plot of 2D DataArray. 1486 1487 Wraps :py:meth:`matplotlib:mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`. 1488 """ 1489 primitive = ax.plot_surface(x, y, z, **kwargs) 1490 return primitive 1491