1""" Test cases for misc plot functions """ 2 3import numpy as np 4import pytest 5 6import pandas.util._test_decorators as td 7 8from pandas import DataFrame, Series 9import pandas._testing as tm 10from pandas.tests.plotting.common import TestPlotBase, _check_plot_works 11 12import pandas.plotting as plotting 13 14pytestmark = pytest.mark.slow 15 16 17@td.skip_if_mpl 18def test_import_error_message(): 19 # GH-19810 20 df = DataFrame({"A": [1, 2]}) 21 22 with pytest.raises(ImportError, match="matplotlib is required for plotting"): 23 df.plot() 24 25 26def test_get_accessor_args(): 27 func = plotting._core.PlotAccessor._get_call_args 28 29 msg = "Called plot accessor for type list, expected Series or DataFrame" 30 with pytest.raises(TypeError, match=msg): 31 func(backend_name="", data=[], args=[], kwargs={}) 32 33 msg = "should not be called with positional arguments" 34 with pytest.raises(TypeError, match=msg): 35 func(backend_name="", data=Series(dtype=object), args=["line", None], kwargs={}) 36 37 x, y, kind, kwargs = func( 38 backend_name="", 39 data=DataFrame(), 40 args=["x"], 41 kwargs={"y": "y", "kind": "bar", "grid": False}, 42 ) 43 assert x == "x" 44 assert y == "y" 45 assert kind == "bar" 46 assert kwargs == {"grid": False} 47 48 x, y, kind, kwargs = func( 49 backend_name="pandas.plotting._matplotlib", 50 data=Series(dtype=object), 51 args=[], 52 kwargs={}, 53 ) 54 assert x is None 55 assert y is None 56 assert kind == "line" 57 assert len(kwargs) == 24 58 59 60@td.skip_if_no_mpl 61class TestSeriesPlots(TestPlotBase): 62 def setup_method(self, method): 63 TestPlotBase.setup_method(self, method) 64 import matplotlib as mpl 65 66 mpl.rcdefaults() 67 68 self.ts = tm.makeTimeSeries() 69 self.ts.name = "ts" 70 71 def test_autocorrelation_plot(self): 72 from pandas.plotting import autocorrelation_plot 73 74 # Ensure no UserWarning when making plot 75 with tm.assert_produces_warning(None): 76 _check_plot_works(autocorrelation_plot, series=self.ts) 77 _check_plot_works(autocorrelation_plot, series=self.ts.values) 78 79 ax = autocorrelation_plot(self.ts, label="Test") 80 self._check_legend_labels(ax, labels=["Test"]) 81 82 def test_lag_plot(self): 83 from pandas.plotting import lag_plot 84 85 _check_plot_works(lag_plot, series=self.ts) 86 _check_plot_works(lag_plot, series=self.ts, lag=5) 87 88 def test_bootstrap_plot(self): 89 from pandas.plotting import bootstrap_plot 90 91 _check_plot_works(bootstrap_plot, series=self.ts, size=10) 92 93 94@td.skip_if_no_mpl 95class TestDataFramePlots(TestPlotBase): 96 @td.skip_if_no_scipy 97 def test_scatter_matrix_axis(self): 98 from pandas.plotting._matplotlib.compat import mpl_ge_3_0_0 99 100 scatter_matrix = plotting.scatter_matrix 101 102 with tm.RNGContext(42): 103 df = DataFrame(np.random.randn(100, 3)) 104 105 # we are plotting multiples on a sub-plot 106 with tm.assert_produces_warning( 107 UserWarning, raise_on_extra_warnings=mpl_ge_3_0_0() 108 ): 109 axes = _check_plot_works( 110 scatter_matrix, filterwarnings="always", frame=df, range_padding=0.1 111 ) 112 axes0_labels = axes[0][0].yaxis.get_majorticklabels() 113 114 # GH 5662 115 expected = ["-2", "0", "2"] 116 self._check_text_labels(axes0_labels, expected) 117 self._check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0) 118 119 df[0] = (df[0] - 2) / 3 120 121 # we are plotting multiples on a sub-plot 122 with tm.assert_produces_warning(UserWarning): 123 axes = _check_plot_works( 124 scatter_matrix, filterwarnings="always", frame=df, range_padding=0.1 125 ) 126 axes0_labels = axes[0][0].yaxis.get_majorticklabels() 127 expected = ["-1.0", "-0.5", "0.0"] 128 self._check_text_labels(axes0_labels, expected) 129 self._check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0) 130 131 def test_andrews_curves(self, iris): 132 from matplotlib import cm 133 134 from pandas.plotting import andrews_curves 135 136 df = iris 137 # Ensure no UserWarning when making plot 138 with tm.assert_produces_warning(None): 139 _check_plot_works(andrews_curves, frame=df, class_column="Name") 140 141 rgba = ("#556270", "#4ECDC4", "#C7F464") 142 ax = _check_plot_works( 143 andrews_curves, frame=df, class_column="Name", color=rgba 144 ) 145 self._check_colors( 146 ax.get_lines()[:10], linecolors=rgba, mapping=df["Name"][:10] 147 ) 148 149 cnames = ["dodgerblue", "aquamarine", "seagreen"] 150 ax = _check_plot_works( 151 andrews_curves, frame=df, class_column="Name", color=cnames 152 ) 153 self._check_colors( 154 ax.get_lines()[:10], linecolors=cnames, mapping=df["Name"][:10] 155 ) 156 157 ax = _check_plot_works( 158 andrews_curves, frame=df, class_column="Name", colormap=cm.jet 159 ) 160 cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())] 161 self._check_colors( 162 ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10] 163 ) 164 165 length = 10 166 df = DataFrame( 167 { 168 "A": np.random.rand(length), 169 "B": np.random.rand(length), 170 "C": np.random.rand(length), 171 "Name": ["A"] * length, 172 } 173 ) 174 175 _check_plot_works(andrews_curves, frame=df, class_column="Name") 176 177 rgba = ("#556270", "#4ECDC4", "#C7F464") 178 ax = _check_plot_works( 179 andrews_curves, frame=df, class_column="Name", color=rgba 180 ) 181 self._check_colors( 182 ax.get_lines()[:10], linecolors=rgba, mapping=df["Name"][:10] 183 ) 184 185 cnames = ["dodgerblue", "aquamarine", "seagreen"] 186 ax = _check_plot_works( 187 andrews_curves, frame=df, class_column="Name", color=cnames 188 ) 189 self._check_colors( 190 ax.get_lines()[:10], linecolors=cnames, mapping=df["Name"][:10] 191 ) 192 193 ax = _check_plot_works( 194 andrews_curves, frame=df, class_column="Name", colormap=cm.jet 195 ) 196 cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())] 197 self._check_colors( 198 ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10] 199 ) 200 201 colors = ["b", "g", "r"] 202 df = DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "C": [1, 2, 3], "Name": colors}) 203 ax = andrews_curves(df, "Name", color=colors) 204 handles, labels = ax.get_legend_handles_labels() 205 self._check_colors(handles, linecolors=colors) 206 207 def test_parallel_coordinates(self, iris): 208 from matplotlib import cm 209 210 from pandas.plotting import parallel_coordinates 211 212 df = iris 213 214 ax = _check_plot_works(parallel_coordinates, frame=df, class_column="Name") 215 nlines = len(ax.get_lines()) 216 nxticks = len(ax.xaxis.get_ticklabels()) 217 218 rgba = ("#556270", "#4ECDC4", "#C7F464") 219 ax = _check_plot_works( 220 parallel_coordinates, frame=df, class_column="Name", color=rgba 221 ) 222 self._check_colors( 223 ax.get_lines()[:10], linecolors=rgba, mapping=df["Name"][:10] 224 ) 225 226 cnames = ["dodgerblue", "aquamarine", "seagreen"] 227 ax = _check_plot_works( 228 parallel_coordinates, frame=df, class_column="Name", color=cnames 229 ) 230 self._check_colors( 231 ax.get_lines()[:10], linecolors=cnames, mapping=df["Name"][:10] 232 ) 233 234 ax = _check_plot_works( 235 parallel_coordinates, frame=df, class_column="Name", colormap=cm.jet 236 ) 237 cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())] 238 self._check_colors( 239 ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10] 240 ) 241 242 ax = _check_plot_works( 243 parallel_coordinates, frame=df, class_column="Name", axvlines=False 244 ) 245 assert len(ax.get_lines()) == (nlines - nxticks) 246 247 colors = ["b", "g", "r"] 248 df = DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "C": [1, 2, 3], "Name": colors}) 249 ax = parallel_coordinates(df, "Name", color=colors) 250 handles, labels = ax.get_legend_handles_labels() 251 self._check_colors(handles, linecolors=colors) 252 253 # not sure if this is indicative of a problem 254 @pytest.mark.filterwarnings("ignore:Attempting to set:UserWarning") 255 def test_parallel_coordinates_with_sorted_labels(self): 256 """ For #15908 """ 257 from pandas.plotting import parallel_coordinates 258 259 df = DataFrame( 260 { 261 "feat": list(range(30)), 262 "class": [2 for _ in range(10)] 263 + [3 for _ in range(10)] 264 + [1 for _ in range(10)], 265 } 266 ) 267 ax = parallel_coordinates(df, "class", sort_labels=True) 268 polylines, labels = ax.get_legend_handles_labels() 269 color_label_tuples = zip( 270 [polyline.get_color() for polyline in polylines], labels 271 ) 272 ordered_color_label_tuples = sorted(color_label_tuples, key=lambda x: x[1]) 273 prev_next_tupels = zip( 274 list(ordered_color_label_tuples[0:-1]), list(ordered_color_label_tuples[1:]) 275 ) 276 for prev, nxt in prev_next_tupels: 277 # labels and colors are ordered strictly increasing 278 assert prev[1] < nxt[1] and prev[0] < nxt[0] 279 280 def test_radviz(self, iris): 281 from matplotlib import cm 282 283 from pandas.plotting import radviz 284 285 df = iris 286 # Ensure no UserWarning when making plot 287 with tm.assert_produces_warning(None): 288 _check_plot_works(radviz, frame=df, class_column="Name") 289 290 rgba = ("#556270", "#4ECDC4", "#C7F464") 291 ax = _check_plot_works(radviz, frame=df, class_column="Name", color=rgba) 292 # skip Circle drawn as ticks 293 patches = [p for p in ax.patches[:20] if p.get_label() != ""] 294 self._check_colors(patches[:10], facecolors=rgba, mapping=df["Name"][:10]) 295 296 cnames = ["dodgerblue", "aquamarine", "seagreen"] 297 _check_plot_works(radviz, frame=df, class_column="Name", color=cnames) 298 patches = [p for p in ax.patches[:20] if p.get_label() != ""] 299 self._check_colors(patches, facecolors=cnames, mapping=df["Name"][:10]) 300 301 _check_plot_works(radviz, frame=df, class_column="Name", colormap=cm.jet) 302 cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())] 303 patches = [p for p in ax.patches[:20] if p.get_label() != ""] 304 self._check_colors(patches, facecolors=cmaps, mapping=df["Name"][:10]) 305 306 colors = [[0.0, 0.0, 1.0, 1.0], [0.0, 0.5, 1.0, 1.0], [1.0, 0.0, 0.0, 1.0]] 307 df = DataFrame( 308 {"A": [1, 2, 3], "B": [2, 1, 3], "C": [3, 2, 1], "Name": ["b", "g", "r"]} 309 ) 310 ax = radviz(df, "Name", color=colors) 311 handles, labels = ax.get_legend_handles_labels() 312 self._check_colors(handles, facecolors=colors) 313 314 def test_subplot_titles(self, iris): 315 df = iris.drop("Name", axis=1).head() 316 # Use the column names as the subplot titles 317 title = list(df.columns) 318 319 # Case len(title) == len(df) 320 plot = df.plot(subplots=True, title=title) 321 assert [p.get_title() for p in plot] == title 322 323 # Case len(title) > len(df) 324 msg = ( 325 "The length of `title` must equal the number of columns if " 326 "using `title` of type `list` and `subplots=True`" 327 ) 328 with pytest.raises(ValueError, match=msg): 329 df.plot(subplots=True, title=title + ["kittens > puppies"]) 330 331 # Case len(title) < len(df) 332 with pytest.raises(ValueError, match=msg): 333 df.plot(subplots=True, title=title[:2]) 334 335 # Case subplots=False and title is of type list 336 msg = ( 337 "Using `title` of type `list` is not supported unless " 338 "`subplots=True` is passed" 339 ) 340 with pytest.raises(ValueError, match=msg): 341 df.plot(subplots=False, title=title) 342 343 # Case df with 3 numeric columns but layout of (2,2) 344 plot = df.drop("SepalWidth", axis=1).plot( 345 subplots=True, layout=(2, 2), title=title[:-1] 346 ) 347 title_list = [ax.get_title() for sublist in plot for ax in sublist] 348 assert title_list == title[:3] + [""] 349 350 def test_get_standard_colors_random_seed(self): 351 # GH17525 352 df = DataFrame(np.zeros((10, 10))) 353 354 # Make sure that the np.random.seed isn't reset by get_standard_colors 355 plotting.parallel_coordinates(df, 0) 356 rand1 = np.random.random() 357 plotting.parallel_coordinates(df, 0) 358 rand2 = np.random.random() 359 assert rand1 != rand2 360 361 # Make sure it produces the same colors every time it's called 362 from pandas.plotting._matplotlib.style import get_standard_colors 363 364 color1 = get_standard_colors(1, color_type="random") 365 color2 = get_standard_colors(1, color_type="random") 366 assert color1 == color2 367 368 def test_get_standard_colors_default_num_colors(self): 369 from pandas.plotting._matplotlib.style import get_standard_colors 370 371 # Make sure the default color_types returns the specified amount 372 color1 = get_standard_colors(1, color_type="default") 373 color2 = get_standard_colors(9, color_type="default") 374 color3 = get_standard_colors(20, color_type="default") 375 assert len(color1) == 1 376 assert len(color2) == 9 377 assert len(color3) == 20 378 379 def test_plot_single_color(self): 380 # Example from #20585. All 3 bars should have the same color 381 df = DataFrame( 382 { 383 "account-start": ["2017-02-03", "2017-03-03", "2017-01-01"], 384 "client": ["Alice Anders", "Bob Baker", "Charlie Chaplin"], 385 "balance": [-1432.32, 10.43, 30000.00], 386 "db-id": [1234, 2424, 251], 387 "proxy-id": [525, 1525, 2542], 388 "rank": [52, 525, 32], 389 } 390 ) 391 ax = df.client.value_counts().plot.bar() 392 colors = [rect.get_facecolor() for rect in ax.get_children()[0:3]] 393 assert all(color == colors[0] for color in colors) 394 395 def test_get_standard_colors_no_appending(self): 396 # GH20726 397 398 # Make sure not to add more colors so that matplotlib can cycle 399 # correctly. 400 from matplotlib import cm 401 402 from pandas.plotting._matplotlib.style import get_standard_colors 403 404 color_before = cm.gnuplot(range(5)) 405 color_after = get_standard_colors(1, color=color_before) 406 assert len(color_after) == len(color_before) 407 408 df = DataFrame(np.random.randn(48, 4), columns=list("ABCD")) 409 410 color_list = cm.gnuplot(np.linspace(0, 1, 16)) 411 p = df.A.plot.bar(figsize=(16, 7), color=color_list) 412 assert p.patches[1].get_facecolor() == p.patches[17].get_facecolor() 413 414 def test_dictionary_color(self): 415 # issue-8193 416 # Test plot color dictionary format 417 data_files = ["a", "b"] 418 419 expected = [(0.5, 0.24, 0.6), (0.3, 0.7, 0.7)] 420 421 df1 = DataFrame(np.random.rand(2, 2), columns=data_files) 422 dic_color = {"b": (0.3, 0.7, 0.7), "a": (0.5, 0.24, 0.6)} 423 424 # Bar color test 425 ax = df1.plot(kind="bar", color=dic_color) 426 colors = [rect.get_facecolor()[0:-1] for rect in ax.get_children()[0:3:2]] 427 assert all(color == expected[index] for index, color in enumerate(colors)) 428 429 # Line color test 430 ax = df1.plot(kind="line", color=dic_color) 431 colors = [rect.get_color() for rect in ax.get_lines()[0:2]] 432 assert all(color == expected[index] for index, color in enumerate(colors)) 433 434 def test_has_externally_shared_axis_x_axis(self): 435 # GH33819 436 # Test _has_externally_shared_axis() works for x-axis 437 func = plotting._matplotlib.tools._has_externally_shared_axis 438 439 fig = self.plt.figure() 440 plots = fig.subplots(2, 4) 441 442 # Create *externally* shared axes for first and third columns 443 plots[0][0] = fig.add_subplot(231, sharex=plots[1][0]) 444 plots[0][2] = fig.add_subplot(233, sharex=plots[1][2]) 445 446 # Create *internally* shared axes for second and third columns 447 plots[0][1].twinx() 448 plots[0][2].twinx() 449 450 # First column is only externally shared 451 # Second column is only internally shared 452 # Third column is both 453 # Fourth column is neither 454 assert func(plots[0][0], "x") 455 assert not func(plots[0][1], "x") 456 assert func(plots[0][2], "x") 457 assert not func(plots[0][3], "x") 458 459 def test_has_externally_shared_axis_y_axis(self): 460 # GH33819 461 # Test _has_externally_shared_axis() works for y-axis 462 func = plotting._matplotlib.tools._has_externally_shared_axis 463 464 fig = self.plt.figure() 465 plots = fig.subplots(4, 2) 466 467 # Create *externally* shared axes for first and third rows 468 plots[0][0] = fig.add_subplot(321, sharey=plots[0][1]) 469 plots[2][0] = fig.add_subplot(325, sharey=plots[2][1]) 470 471 # Create *internally* shared axes for second and third rows 472 plots[1][0].twiny() 473 plots[2][0].twiny() 474 475 # First row is only externally shared 476 # Second row is only internally shared 477 # Third row is both 478 # Fourth row is neither 479 assert func(plots[0][0], "y") 480 assert not func(plots[1][0], "y") 481 assert func(plots[2][0], "y") 482 assert not func(plots[3][0], "y") 483 484 def test_has_externally_shared_axis_invalid_compare_axis(self): 485 # GH33819 486 # Test _has_externally_shared_axis() raises an exception when 487 # passed an invalid value as compare_axis parameter 488 func = plotting._matplotlib.tools._has_externally_shared_axis 489 490 fig = self.plt.figure() 491 plots = fig.subplots(4, 2) 492 493 # Create arbitrary axes 494 plots[0][0] = fig.add_subplot(321, sharey=plots[0][1]) 495 496 # Check that an invalid compare_axis value triggers the expected exception 497 msg = "needs 'x' or 'y' as a second parameter" 498 with pytest.raises(ValueError, match=msg): 499 func(plots[0][0], "z") 500 501 def test_externally_shared_axes(self): 502 # Example from GH33819 503 # Create data 504 df = DataFrame({"a": np.random.randn(1000), "b": np.random.randn(1000)}) 505 506 # Create figure 507 fig = self.plt.figure() 508 plots = fig.subplots(2, 3) 509 510 # Create *externally* shared axes 511 plots[0][0] = fig.add_subplot(231, sharex=plots[1][0]) 512 # note: no plots[0][1] that's the twin only case 513 plots[0][2] = fig.add_subplot(233, sharex=plots[1][2]) 514 515 # Create *internally* shared axes 516 # note: no plots[0][0] that's the external only case 517 twin_ax1 = plots[0][1].twinx() 518 twin_ax2 = plots[0][2].twinx() 519 520 # Plot data to primary axes 521 df["a"].plot(ax=plots[0][0], title="External share only").set_xlabel( 522 "this label should never be visible" 523 ) 524 df["a"].plot(ax=plots[1][0]) 525 526 df["a"].plot(ax=plots[0][1], title="Internal share (twin) only").set_xlabel( 527 "this label should always be visible" 528 ) 529 df["a"].plot(ax=plots[1][1]) 530 531 df["a"].plot(ax=plots[0][2], title="Both").set_xlabel( 532 "this label should never be visible" 533 ) 534 df["a"].plot(ax=plots[1][2]) 535 536 # Plot data to twinned axes 537 df["b"].plot(ax=twin_ax1, color="green") 538 df["b"].plot(ax=twin_ax2, color="yellow") 539 540 assert not plots[0][0].xaxis.get_label().get_visible() 541 assert plots[0][1].xaxis.get_label().get_visible() 542 assert not plots[0][2].xaxis.get_label().get_visible() 543