1""" Test cases for misc plot functions """
2
3import numpy as np
4import pytest
5
6import pandas.util._test_decorators as td
7
8from pandas import DataFrame, Series
9import pandas._testing as tm
10from pandas.tests.plotting.common import TestPlotBase, _check_plot_works
11
12import pandas.plotting as plotting
13
14pytestmark = pytest.mark.slow
15
16
17@td.skip_if_mpl
18def test_import_error_message():
19    # GH-19810
20    df = DataFrame({"A": [1, 2]})
21
22    with pytest.raises(ImportError, match="matplotlib is required for plotting"):
23        df.plot()
24
25
26def test_get_accessor_args():
27    func = plotting._core.PlotAccessor._get_call_args
28
29    msg = "Called plot accessor for type list, expected Series or DataFrame"
30    with pytest.raises(TypeError, match=msg):
31        func(backend_name="", data=[], args=[], kwargs={})
32
33    msg = "should not be called with positional arguments"
34    with pytest.raises(TypeError, match=msg):
35        func(backend_name="", data=Series(dtype=object), args=["line", None], kwargs={})
36
37    x, y, kind, kwargs = func(
38        backend_name="",
39        data=DataFrame(),
40        args=["x"],
41        kwargs={"y": "y", "kind": "bar", "grid": False},
42    )
43    assert x == "x"
44    assert y == "y"
45    assert kind == "bar"
46    assert kwargs == {"grid": False}
47
48    x, y, kind, kwargs = func(
49        backend_name="pandas.plotting._matplotlib",
50        data=Series(dtype=object),
51        args=[],
52        kwargs={},
53    )
54    assert x is None
55    assert y is None
56    assert kind == "line"
57    assert len(kwargs) == 24
58
59
60@td.skip_if_no_mpl
61class TestSeriesPlots(TestPlotBase):
62    def setup_method(self, method):
63        TestPlotBase.setup_method(self, method)
64        import matplotlib as mpl
65
66        mpl.rcdefaults()
67
68        self.ts = tm.makeTimeSeries()
69        self.ts.name = "ts"
70
71    def test_autocorrelation_plot(self):
72        from pandas.plotting import autocorrelation_plot
73
74        # Ensure no UserWarning when making plot
75        with tm.assert_produces_warning(None):
76            _check_plot_works(autocorrelation_plot, series=self.ts)
77            _check_plot_works(autocorrelation_plot, series=self.ts.values)
78
79            ax = autocorrelation_plot(self.ts, label="Test")
80        self._check_legend_labels(ax, labels=["Test"])
81
82    def test_lag_plot(self):
83        from pandas.plotting import lag_plot
84
85        _check_plot_works(lag_plot, series=self.ts)
86        _check_plot_works(lag_plot, series=self.ts, lag=5)
87
88    def test_bootstrap_plot(self):
89        from pandas.plotting import bootstrap_plot
90
91        _check_plot_works(bootstrap_plot, series=self.ts, size=10)
92
93
94@td.skip_if_no_mpl
95class TestDataFramePlots(TestPlotBase):
96    @td.skip_if_no_scipy
97    def test_scatter_matrix_axis(self):
98        from pandas.plotting._matplotlib.compat import mpl_ge_3_0_0
99
100        scatter_matrix = plotting.scatter_matrix
101
102        with tm.RNGContext(42):
103            df = DataFrame(np.random.randn(100, 3))
104
105        # we are plotting multiples on a sub-plot
106        with tm.assert_produces_warning(
107            UserWarning, raise_on_extra_warnings=mpl_ge_3_0_0()
108        ):
109            axes = _check_plot_works(
110                scatter_matrix, filterwarnings="always", frame=df, range_padding=0.1
111            )
112        axes0_labels = axes[0][0].yaxis.get_majorticklabels()
113
114        # GH 5662
115        expected = ["-2", "0", "2"]
116        self._check_text_labels(axes0_labels, expected)
117        self._check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
118
119        df[0] = (df[0] - 2) / 3
120
121        # we are plotting multiples on a sub-plot
122        with tm.assert_produces_warning(UserWarning):
123            axes = _check_plot_works(
124                scatter_matrix, filterwarnings="always", frame=df, range_padding=0.1
125            )
126        axes0_labels = axes[0][0].yaxis.get_majorticklabels()
127        expected = ["-1.0", "-0.5", "0.0"]
128        self._check_text_labels(axes0_labels, expected)
129        self._check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
130
131    def test_andrews_curves(self, iris):
132        from matplotlib import cm
133
134        from pandas.plotting import andrews_curves
135
136        df = iris
137        # Ensure no UserWarning when making plot
138        with tm.assert_produces_warning(None):
139            _check_plot_works(andrews_curves, frame=df, class_column="Name")
140
141        rgba = ("#556270", "#4ECDC4", "#C7F464")
142        ax = _check_plot_works(
143            andrews_curves, frame=df, class_column="Name", color=rgba
144        )
145        self._check_colors(
146            ax.get_lines()[:10], linecolors=rgba, mapping=df["Name"][:10]
147        )
148
149        cnames = ["dodgerblue", "aquamarine", "seagreen"]
150        ax = _check_plot_works(
151            andrews_curves, frame=df, class_column="Name", color=cnames
152        )
153        self._check_colors(
154            ax.get_lines()[:10], linecolors=cnames, mapping=df["Name"][:10]
155        )
156
157        ax = _check_plot_works(
158            andrews_curves, frame=df, class_column="Name", colormap=cm.jet
159        )
160        cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())]
161        self._check_colors(
162            ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10]
163        )
164
165        length = 10
166        df = DataFrame(
167            {
168                "A": np.random.rand(length),
169                "B": np.random.rand(length),
170                "C": np.random.rand(length),
171                "Name": ["A"] * length,
172            }
173        )
174
175        _check_plot_works(andrews_curves, frame=df, class_column="Name")
176
177        rgba = ("#556270", "#4ECDC4", "#C7F464")
178        ax = _check_plot_works(
179            andrews_curves, frame=df, class_column="Name", color=rgba
180        )
181        self._check_colors(
182            ax.get_lines()[:10], linecolors=rgba, mapping=df["Name"][:10]
183        )
184
185        cnames = ["dodgerblue", "aquamarine", "seagreen"]
186        ax = _check_plot_works(
187            andrews_curves, frame=df, class_column="Name", color=cnames
188        )
189        self._check_colors(
190            ax.get_lines()[:10], linecolors=cnames, mapping=df["Name"][:10]
191        )
192
193        ax = _check_plot_works(
194            andrews_curves, frame=df, class_column="Name", colormap=cm.jet
195        )
196        cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())]
197        self._check_colors(
198            ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10]
199        )
200
201        colors = ["b", "g", "r"]
202        df = DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "C": [1, 2, 3], "Name": colors})
203        ax = andrews_curves(df, "Name", color=colors)
204        handles, labels = ax.get_legend_handles_labels()
205        self._check_colors(handles, linecolors=colors)
206
207    def test_parallel_coordinates(self, iris):
208        from matplotlib import cm
209
210        from pandas.plotting import parallel_coordinates
211
212        df = iris
213
214        ax = _check_plot_works(parallel_coordinates, frame=df, class_column="Name")
215        nlines = len(ax.get_lines())
216        nxticks = len(ax.xaxis.get_ticklabels())
217
218        rgba = ("#556270", "#4ECDC4", "#C7F464")
219        ax = _check_plot_works(
220            parallel_coordinates, frame=df, class_column="Name", color=rgba
221        )
222        self._check_colors(
223            ax.get_lines()[:10], linecolors=rgba, mapping=df["Name"][:10]
224        )
225
226        cnames = ["dodgerblue", "aquamarine", "seagreen"]
227        ax = _check_plot_works(
228            parallel_coordinates, frame=df, class_column="Name", color=cnames
229        )
230        self._check_colors(
231            ax.get_lines()[:10], linecolors=cnames, mapping=df["Name"][:10]
232        )
233
234        ax = _check_plot_works(
235            parallel_coordinates, frame=df, class_column="Name", colormap=cm.jet
236        )
237        cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())]
238        self._check_colors(
239            ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10]
240        )
241
242        ax = _check_plot_works(
243            parallel_coordinates, frame=df, class_column="Name", axvlines=False
244        )
245        assert len(ax.get_lines()) == (nlines - nxticks)
246
247        colors = ["b", "g", "r"]
248        df = DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "C": [1, 2, 3], "Name": colors})
249        ax = parallel_coordinates(df, "Name", color=colors)
250        handles, labels = ax.get_legend_handles_labels()
251        self._check_colors(handles, linecolors=colors)
252
253    # not sure if this is indicative of a problem
254    @pytest.mark.filterwarnings("ignore:Attempting to set:UserWarning")
255    def test_parallel_coordinates_with_sorted_labels(self):
256        """ For #15908 """
257        from pandas.plotting import parallel_coordinates
258
259        df = DataFrame(
260            {
261                "feat": list(range(30)),
262                "class": [2 for _ in range(10)]
263                + [3 for _ in range(10)]
264                + [1 for _ in range(10)],
265            }
266        )
267        ax = parallel_coordinates(df, "class", sort_labels=True)
268        polylines, labels = ax.get_legend_handles_labels()
269        color_label_tuples = zip(
270            [polyline.get_color() for polyline in polylines], labels
271        )
272        ordered_color_label_tuples = sorted(color_label_tuples, key=lambda x: x[1])
273        prev_next_tupels = zip(
274            list(ordered_color_label_tuples[0:-1]), list(ordered_color_label_tuples[1:])
275        )
276        for prev, nxt in prev_next_tupels:
277            # labels and colors are ordered strictly increasing
278            assert prev[1] < nxt[1] and prev[0] < nxt[0]
279
280    def test_radviz(self, iris):
281        from matplotlib import cm
282
283        from pandas.plotting import radviz
284
285        df = iris
286        # Ensure no UserWarning when making plot
287        with tm.assert_produces_warning(None):
288            _check_plot_works(radviz, frame=df, class_column="Name")
289
290        rgba = ("#556270", "#4ECDC4", "#C7F464")
291        ax = _check_plot_works(radviz, frame=df, class_column="Name", color=rgba)
292        # skip Circle drawn as ticks
293        patches = [p for p in ax.patches[:20] if p.get_label() != ""]
294        self._check_colors(patches[:10], facecolors=rgba, mapping=df["Name"][:10])
295
296        cnames = ["dodgerblue", "aquamarine", "seagreen"]
297        _check_plot_works(radviz, frame=df, class_column="Name", color=cnames)
298        patches = [p for p in ax.patches[:20] if p.get_label() != ""]
299        self._check_colors(patches, facecolors=cnames, mapping=df["Name"][:10])
300
301        _check_plot_works(radviz, frame=df, class_column="Name", colormap=cm.jet)
302        cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())]
303        patches = [p for p in ax.patches[:20] if p.get_label() != ""]
304        self._check_colors(patches, facecolors=cmaps, mapping=df["Name"][:10])
305
306        colors = [[0.0, 0.0, 1.0, 1.0], [0.0, 0.5, 1.0, 1.0], [1.0, 0.0, 0.0, 1.0]]
307        df = DataFrame(
308            {"A": [1, 2, 3], "B": [2, 1, 3], "C": [3, 2, 1], "Name": ["b", "g", "r"]}
309        )
310        ax = radviz(df, "Name", color=colors)
311        handles, labels = ax.get_legend_handles_labels()
312        self._check_colors(handles, facecolors=colors)
313
314    def test_subplot_titles(self, iris):
315        df = iris.drop("Name", axis=1).head()
316        # Use the column names as the subplot titles
317        title = list(df.columns)
318
319        # Case len(title) == len(df)
320        plot = df.plot(subplots=True, title=title)
321        assert [p.get_title() for p in plot] == title
322
323        # Case len(title) > len(df)
324        msg = (
325            "The length of `title` must equal the number of columns if "
326            "using `title` of type `list` and `subplots=True`"
327        )
328        with pytest.raises(ValueError, match=msg):
329            df.plot(subplots=True, title=title + ["kittens > puppies"])
330
331        # Case len(title) < len(df)
332        with pytest.raises(ValueError, match=msg):
333            df.plot(subplots=True, title=title[:2])
334
335        # Case subplots=False and title is of type list
336        msg = (
337            "Using `title` of type `list` is not supported unless "
338            "`subplots=True` is passed"
339        )
340        with pytest.raises(ValueError, match=msg):
341            df.plot(subplots=False, title=title)
342
343        # Case df with 3 numeric columns but layout of (2,2)
344        plot = df.drop("SepalWidth", axis=1).plot(
345            subplots=True, layout=(2, 2), title=title[:-1]
346        )
347        title_list = [ax.get_title() for sublist in plot for ax in sublist]
348        assert title_list == title[:3] + [""]
349
350    def test_get_standard_colors_random_seed(self):
351        # GH17525
352        df = DataFrame(np.zeros((10, 10)))
353
354        # Make sure that the np.random.seed isn't reset by get_standard_colors
355        plotting.parallel_coordinates(df, 0)
356        rand1 = np.random.random()
357        plotting.parallel_coordinates(df, 0)
358        rand2 = np.random.random()
359        assert rand1 != rand2
360
361        # Make sure it produces the same colors every time it's called
362        from pandas.plotting._matplotlib.style import get_standard_colors
363
364        color1 = get_standard_colors(1, color_type="random")
365        color2 = get_standard_colors(1, color_type="random")
366        assert color1 == color2
367
368    def test_get_standard_colors_default_num_colors(self):
369        from pandas.plotting._matplotlib.style import get_standard_colors
370
371        # Make sure the default color_types returns the specified amount
372        color1 = get_standard_colors(1, color_type="default")
373        color2 = get_standard_colors(9, color_type="default")
374        color3 = get_standard_colors(20, color_type="default")
375        assert len(color1) == 1
376        assert len(color2) == 9
377        assert len(color3) == 20
378
379    def test_plot_single_color(self):
380        # Example from #20585. All 3 bars should have the same color
381        df = DataFrame(
382            {
383                "account-start": ["2017-02-03", "2017-03-03", "2017-01-01"],
384                "client": ["Alice Anders", "Bob Baker", "Charlie Chaplin"],
385                "balance": [-1432.32, 10.43, 30000.00],
386                "db-id": [1234, 2424, 251],
387                "proxy-id": [525, 1525, 2542],
388                "rank": [52, 525, 32],
389            }
390        )
391        ax = df.client.value_counts().plot.bar()
392        colors = [rect.get_facecolor() for rect in ax.get_children()[0:3]]
393        assert all(color == colors[0] for color in colors)
394
395    def test_get_standard_colors_no_appending(self):
396        # GH20726
397
398        # Make sure not to add more colors so that matplotlib can cycle
399        # correctly.
400        from matplotlib import cm
401
402        from pandas.plotting._matplotlib.style import get_standard_colors
403
404        color_before = cm.gnuplot(range(5))
405        color_after = get_standard_colors(1, color=color_before)
406        assert len(color_after) == len(color_before)
407
408        df = DataFrame(np.random.randn(48, 4), columns=list("ABCD"))
409
410        color_list = cm.gnuplot(np.linspace(0, 1, 16))
411        p = df.A.plot.bar(figsize=(16, 7), color=color_list)
412        assert p.patches[1].get_facecolor() == p.patches[17].get_facecolor()
413
414    def test_dictionary_color(self):
415        # issue-8193
416        # Test plot color dictionary format
417        data_files = ["a", "b"]
418
419        expected = [(0.5, 0.24, 0.6), (0.3, 0.7, 0.7)]
420
421        df1 = DataFrame(np.random.rand(2, 2), columns=data_files)
422        dic_color = {"b": (0.3, 0.7, 0.7), "a": (0.5, 0.24, 0.6)}
423
424        # Bar color test
425        ax = df1.plot(kind="bar", color=dic_color)
426        colors = [rect.get_facecolor()[0:-1] for rect in ax.get_children()[0:3:2]]
427        assert all(color == expected[index] for index, color in enumerate(colors))
428
429        # Line color test
430        ax = df1.plot(kind="line", color=dic_color)
431        colors = [rect.get_color() for rect in ax.get_lines()[0:2]]
432        assert all(color == expected[index] for index, color in enumerate(colors))
433
434    def test_has_externally_shared_axis_x_axis(self):
435        # GH33819
436        # Test _has_externally_shared_axis() works for x-axis
437        func = plotting._matplotlib.tools._has_externally_shared_axis
438
439        fig = self.plt.figure()
440        plots = fig.subplots(2, 4)
441
442        # Create *externally* shared axes for first and third columns
443        plots[0][0] = fig.add_subplot(231, sharex=plots[1][0])
444        plots[0][2] = fig.add_subplot(233, sharex=plots[1][2])
445
446        # Create *internally* shared axes for second and third columns
447        plots[0][1].twinx()
448        plots[0][2].twinx()
449
450        # First  column is only externally shared
451        # Second column is only internally shared
452        # Third  column is both
453        # Fourth column is neither
454        assert func(plots[0][0], "x")
455        assert not func(plots[0][1], "x")
456        assert func(plots[0][2], "x")
457        assert not func(plots[0][3], "x")
458
459    def test_has_externally_shared_axis_y_axis(self):
460        # GH33819
461        # Test _has_externally_shared_axis() works for y-axis
462        func = plotting._matplotlib.tools._has_externally_shared_axis
463
464        fig = self.plt.figure()
465        plots = fig.subplots(4, 2)
466
467        # Create *externally* shared axes for first and third rows
468        plots[0][0] = fig.add_subplot(321, sharey=plots[0][1])
469        plots[2][0] = fig.add_subplot(325, sharey=plots[2][1])
470
471        # Create *internally* shared axes for second and third rows
472        plots[1][0].twiny()
473        plots[2][0].twiny()
474
475        # First  row is only externally shared
476        # Second row is only internally shared
477        # Third  row is both
478        # Fourth row is neither
479        assert func(plots[0][0], "y")
480        assert not func(plots[1][0], "y")
481        assert func(plots[2][0], "y")
482        assert not func(plots[3][0], "y")
483
484    def test_has_externally_shared_axis_invalid_compare_axis(self):
485        # GH33819
486        # Test _has_externally_shared_axis() raises an exception when
487        # passed an invalid value as compare_axis parameter
488        func = plotting._matplotlib.tools._has_externally_shared_axis
489
490        fig = self.plt.figure()
491        plots = fig.subplots(4, 2)
492
493        # Create arbitrary axes
494        plots[0][0] = fig.add_subplot(321, sharey=plots[0][1])
495
496        # Check that an invalid compare_axis value triggers the expected exception
497        msg = "needs 'x' or 'y' as a second parameter"
498        with pytest.raises(ValueError, match=msg):
499            func(plots[0][0], "z")
500
501    def test_externally_shared_axes(self):
502        # Example from GH33819
503        # Create data
504        df = DataFrame({"a": np.random.randn(1000), "b": np.random.randn(1000)})
505
506        # Create figure
507        fig = self.plt.figure()
508        plots = fig.subplots(2, 3)
509
510        # Create *externally* shared axes
511        plots[0][0] = fig.add_subplot(231, sharex=plots[1][0])
512        # note: no plots[0][1] that's the twin only case
513        plots[0][2] = fig.add_subplot(233, sharex=plots[1][2])
514
515        # Create *internally* shared axes
516        # note: no plots[0][0] that's the external only case
517        twin_ax1 = plots[0][1].twinx()
518        twin_ax2 = plots[0][2].twinx()
519
520        # Plot data to primary axes
521        df["a"].plot(ax=plots[0][0], title="External share only").set_xlabel(
522            "this label should never be visible"
523        )
524        df["a"].plot(ax=plots[1][0])
525
526        df["a"].plot(ax=plots[0][1], title="Internal share (twin) only").set_xlabel(
527            "this label should always be visible"
528        )
529        df["a"].plot(ax=plots[1][1])
530
531        df["a"].plot(ax=plots[0][2], title="Both").set_xlabel(
532            "this label should never be visible"
533        )
534        df["a"].plot(ax=plots[1][2])
535
536        # Plot data to twinned axes
537        df["b"].plot(ax=twin_ax1, color="green")
538        df["b"].plot(ax=twin_ax2, color="yellow")
539
540        assert not plots[0][0].xaxis.get_label().get_visible()
541        assert plots[0][1].xaxis.get_label().get_visible()
542        assert not plots[0][2].xaxis.get_label().get_visible()
543