1""" 2Module consolidating common testing functions for checking plotting. 3 4Currently all plotting tests are marked as slow via 5``pytestmark = pytest.mark.slow`` at the module level. 6""" 7 8import os 9from typing import TYPE_CHECKING, Sequence, Union 10import warnings 11 12import numpy as np 13 14from pandas.util._decorators import cache_readonly 15import pandas.util._test_decorators as td 16 17from pandas.core.dtypes.api import is_list_like 18 19import pandas as pd 20from pandas import DataFrame, Series, to_datetime 21import pandas._testing as tm 22 23if TYPE_CHECKING: 24 from matplotlib.axes import Axes 25 26 27@td.skip_if_no_mpl 28class TestPlotBase: 29 """ 30 This is a common base class used for various plotting tests 31 """ 32 33 def setup_method(self, method): 34 35 import matplotlib as mpl 36 37 from pandas.plotting._matplotlib import compat 38 39 mpl.rcdefaults() 40 41 self.start_date_to_int64 = 812419200000000000 42 self.end_date_to_int64 = 819331200000000000 43 44 self.mpl_ge_2_2_3 = compat.mpl_ge_2_2_3() 45 self.mpl_ge_3_0_0 = compat.mpl_ge_3_0_0() 46 self.mpl_ge_3_1_0 = compat.mpl_ge_3_1_0() 47 self.mpl_ge_3_2_0 = compat.mpl_ge_3_2_0() 48 49 self.bp_n_objects = 7 50 self.polycollection_factor = 2 51 self.default_figsize = (6.4, 4.8) 52 self.default_tick_position = "left" 53 54 n = 100 55 with tm.RNGContext(42): 56 gender = np.random.choice(["Male", "Female"], size=n) 57 classroom = np.random.choice(["A", "B", "C"], size=n) 58 59 self.hist_df = DataFrame( 60 { 61 "gender": gender, 62 "classroom": classroom, 63 "height": np.random.normal(66, 4, size=n), 64 "weight": np.random.normal(161, 32, size=n), 65 "category": np.random.randint(4, size=n), 66 "datetime": to_datetime( 67 np.random.randint( 68 self.start_date_to_int64, 69 self.end_date_to_int64, 70 size=n, 71 dtype=np.int64, 72 ) 73 ), 74 } 75 ) 76 77 self.tdf = tm.makeTimeDataFrame() 78 self.hexbin_df = DataFrame( 79 { 80 "A": np.random.uniform(size=20), 81 "B": np.random.uniform(size=20), 82 "C": np.arange(20) + np.random.uniform(size=20), 83 } 84 ) 85 86 def teardown_method(self, method): 87 tm.close() 88 89 @cache_readonly 90 def plt(self): 91 import matplotlib.pyplot as plt 92 93 return plt 94 95 @cache_readonly 96 def colorconverter(self): 97 import matplotlib.colors as colors 98 99 return colors.colorConverter 100 101 def _check_legend_labels(self, axes, labels=None, visible=True): 102 """ 103 Check each axes has expected legend labels 104 105 Parameters 106 ---------- 107 axes : matplotlib Axes object, or its list-like 108 labels : list-like 109 expected legend labels 110 visible : bool 111 expected legend visibility. labels are checked only when visible is 112 True 113 """ 114 if visible and (labels is None): 115 raise ValueError("labels must be specified when visible is True") 116 axes = self._flatten_visible(axes) 117 for ax in axes: 118 if visible: 119 assert ax.get_legend() is not None 120 self._check_text_labels(ax.get_legend().get_texts(), labels) 121 else: 122 assert ax.get_legend() is None 123 124 def _check_legend_marker(self, ax, expected_markers=None, visible=True): 125 """ 126 Check ax has expected legend markers 127 128 Parameters 129 ---------- 130 ax : matplotlib Axes object 131 expected_markers : list-like 132 expected legend markers 133 visible : bool 134 expected legend visibility. labels are checked only when visible is 135 True 136 """ 137 if visible and (expected_markers is None): 138 raise ValueError("Markers must be specified when visible is True") 139 if visible: 140 handles, _ = ax.get_legend_handles_labels() 141 markers = [handle.get_marker() for handle in handles] 142 assert markers == expected_markers 143 else: 144 assert ax.get_legend() is None 145 146 def _check_data(self, xp, rs): 147 """ 148 Check each axes has identical lines 149 150 Parameters 151 ---------- 152 xp : matplotlib Axes object 153 rs : matplotlib Axes object 154 """ 155 xp_lines = xp.get_lines() 156 rs_lines = rs.get_lines() 157 158 def check_line(xpl, rsl): 159 xpdata = xpl.get_xydata() 160 rsdata = rsl.get_xydata() 161 tm.assert_almost_equal(xpdata, rsdata) 162 163 assert len(xp_lines) == len(rs_lines) 164 [check_line(xpl, rsl) for xpl, rsl in zip(xp_lines, rs_lines)] 165 tm.close() 166 167 def _check_visible(self, collections, visible=True): 168 """ 169 Check each artist is visible or not 170 171 Parameters 172 ---------- 173 collections : matplotlib Artist or its list-like 174 target Artist or its list or collection 175 visible : bool 176 expected visibility 177 """ 178 from matplotlib.collections import Collection 179 180 if not isinstance(collections, Collection) and not is_list_like(collections): 181 collections = [collections] 182 183 for patch in collections: 184 assert patch.get_visible() == visible 185 186 def _check_patches_all_filled( 187 self, axes: Union["Axes", Sequence["Axes"]], filled: bool = True 188 ) -> None: 189 """ 190 Check for each artist whether it is filled or not 191 192 Parameters 193 ---------- 194 axes : matplotlib Axes object, or its list-like 195 filled : bool 196 expected filling 197 """ 198 199 axes = self._flatten_visible(axes) 200 for ax in axes: 201 for patch in ax.patches: 202 assert patch.fill == filled 203 204 def _get_colors_mapped(self, series, colors): 205 unique = series.unique() 206 # unique and colors length can be differed 207 # depending on slice value 208 mapped = dict(zip(unique, colors)) 209 return [mapped[v] for v in series.values] 210 211 def _check_colors( 212 self, collections, linecolors=None, facecolors=None, mapping=None 213 ): 214 """ 215 Check each artist has expected line colors and face colors 216 217 Parameters 218 ---------- 219 collections : list-like 220 list or collection of target artist 221 linecolors : list-like which has the same length as collections 222 list of expected line colors 223 facecolors : list-like which has the same length as collections 224 list of expected face colors 225 mapping : Series 226 Series used for color grouping key 227 used for andrew_curves, parallel_coordinates, radviz test 228 """ 229 from matplotlib.collections import Collection, LineCollection, PolyCollection 230 from matplotlib.lines import Line2D 231 232 conv = self.colorconverter 233 if linecolors is not None: 234 235 if mapping is not None: 236 linecolors = self._get_colors_mapped(mapping, linecolors) 237 linecolors = linecolors[: len(collections)] 238 239 assert len(collections) == len(linecolors) 240 for patch, color in zip(collections, linecolors): 241 if isinstance(patch, Line2D): 242 result = patch.get_color() 243 # Line2D may contains string color expression 244 result = conv.to_rgba(result) 245 elif isinstance(patch, (PolyCollection, LineCollection)): 246 result = tuple(patch.get_edgecolor()[0]) 247 else: 248 result = patch.get_edgecolor() 249 250 expected = conv.to_rgba(color) 251 assert result == expected 252 253 if facecolors is not None: 254 255 if mapping is not None: 256 facecolors = self._get_colors_mapped(mapping, facecolors) 257 facecolors = facecolors[: len(collections)] 258 259 assert len(collections) == len(facecolors) 260 for patch, color in zip(collections, facecolors): 261 if isinstance(patch, Collection): 262 # returned as list of np.array 263 result = patch.get_facecolor()[0] 264 else: 265 result = patch.get_facecolor() 266 267 if isinstance(result, np.ndarray): 268 result = tuple(result) 269 270 expected = conv.to_rgba(color) 271 assert result == expected 272 273 def _check_text_labels(self, texts, expected): 274 """ 275 Check each text has expected labels 276 277 Parameters 278 ---------- 279 texts : matplotlib Text object, or its list-like 280 target text, or its list 281 expected : str or list-like which has the same length as texts 282 expected text label, or its list 283 """ 284 if not is_list_like(texts): 285 assert texts.get_text() == expected 286 else: 287 labels = [t.get_text() for t in texts] 288 assert len(labels) == len(expected) 289 for label, e in zip(labels, expected): 290 assert label == e 291 292 def _check_ticks_props( 293 self, axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None 294 ): 295 """ 296 Check each axes has expected tick properties 297 298 Parameters 299 ---------- 300 axes : matplotlib Axes object, or its list-like 301 xlabelsize : number 302 expected xticks font size 303 xrot : number 304 expected xticks rotation 305 ylabelsize : number 306 expected yticks font size 307 yrot : number 308 expected yticks rotation 309 """ 310 from matplotlib.ticker import NullFormatter 311 312 axes = self._flatten_visible(axes) 313 for ax in axes: 314 if xlabelsize is not None or xrot is not None: 315 if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter): 316 # If minor ticks has NullFormatter, rot / fontsize are not 317 # retained 318 labels = ax.get_xticklabels() 319 else: 320 labels = ax.get_xticklabels() + ax.get_xticklabels(minor=True) 321 322 for label in labels: 323 if xlabelsize is not None: 324 tm.assert_almost_equal(label.get_fontsize(), xlabelsize) 325 if xrot is not None: 326 tm.assert_almost_equal(label.get_rotation(), xrot) 327 328 if ylabelsize is not None or yrot is not None: 329 if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter): 330 labels = ax.get_yticklabels() 331 else: 332 labels = ax.get_yticklabels() + ax.get_yticklabels(minor=True) 333 334 for label in labels: 335 if ylabelsize is not None: 336 tm.assert_almost_equal(label.get_fontsize(), ylabelsize) 337 if yrot is not None: 338 tm.assert_almost_equal(label.get_rotation(), yrot) 339 340 def _check_ax_scales(self, axes, xaxis="linear", yaxis="linear"): 341 """ 342 Check each axes has expected scales 343 344 Parameters 345 ---------- 346 axes : matplotlib Axes object, or its list-like 347 xaxis : {'linear', 'log'} 348 expected xaxis scale 349 yaxis : {'linear', 'log'} 350 expected yaxis scale 351 """ 352 axes = self._flatten_visible(axes) 353 for ax in axes: 354 assert ax.xaxis.get_scale() == xaxis 355 assert ax.yaxis.get_scale() == yaxis 356 357 def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=None): 358 """ 359 Check expected number of axes is drawn in expected layout 360 361 Parameters 362 ---------- 363 axes : matplotlib Axes object, or its list-like 364 axes_num : number 365 expected number of axes. Unnecessary axes should be set to 366 invisible. 367 layout : tuple 368 expected layout, (expected number of rows , columns) 369 figsize : tuple 370 expected figsize. default is matplotlib default 371 """ 372 from pandas.plotting._matplotlib.tools import flatten_axes 373 374 if figsize is None: 375 figsize = self.default_figsize 376 visible_axes = self._flatten_visible(axes) 377 378 if axes_num is not None: 379 assert len(visible_axes) == axes_num 380 for ax in visible_axes: 381 # check something drawn on visible axes 382 assert len(ax.get_children()) > 0 383 384 if layout is not None: 385 result = self._get_axes_layout(flatten_axes(axes)) 386 assert result == layout 387 388 tm.assert_numpy_array_equal( 389 visible_axes[0].figure.get_size_inches(), 390 np.array(figsize, dtype=np.float64), 391 ) 392 393 def _get_axes_layout(self, axes): 394 x_set = set() 395 y_set = set() 396 for ax in axes: 397 # check axes coordinates to estimate layout 398 points = ax.get_position().get_points() 399 x_set.add(points[0][0]) 400 y_set.add(points[0][1]) 401 return (len(y_set), len(x_set)) 402 403 def _flatten_visible(self, axes): 404 """ 405 Flatten axes, and filter only visible 406 407 Parameters 408 ---------- 409 axes : matplotlib Axes object, or its list-like 410 411 """ 412 from pandas.plotting._matplotlib.tools import flatten_axes 413 414 axes = flatten_axes(axes) 415 axes = [ax for ax in axes if ax.get_visible()] 416 return axes 417 418 def _check_has_errorbars(self, axes, xerr=0, yerr=0): 419 """ 420 Check axes has expected number of errorbars 421 422 Parameters 423 ---------- 424 axes : matplotlib Axes object, or its list-like 425 xerr : number 426 expected number of x errorbar 427 yerr : number 428 expected number of y errorbar 429 """ 430 axes = self._flatten_visible(axes) 431 for ax in axes: 432 containers = ax.containers 433 xerr_count = 0 434 yerr_count = 0 435 for c in containers: 436 has_xerr = getattr(c, "has_xerr", False) 437 has_yerr = getattr(c, "has_yerr", False) 438 if has_xerr: 439 xerr_count += 1 440 if has_yerr: 441 yerr_count += 1 442 assert xerr == xerr_count 443 assert yerr == yerr_count 444 445 def _check_box_return_type( 446 self, returned, return_type, expected_keys=None, check_ax_title=True 447 ): 448 """ 449 Check box returned type is correct 450 451 Parameters 452 ---------- 453 returned : object to be tested, returned from boxplot 454 return_type : str 455 return_type passed to boxplot 456 expected_keys : list-like, optional 457 group labels in subplot case. If not passed, 458 the function checks assuming boxplot uses single ax 459 check_ax_title : bool 460 Whether to check the ax.title is the same as expected_key 461 Intended to be checked by calling from ``boxplot``. 462 Normal ``plot`` doesn't attach ``ax.title``, it must be disabled. 463 """ 464 from matplotlib.axes import Axes 465 466 types = {"dict": dict, "axes": Axes, "both": tuple} 467 if expected_keys is None: 468 # should be fixed when the returning default is changed 469 if return_type is None: 470 return_type = "dict" 471 472 assert isinstance(returned, types[return_type]) 473 if return_type == "both": 474 assert isinstance(returned.ax, Axes) 475 assert isinstance(returned.lines, dict) 476 else: 477 # should be fixed when the returning default is changed 478 if return_type is None: 479 for r in self._flatten_visible(returned): 480 assert isinstance(r, Axes) 481 return 482 483 assert isinstance(returned, Series) 484 485 assert sorted(returned.keys()) == sorted(expected_keys) 486 for key, value in returned.items(): 487 assert isinstance(value, types[return_type]) 488 # check returned dict has correct mapping 489 if return_type == "axes": 490 if check_ax_title: 491 assert value.get_title() == key 492 elif return_type == "both": 493 if check_ax_title: 494 assert value.ax.get_title() == key 495 assert isinstance(value.ax, Axes) 496 assert isinstance(value.lines, dict) 497 elif return_type == "dict": 498 line = value["medians"][0] 499 axes = line.axes 500 if check_ax_title: 501 assert axes.get_title() == key 502 else: 503 raise AssertionError 504 505 def _check_grid_settings(self, obj, kinds, kws={}): 506 # Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792 507 508 import matplotlib as mpl 509 510 def is_grid_on(): 511 xticks = self.plt.gca().xaxis.get_major_ticks() 512 yticks = self.plt.gca().yaxis.get_major_ticks() 513 # for mpl 2.2.2, gridOn and gridline.get_visible disagree. 514 # for new MPL, they are the same. 515 516 if self.mpl_ge_3_1_0: 517 xoff = all(not g.gridline.get_visible() for g in xticks) 518 yoff = all(not g.gridline.get_visible() for g in yticks) 519 else: 520 xoff = all(not g.gridOn for g in xticks) 521 yoff = all(not g.gridOn for g in yticks) 522 523 return not (xoff and yoff) 524 525 spndx = 1 526 for kind in kinds: 527 528 self.plt.subplot(1, 4 * len(kinds), spndx) 529 spndx += 1 530 mpl.rc("axes", grid=False) 531 obj.plot(kind=kind, **kws) 532 assert not is_grid_on() 533 534 self.plt.subplot(1, 4 * len(kinds), spndx) 535 spndx += 1 536 mpl.rc("axes", grid=True) 537 obj.plot(kind=kind, grid=False, **kws) 538 assert not is_grid_on() 539 540 if kind != "pie": 541 self.plt.subplot(1, 4 * len(kinds), spndx) 542 spndx += 1 543 mpl.rc("axes", grid=True) 544 obj.plot(kind=kind, **kws) 545 assert is_grid_on() 546 547 self.plt.subplot(1, 4 * len(kinds), spndx) 548 spndx += 1 549 mpl.rc("axes", grid=False) 550 obj.plot(kind=kind, grid=True, **kws) 551 assert is_grid_on() 552 553 def _unpack_cycler(self, rcParams, field="color"): 554 """ 555 Auxiliary function for correctly unpacking cycler after MPL >= 1.5 556 """ 557 return [v[field] for v in rcParams["axes.prop_cycle"]] 558 559 560def _check_plot_works(f, filterwarnings="always", default_axes=False, **kwargs): 561 """ 562 Create plot and ensure that plot return object is valid. 563 564 Parameters 565 ---------- 566 f : func 567 Plotting function. 568 filterwarnings : str 569 Warnings filter. 570 See https://docs.python.org/3/library/warnings.html#warning-filter 571 default_axes : bool, optional 572 If False (default): 573 - If `ax` not in `kwargs`, then create subplot(211) and plot there 574 - Create new subplot(212) and plot there as well 575 - Mind special corner case for bootstrap_plot (see `_gen_two_subplots`) 576 If True: 577 - Simply run plotting function with kwargs provided 578 - All required axes instances will be created automatically 579 - It is recommended to use it when the plotting function 580 creates multiple axes itself. It helps avoid warnings like 581 'UserWarning: To output multiple subplots, 582 the figure containing the passed axes is being cleared' 583 **kwargs 584 Keyword arguments passed to the plotting function. 585 586 Returns 587 ------- 588 Plot object returned by the last plotting. 589 """ 590 import matplotlib.pyplot as plt 591 592 if default_axes: 593 gen_plots = _gen_default_plot 594 else: 595 gen_plots = _gen_two_subplots 596 597 ret = None 598 with warnings.catch_warnings(): 599 warnings.simplefilter(filterwarnings) 600 try: 601 fig = kwargs.get("figure", plt.gcf()) 602 plt.clf() 603 604 for ret in gen_plots(f, fig, **kwargs): 605 tm.assert_is_valid_plot_return_object(ret) 606 607 with tm.ensure_clean(return_filelike=True) as path: 608 plt.savefig(path) 609 610 except Exception as err: 611 raise err 612 finally: 613 tm.close(fig) 614 615 return ret 616 617 618def _gen_default_plot(f, fig, **kwargs): 619 """ 620 Create plot in a default way. 621 """ 622 yield f(**kwargs) 623 624 625def _gen_two_subplots(f, fig, **kwargs): 626 """ 627 Create plot on two subplots forcefully created. 628 """ 629 kwargs.get("ax", fig.add_subplot(211)) 630 yield f(**kwargs) 631 632 if f is pd.plotting.bootstrap_plot: 633 assert "ax" not in kwargs 634 else: 635 kwargs["ax"] = fig.add_subplot(212) 636 yield f(**kwargs) 637 638 639def curpath(): 640 pth, _ = os.path.split(os.path.abspath(__file__)) 641 return pth 642