1import functools 2import itertools 3import warnings 4 5import numpy as np 6 7from ..core.formatting import format_item 8from .utils import ( 9 _get_nice_quiver_magnitude, 10 _infer_xy_labels, 11 _process_cmap_cbar_kwargs, 12 label_from_attrs, 13 plt, 14) 15 16# Overrides axes.labelsize, xtick.major.size, ytick.major.size 17# from mpl.rcParams 18_FONTSIZE = "small" 19# For major ticks on x, y axes 20_NTICKS = 5 21 22 23def _nicetitle(coord, value, maxchar, template): 24 """ 25 Put coord, value in template and truncate at maxchar 26 """ 27 prettyvalue = format_item(value, quote_strings=False) 28 title = template.format(coord=coord, value=prettyvalue) 29 30 if len(title) > maxchar: 31 title = title[: (maxchar - 3)] + "..." 32 33 return title 34 35 36class FacetGrid: 37 """ 38 Initialize the Matplotlib figure and FacetGrid object. 39 40 The :class:`FacetGrid` is an object that links a xarray DataArray to 41 a Matplotlib figure with a particular structure. 42 43 In particular, :class:`FacetGrid` is used to draw plots with multiple 44 axes, where each axes shows the same relationship conditioned on 45 different levels of some dimension. It's possible to condition on up to 46 two variables by assigning variables to the rows and columns of the 47 grid. 48 49 The general approach to plotting here is called "small multiples", 50 where the same kind of plot is repeated multiple times, and the 51 specific use of small multiples to display the same relationship 52 conditioned on one ore more other variables is often called a "trellis 53 plot". 54 55 The basic workflow is to initialize the :class:`FacetGrid` object with 56 the DataArray and the variable names that are used to structure the grid. 57 Then plotting functions can be applied to each subset by calling 58 :meth:`FacetGrid.map_dataarray` or :meth:`FacetGrid.map`. 59 60 Attributes 61 ---------- 62 axes : ndarray of matplotlib.axes.Axes 63 Array containing axes in corresponding position, as returned from 64 :py:func:`matplotlib.pyplot.subplots`. 65 col_labels : list of matplotlib.text.Text 66 Column titles. 67 row_labels : list of matplotlib.text.Text 68 Row titles. 69 fig : matplotlib.figure.Figure 70 The figure containing all the axes. 71 name_dicts : ndarray of dict 72 Array containing dictionaries mapping coordinate names to values. ``None`` is 73 used as a sentinel value for axes that should remain empty, i.e., 74 sometimes the rightmost grid positions in the bottom row. 75 """ 76 77 def __init__( 78 self, 79 data, 80 col=None, 81 row=None, 82 col_wrap=None, 83 sharex=True, 84 sharey=True, 85 figsize=None, 86 aspect=1, 87 size=3, 88 subplot_kws=None, 89 ): 90 """ 91 Parameters 92 ---------- 93 data : DataArray 94 xarray DataArray to be plotted. 95 row, col : str 96 Dimesion names that define subsets of the data, which will be drawn 97 on separate facets in the grid. 98 col_wrap : int, optional 99 "Wrap" the grid the for the column variable after this number of columns, 100 adding rows if ``col_wrap`` is less than the number of facets. 101 sharex : bool, optional 102 If true, the facets will share *x* axes. 103 sharey : bool, optional 104 If true, the facets will share *y* axes. 105 figsize : tuple, optional 106 A tuple (width, height) of the figure in inches. 107 If set, overrides ``size`` and ``aspect``. 108 aspect : scalar, optional 109 Aspect ratio of each facet, so that ``aspect * size`` gives the 110 width of each facet in inches. 111 size : scalar, optional 112 Height (in inches) of each facet. See also: ``aspect``. 113 subplot_kws : dict, optional 114 Dictionary of keyword arguments for Matplotlib subplots 115 (:py:func:`matplotlib.pyplot.subplots`). 116 117 """ 118 119 # Handle corner case of nonunique coordinates 120 rep_col = col is not None and not data[col].to_index().is_unique 121 rep_row = row is not None and not data[row].to_index().is_unique 122 if rep_col or rep_row: 123 raise ValueError( 124 "Coordinates used for faceting cannot " 125 "contain repeated (nonunique) values." 126 ) 127 128 # single_group is the grouping variable, if there is exactly one 129 if col and row: 130 single_group = False 131 nrow = len(data[row]) 132 ncol = len(data[col]) 133 nfacet = nrow * ncol 134 if col_wrap is not None: 135 warnings.warn("Ignoring col_wrap since both col and row were passed") 136 elif row and not col: 137 single_group = row 138 elif not row and col: 139 single_group = col 140 else: 141 raise ValueError("Pass a coordinate name as an argument for row or col") 142 143 # Compute grid shape 144 if single_group: 145 nfacet = len(data[single_group]) 146 if col: 147 # idea - could add heuristic for nice shapes like 3x4 148 ncol = nfacet 149 if row: 150 ncol = 1 151 if col_wrap is not None: 152 # Overrides previous settings 153 ncol = col_wrap 154 nrow = int(np.ceil(nfacet / ncol)) 155 156 # Set the subplot kwargs 157 subplot_kws = {} if subplot_kws is None else subplot_kws 158 159 if figsize is None: 160 # Calculate the base figure size with extra horizontal space for a 161 # colorbar 162 cbar_space = 1 163 figsize = (ncol * size * aspect + cbar_space, nrow * size) 164 165 fig, axes = plt.subplots( 166 nrow, 167 ncol, 168 sharex=sharex, 169 sharey=sharey, 170 squeeze=False, 171 figsize=figsize, 172 subplot_kw=subplot_kws, 173 ) 174 175 # Set up the lists of names for the row and column facet variables 176 col_names = list(data[col].to_numpy()) if col else [] 177 row_names = list(data[row].to_numpy()) if row else [] 178 179 if single_group: 180 full = [{single_group: x} for x in data[single_group].to_numpy()] 181 empty = [None for x in range(nrow * ncol - len(full))] 182 name_dicts = full + empty 183 else: 184 rowcols = itertools.product(row_names, col_names) 185 name_dicts = [{row: r, col: c} for r, c in rowcols] 186 187 name_dicts = np.array(name_dicts).reshape(nrow, ncol) 188 189 # Set up the class attributes 190 # --------------------------- 191 192 # First the public API 193 self.data = data 194 self.name_dicts = name_dicts 195 self.fig = fig 196 self.axes = axes 197 self.row_names = row_names 198 self.col_names = col_names 199 200 # guides 201 self.figlegend = None 202 self.quiverkey = None 203 self.cbar = None 204 205 # Next the private variables 206 self._single_group = single_group 207 self._nrow = nrow 208 self._row_var = row 209 self._ncol = ncol 210 self._col_var = col 211 self._col_wrap = col_wrap 212 self.row_labels = [None] * nrow 213 self.col_labels = [None] * ncol 214 self._x_var = None 215 self._y_var = None 216 self._cmap_extend = None 217 self._mappables = [] 218 self._finalized = False 219 220 @property 221 def _left_axes(self): 222 return self.axes[:, 0] 223 224 @property 225 def _bottom_axes(self): 226 return self.axes[-1, :] 227 228 def map_dataarray(self, func, x, y, **kwargs): 229 """ 230 Apply a plotting function to a 2d facet's subset of the data. 231 232 This is more convenient and less general than ``FacetGrid.map`` 233 234 Parameters 235 ---------- 236 func : callable 237 A plotting function with the same signature as a 2d xarray 238 plotting method such as `xarray.plot.imshow` 239 x, y : string 240 Names of the coordinates to plot on x, y axes 241 **kwargs 242 additional keyword arguments to func 243 244 Returns 245 ------- 246 self : FacetGrid object 247 248 """ 249 250 if kwargs.get("cbar_ax", None) is not None: 251 raise ValueError("cbar_ax not supported by FacetGrid.") 252 253 cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( 254 func, self.data.to_numpy(), **kwargs 255 ) 256 257 self._cmap_extend = cmap_params.get("extend") 258 259 # Order is important 260 func_kwargs = { 261 k: v 262 for k, v in kwargs.items() 263 if k not in {"cmap", "colors", "cbar_kwargs", "levels"} 264 } 265 func_kwargs.update(cmap_params) 266 func_kwargs["add_colorbar"] = False 267 if func.__name__ != "surface": 268 func_kwargs["add_labels"] = False 269 270 # Get x, y labels for the first subplot 271 x, y = _infer_xy_labels( 272 darray=self.data.loc[self.name_dicts.flat[0]], 273 x=x, 274 y=y, 275 imshow=func.__name__ == "imshow", 276 rgb=kwargs.get("rgb", None), 277 ) 278 279 for d, ax in zip(self.name_dicts.flat, self.axes.flat): 280 # None is the sentinel value 281 if d is not None: 282 subset = self.data.loc[d] 283 mappable = func( 284 subset, x=x, y=y, ax=ax, **func_kwargs, _is_facetgrid=True 285 ) 286 self._mappables.append(mappable) 287 288 self._finalize_grid(x, y) 289 290 if kwargs.get("add_colorbar", True): 291 self.add_colorbar(**cbar_kwargs) 292 293 return self 294 295 def map_dataarray_line( 296 self, func, x, y, hue, add_legend=True, _labels=None, **kwargs 297 ): 298 from .plot import _infer_line_data 299 300 for d, ax in zip(self.name_dicts.flat, self.axes.flat): 301 # None is the sentinel value 302 if d is not None: 303 subset = self.data.loc[d] 304 mappable = func( 305 subset, 306 x=x, 307 y=y, 308 ax=ax, 309 hue=hue, 310 add_legend=False, 311 _labels=False, 312 **kwargs, 313 ) 314 self._mappables.append(mappable) 315 316 xplt, yplt, hueplt, huelabel = _infer_line_data( 317 darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, hue=hue 318 ) 319 xlabel = label_from_attrs(xplt) 320 ylabel = label_from_attrs(yplt) 321 322 self._hue_var = hueplt 323 self._hue_label = huelabel 324 self._finalize_grid(xlabel, ylabel) 325 326 if add_legend and hueplt is not None and huelabel is not None: 327 self.add_legend() 328 329 return self 330 331 def map_dataset( 332 self, func, x=None, y=None, hue=None, hue_style=None, add_guide=None, **kwargs 333 ): 334 from .dataset_plot import _infer_meta_data, _parse_size 335 336 kwargs["add_guide"] = False 337 338 if kwargs.get("markersize", None): 339 kwargs["size_mapping"] = _parse_size( 340 self.data[kwargs["markersize"]], kwargs.pop("size_norm", None) 341 ) 342 343 meta_data = _infer_meta_data( 344 self.data, x, y, hue, hue_style, add_guide, funcname=func.__name__ 345 ) 346 kwargs["meta_data"] = meta_data 347 348 if hue and meta_data["hue_style"] == "continuous": 349 cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( 350 func, self.data[hue].to_numpy(), **kwargs 351 ) 352 kwargs["meta_data"]["cmap_params"] = cmap_params 353 kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs 354 355 kwargs["_is_facetgrid"] = True 356 357 if func.__name__ == "quiver" and "scale" not in kwargs: 358 raise ValueError("Please provide scale.") 359 # TODO: come up with an algorithm for reasonable scale choice 360 361 for d, ax in zip(self.name_dicts.flat, self.axes.flat): 362 # None is the sentinel value 363 if d is not None: 364 subset = self.data.loc[d] 365 maybe_mappable = func( 366 ds=subset, x=x, y=y, hue=hue, hue_style=hue_style, ax=ax, **kwargs 367 ) 368 # TODO: this is needed to get legends to work. 369 # but maybe_mappable is a list in that case :/ 370 self._mappables.append(maybe_mappable) 371 372 self._finalize_grid(meta_data["xlabel"], meta_data["ylabel"]) 373 374 if hue: 375 self._hue_label = meta_data.pop("hue_label", None) 376 if meta_data["add_legend"]: 377 self._hue_var = meta_data["hue"] 378 self.add_legend() 379 elif meta_data["add_colorbar"]: 380 self.add_colorbar(label=self._hue_label, **cbar_kwargs) 381 382 if meta_data["add_quiverkey"]: 383 self.add_quiverkey(kwargs["u"], kwargs["v"]) 384 385 return self 386 387 def _finalize_grid(self, *axlabels): 388 """Finalize the annotations and layout.""" 389 if not self._finalized: 390 self.set_axis_labels(*axlabels) 391 self.set_titles() 392 self.fig.tight_layout() 393 394 for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): 395 if namedict is None: 396 ax.set_visible(False) 397 398 self._finalized = True 399 400 def _adjust_fig_for_guide(self, guide): 401 # Draw the plot to set the bounding boxes correctly 402 renderer = self.fig.canvas.get_renderer() 403 self.fig.draw(renderer) 404 405 # Calculate and set the new width of the figure so the legend fits 406 guide_width = guide.get_window_extent(renderer).width / self.fig.dpi 407 figure_width = self.fig.get_figwidth() 408 self.fig.set_figwidth(figure_width + guide_width) 409 410 # Draw the plot again to get the new transformations 411 self.fig.draw(renderer) 412 413 # Now calculate how much space we need on the right side 414 guide_width = guide.get_window_extent(renderer).width / self.fig.dpi 415 space_needed = guide_width / (figure_width + guide_width) + 0.02 416 # margin = .01 417 # _space_needed = margin + space_needed 418 right = 1 - space_needed 419 420 # Place the subplot axes to give space for the legend 421 self.fig.subplots_adjust(right=right) 422 423 def add_legend(self, **kwargs): 424 self.figlegend = self.fig.legend( 425 handles=self._mappables[-1], 426 labels=list(self._hue_var.to_numpy()), 427 title=self._hue_label, 428 loc="center right", 429 **kwargs, 430 ) 431 self._adjust_fig_for_guide(self.figlegend) 432 433 def add_colorbar(self, **kwargs): 434 """Draw a colorbar.""" 435 kwargs = kwargs.copy() 436 if self._cmap_extend is not None: 437 kwargs.setdefault("extend", self._cmap_extend) 438 # dont pass extend as kwarg if it is in the mappable 439 if hasattr(self._mappables[-1], "extend"): 440 kwargs.pop("extend", None) 441 if "label" not in kwargs: 442 kwargs.setdefault("label", label_from_attrs(self.data)) 443 self.cbar = self.fig.colorbar( 444 self._mappables[-1], ax=list(self.axes.flat), **kwargs 445 ) 446 return self 447 448 def add_quiverkey(self, u, v, **kwargs): 449 kwargs = kwargs.copy() 450 451 magnitude = _get_nice_quiver_magnitude(self.data[u], self.data[v]) 452 units = self.data[u].attrs.get("units", "") 453 self.quiverkey = self.axes.flat[-1].quiverkey( 454 self._mappables[-1], 455 X=0.8, 456 Y=0.9, 457 U=magnitude, 458 label=f"{magnitude}\n{units}", 459 labelpos="E", 460 coordinates="figure", 461 ) 462 463 # TODO: does not work because self.quiverkey.get_window_extent(renderer) = 0 464 # https://github.com/matplotlib/matplotlib/issues/18530 465 # self._adjust_fig_for_guide(self.quiverkey.text) 466 return self 467 468 def set_axis_labels(self, x_var=None, y_var=None): 469 """Set axis labels on the left column and bottom row of the grid.""" 470 if x_var is not None: 471 if x_var in self.data.coords: 472 self._x_var = x_var 473 self.set_xlabels(label_from_attrs(self.data[x_var])) 474 else: 475 # x_var is a string 476 self.set_xlabels(x_var) 477 478 if y_var is not None: 479 if y_var in self.data.coords: 480 self._y_var = y_var 481 self.set_ylabels(label_from_attrs(self.data[y_var])) 482 else: 483 self.set_ylabels(y_var) 484 return self 485 486 def set_xlabels(self, label=None, **kwargs): 487 """Label the x axis on the bottom row of the grid.""" 488 if label is None: 489 label = label_from_attrs(self.data[self._x_var]) 490 for ax in self._bottom_axes: 491 ax.set_xlabel(label, **kwargs) 492 return self 493 494 def set_ylabels(self, label=None, **kwargs): 495 """Label the y axis on the left column of the grid.""" 496 if label is None: 497 label = label_from_attrs(self.data[self._y_var]) 498 for ax in self._left_axes: 499 ax.set_ylabel(label, **kwargs) 500 return self 501 502 def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwargs): 503 """ 504 Draw titles either above each facet or on the grid margins. 505 506 Parameters 507 ---------- 508 template : string 509 Template for plot titles containing {coord} and {value} 510 maxchar : int 511 Truncate titles at maxchar 512 **kwargs : keyword args 513 additional arguments to matplotlib.text 514 515 Returns 516 ------- 517 self: FacetGrid object 518 519 """ 520 if size is None: 521 size = plt.rcParams["axes.labelsize"] 522 523 nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template) 524 525 if self._single_group: 526 for d, ax in zip(self.name_dicts.flat, self.axes.flat): 527 # Only label the ones with data 528 if d is not None: 529 coord, value = list(d.items()).pop() 530 title = nicetitle(coord, value, maxchar=maxchar) 531 ax.set_title(title, size=size, **kwargs) 532 else: 533 # The row titles on the right edge of the grid 534 for index, (ax, row_name, handle) in enumerate( 535 zip(self.axes[:, -1], self.row_names, self.row_labels) 536 ): 537 title = nicetitle(coord=self._row_var, value=row_name, maxchar=maxchar) 538 if not handle: 539 self.row_labels[index] = ax.annotate( 540 title, 541 xy=(1.02, 0.5), 542 xycoords="axes fraction", 543 rotation=270, 544 ha="left", 545 va="center", 546 **kwargs, 547 ) 548 else: 549 handle.set_text(title) 550 551 # The column titles on the top row 552 for index, (ax, col_name, handle) in enumerate( 553 zip(self.axes[0, :], self.col_names, self.col_labels) 554 ): 555 title = nicetitle(coord=self._col_var, value=col_name, maxchar=maxchar) 556 if not handle: 557 self.col_labels[index] = ax.set_title(title, size=size, **kwargs) 558 else: 559 handle.set_text(title) 560 561 return self 562 563 def set_ticks(self, max_xticks=_NTICKS, max_yticks=_NTICKS, fontsize=_FONTSIZE): 564 """ 565 Set and control tick behavior. 566 567 Parameters 568 ---------- 569 max_xticks, max_yticks : int, optional 570 Maximum number of labeled ticks to plot on x, y axes 571 fontsize : string or int 572 Font size as used by matplotlib text 573 574 Returns 575 ------- 576 self : FacetGrid object 577 578 """ 579 from matplotlib.ticker import MaxNLocator 580 581 # Both are necessary 582 x_major_locator = MaxNLocator(nbins=max_xticks) 583 y_major_locator = MaxNLocator(nbins=max_yticks) 584 585 for ax in self.axes.flat: 586 ax.xaxis.set_major_locator(x_major_locator) 587 ax.yaxis.set_major_locator(y_major_locator) 588 for tick in itertools.chain( 589 ax.xaxis.get_major_ticks(), ax.yaxis.get_major_ticks() 590 ): 591 tick.label1.set_fontsize(fontsize) 592 593 return self 594 595 def map(self, func, *args, **kwargs): 596 """ 597 Apply a plotting function to each facet's subset of the data. 598 599 Parameters 600 ---------- 601 func : callable 602 A plotting function that takes data and keyword arguments. It 603 must plot to the currently active matplotlib Axes and take a 604 `color` keyword argument. If faceting on the `hue` dimension, 605 it must also take a `label` keyword argument. 606 *args : strings 607 Column names in self.data that identify variables with data to 608 plot. The data for each variable is passed to `func` in the 609 order the variables are specified in the call. 610 **kwargs : keyword arguments 611 All keyword arguments are passed to the plotting function. 612 613 Returns 614 ------- 615 self : FacetGrid object 616 617 """ 618 for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): 619 if namedict is not None: 620 data = self.data.loc[namedict] 621 plt.sca(ax) 622 innerargs = [data[a].to_numpy() for a in args] 623 maybe_mappable = func(*innerargs, **kwargs) 624 # TODO: better way to verify that an artist is mappable? 625 # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522 626 if maybe_mappable and hasattr(maybe_mappable, "autoscale_None"): 627 self._mappables.append(maybe_mappable) 628 629 self._finalize_grid(*args[:2]) 630 631 return self 632 633 634def _easy_facetgrid( 635 data, 636 plotfunc, 637 kind, 638 x=None, 639 y=None, 640 row=None, 641 col=None, 642 col_wrap=None, 643 sharex=True, 644 sharey=True, 645 aspect=None, 646 size=None, 647 subplot_kws=None, 648 ax=None, 649 figsize=None, 650 **kwargs, 651): 652 """ 653 Convenience method to call xarray.plot.FacetGrid from 2d plotting methods 654 655 kwargs are the arguments to 2d plotting method 656 """ 657 if ax is not None: 658 raise ValueError("Can't use axes when making faceted plots.") 659 if aspect is None: 660 aspect = 1 661 if size is None: 662 size = 3 663 elif figsize is not None: 664 raise ValueError("cannot provide both `figsize` and `size` arguments") 665 666 g = FacetGrid( 667 data=data, 668 col=col, 669 row=row, 670 col_wrap=col_wrap, 671 sharex=sharex, 672 sharey=sharey, 673 figsize=figsize, 674 aspect=aspect, 675 size=size, 676 subplot_kws=subplot_kws, 677 ) 678 679 if kind == "line": 680 return g.map_dataarray_line(plotfunc, x, y, **kwargs) 681 682 if kind == "dataarray": 683 return g.map_dataarray(plotfunc, x, y, **kwargs) 684 685 if kind == "dataset": 686 return g.map_dataset(plotfunc, x, y, **kwargs) 687