1"""
2Module consolidating common testing functions for checking plotting.
3
4Currently all plotting tests are marked as slow via
5``pytestmark = pytest.mark.slow`` at the module level.
6"""
7
8import os
9from typing import TYPE_CHECKING, Sequence, Union
10import warnings
11
12import numpy as np
13
14from pandas.util._decorators import cache_readonly
15import pandas.util._test_decorators as td
16
17from pandas.core.dtypes.api import is_list_like
18
19import pandas as pd
20from pandas import DataFrame, Series, to_datetime
21import pandas._testing as tm
22
23if TYPE_CHECKING:
24    from matplotlib.axes import Axes
25
26
27@td.skip_if_no_mpl
28class TestPlotBase:
29    """
30    This is a common base class used for various plotting tests
31    """
32
33    def setup_method(self, method):
34
35        import matplotlib as mpl
36
37        from pandas.plotting._matplotlib import compat
38
39        mpl.rcdefaults()
40
41        self.start_date_to_int64 = 812419200000000000
42        self.end_date_to_int64 = 819331200000000000
43
44        self.mpl_ge_2_2_3 = compat.mpl_ge_2_2_3()
45        self.mpl_ge_3_0_0 = compat.mpl_ge_3_0_0()
46        self.mpl_ge_3_1_0 = compat.mpl_ge_3_1_0()
47        self.mpl_ge_3_2_0 = compat.mpl_ge_3_2_0()
48
49        self.bp_n_objects = 7
50        self.polycollection_factor = 2
51        self.default_figsize = (6.4, 4.8)
52        self.default_tick_position = "left"
53
54        n = 100
55        with tm.RNGContext(42):
56            gender = np.random.choice(["Male", "Female"], size=n)
57            classroom = np.random.choice(["A", "B", "C"], size=n)
58
59            self.hist_df = DataFrame(
60                {
61                    "gender": gender,
62                    "classroom": classroom,
63                    "height": np.random.normal(66, 4, size=n),
64                    "weight": np.random.normal(161, 32, size=n),
65                    "category": np.random.randint(4, size=n),
66                    "datetime": to_datetime(
67                        np.random.randint(
68                            self.start_date_to_int64,
69                            self.end_date_to_int64,
70                            size=n,
71                            dtype=np.int64,
72                        )
73                    ),
74                }
75            )
76
77        self.tdf = tm.makeTimeDataFrame()
78        self.hexbin_df = DataFrame(
79            {
80                "A": np.random.uniform(size=20),
81                "B": np.random.uniform(size=20),
82                "C": np.arange(20) + np.random.uniform(size=20),
83            }
84        )
85
86    def teardown_method(self, method):
87        tm.close()
88
89    @cache_readonly
90    def plt(self):
91        import matplotlib.pyplot as plt
92
93        return plt
94
95    @cache_readonly
96    def colorconverter(self):
97        import matplotlib.colors as colors
98
99        return colors.colorConverter
100
101    def _check_legend_labels(self, axes, labels=None, visible=True):
102        """
103        Check each axes has expected legend labels
104
105        Parameters
106        ----------
107        axes : matplotlib Axes object, or its list-like
108        labels : list-like
109            expected legend labels
110        visible : bool
111            expected legend visibility. labels are checked only when visible is
112            True
113        """
114        if visible and (labels is None):
115            raise ValueError("labels must be specified when visible is True")
116        axes = self._flatten_visible(axes)
117        for ax in axes:
118            if visible:
119                assert ax.get_legend() is not None
120                self._check_text_labels(ax.get_legend().get_texts(), labels)
121            else:
122                assert ax.get_legend() is None
123
124    def _check_legend_marker(self, ax, expected_markers=None, visible=True):
125        """
126        Check ax has expected legend markers
127
128        Parameters
129        ----------
130        ax : matplotlib Axes object
131        expected_markers : list-like
132            expected legend markers
133        visible : bool
134            expected legend visibility. labels are checked only when visible is
135            True
136        """
137        if visible and (expected_markers is None):
138            raise ValueError("Markers must be specified when visible is True")
139        if visible:
140            handles, _ = ax.get_legend_handles_labels()
141            markers = [handle.get_marker() for handle in handles]
142            assert markers == expected_markers
143        else:
144            assert ax.get_legend() is None
145
146    def _check_data(self, xp, rs):
147        """
148        Check each axes has identical lines
149
150        Parameters
151        ----------
152        xp : matplotlib Axes object
153        rs : matplotlib Axes object
154        """
155        xp_lines = xp.get_lines()
156        rs_lines = rs.get_lines()
157
158        def check_line(xpl, rsl):
159            xpdata = xpl.get_xydata()
160            rsdata = rsl.get_xydata()
161            tm.assert_almost_equal(xpdata, rsdata)
162
163        assert len(xp_lines) == len(rs_lines)
164        [check_line(xpl, rsl) for xpl, rsl in zip(xp_lines, rs_lines)]
165        tm.close()
166
167    def _check_visible(self, collections, visible=True):
168        """
169        Check each artist is visible or not
170
171        Parameters
172        ----------
173        collections : matplotlib Artist or its list-like
174            target Artist or its list or collection
175        visible : bool
176            expected visibility
177        """
178        from matplotlib.collections import Collection
179
180        if not isinstance(collections, Collection) and not is_list_like(collections):
181            collections = [collections]
182
183        for patch in collections:
184            assert patch.get_visible() == visible
185
186    def _check_patches_all_filled(
187        self, axes: Union["Axes", Sequence["Axes"]], filled: bool = True
188    ) -> None:
189        """
190        Check for each artist whether it is filled or not
191
192        Parameters
193        ----------
194        axes : matplotlib Axes object, or its list-like
195        filled : bool
196            expected filling
197        """
198
199        axes = self._flatten_visible(axes)
200        for ax in axes:
201            for patch in ax.patches:
202                assert patch.fill == filled
203
204    def _get_colors_mapped(self, series, colors):
205        unique = series.unique()
206        # unique and colors length can be differed
207        # depending on slice value
208        mapped = dict(zip(unique, colors))
209        return [mapped[v] for v in series.values]
210
211    def _check_colors(
212        self, collections, linecolors=None, facecolors=None, mapping=None
213    ):
214        """
215        Check each artist has expected line colors and face colors
216
217        Parameters
218        ----------
219        collections : list-like
220            list or collection of target artist
221        linecolors : list-like which has the same length as collections
222            list of expected line colors
223        facecolors : list-like which has the same length as collections
224            list of expected face colors
225        mapping : Series
226            Series used for color grouping key
227            used for andrew_curves, parallel_coordinates, radviz test
228        """
229        from matplotlib.collections import Collection, LineCollection, PolyCollection
230        from matplotlib.lines import Line2D
231
232        conv = self.colorconverter
233        if linecolors is not None:
234
235            if mapping is not None:
236                linecolors = self._get_colors_mapped(mapping, linecolors)
237                linecolors = linecolors[: len(collections)]
238
239            assert len(collections) == len(linecolors)
240            for patch, color in zip(collections, linecolors):
241                if isinstance(patch, Line2D):
242                    result = patch.get_color()
243                    # Line2D may contains string color expression
244                    result = conv.to_rgba(result)
245                elif isinstance(patch, (PolyCollection, LineCollection)):
246                    result = tuple(patch.get_edgecolor()[0])
247                else:
248                    result = patch.get_edgecolor()
249
250                expected = conv.to_rgba(color)
251                assert result == expected
252
253        if facecolors is not None:
254
255            if mapping is not None:
256                facecolors = self._get_colors_mapped(mapping, facecolors)
257                facecolors = facecolors[: len(collections)]
258
259            assert len(collections) == len(facecolors)
260            for patch, color in zip(collections, facecolors):
261                if isinstance(patch, Collection):
262                    # returned as list of np.array
263                    result = patch.get_facecolor()[0]
264                else:
265                    result = patch.get_facecolor()
266
267                if isinstance(result, np.ndarray):
268                    result = tuple(result)
269
270                expected = conv.to_rgba(color)
271                assert result == expected
272
273    def _check_text_labels(self, texts, expected):
274        """
275        Check each text has expected labels
276
277        Parameters
278        ----------
279        texts : matplotlib Text object, or its list-like
280            target text, or its list
281        expected : str or list-like which has the same length as texts
282            expected text label, or its list
283        """
284        if not is_list_like(texts):
285            assert texts.get_text() == expected
286        else:
287            labels = [t.get_text() for t in texts]
288            assert len(labels) == len(expected)
289            for label, e in zip(labels, expected):
290                assert label == e
291
292    def _check_ticks_props(
293        self, axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None
294    ):
295        """
296        Check each axes has expected tick properties
297
298        Parameters
299        ----------
300        axes : matplotlib Axes object, or its list-like
301        xlabelsize : number
302            expected xticks font size
303        xrot : number
304            expected xticks rotation
305        ylabelsize : number
306            expected yticks font size
307        yrot : number
308            expected yticks rotation
309        """
310        from matplotlib.ticker import NullFormatter
311
312        axes = self._flatten_visible(axes)
313        for ax in axes:
314            if xlabelsize is not None or xrot is not None:
315                if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter):
316                    # If minor ticks has NullFormatter, rot / fontsize are not
317                    # retained
318                    labels = ax.get_xticklabels()
319                else:
320                    labels = ax.get_xticklabels() + ax.get_xticklabels(minor=True)
321
322                for label in labels:
323                    if xlabelsize is not None:
324                        tm.assert_almost_equal(label.get_fontsize(), xlabelsize)
325                    if xrot is not None:
326                        tm.assert_almost_equal(label.get_rotation(), xrot)
327
328            if ylabelsize is not None or yrot is not None:
329                if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter):
330                    labels = ax.get_yticklabels()
331                else:
332                    labels = ax.get_yticklabels() + ax.get_yticklabels(minor=True)
333
334                for label in labels:
335                    if ylabelsize is not None:
336                        tm.assert_almost_equal(label.get_fontsize(), ylabelsize)
337                    if yrot is not None:
338                        tm.assert_almost_equal(label.get_rotation(), yrot)
339
340    def _check_ax_scales(self, axes, xaxis="linear", yaxis="linear"):
341        """
342        Check each axes has expected scales
343
344        Parameters
345        ----------
346        axes : matplotlib Axes object, or its list-like
347        xaxis : {'linear', 'log'}
348            expected xaxis scale
349        yaxis : {'linear', 'log'}
350            expected yaxis scale
351        """
352        axes = self._flatten_visible(axes)
353        for ax in axes:
354            assert ax.xaxis.get_scale() == xaxis
355            assert ax.yaxis.get_scale() == yaxis
356
357    def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=None):
358        """
359        Check expected number of axes is drawn in expected layout
360
361        Parameters
362        ----------
363        axes : matplotlib Axes object, or its list-like
364        axes_num : number
365            expected number of axes. Unnecessary axes should be set to
366            invisible.
367        layout : tuple
368            expected layout, (expected number of rows , columns)
369        figsize : tuple
370            expected figsize. default is matplotlib default
371        """
372        from pandas.plotting._matplotlib.tools import flatten_axes
373
374        if figsize is None:
375            figsize = self.default_figsize
376        visible_axes = self._flatten_visible(axes)
377
378        if axes_num is not None:
379            assert len(visible_axes) == axes_num
380            for ax in visible_axes:
381                # check something drawn on visible axes
382                assert len(ax.get_children()) > 0
383
384        if layout is not None:
385            result = self._get_axes_layout(flatten_axes(axes))
386            assert result == layout
387
388        tm.assert_numpy_array_equal(
389            visible_axes[0].figure.get_size_inches(),
390            np.array(figsize, dtype=np.float64),
391        )
392
393    def _get_axes_layout(self, axes):
394        x_set = set()
395        y_set = set()
396        for ax in axes:
397            # check axes coordinates to estimate layout
398            points = ax.get_position().get_points()
399            x_set.add(points[0][0])
400            y_set.add(points[0][1])
401        return (len(y_set), len(x_set))
402
403    def _flatten_visible(self, axes):
404        """
405        Flatten axes, and filter only visible
406
407        Parameters
408        ----------
409        axes : matplotlib Axes object, or its list-like
410
411        """
412        from pandas.plotting._matplotlib.tools import flatten_axes
413
414        axes = flatten_axes(axes)
415        axes = [ax for ax in axes if ax.get_visible()]
416        return axes
417
418    def _check_has_errorbars(self, axes, xerr=0, yerr=0):
419        """
420        Check axes has expected number of errorbars
421
422        Parameters
423        ----------
424        axes : matplotlib Axes object, or its list-like
425        xerr : number
426            expected number of x errorbar
427        yerr : number
428            expected number of y errorbar
429        """
430        axes = self._flatten_visible(axes)
431        for ax in axes:
432            containers = ax.containers
433            xerr_count = 0
434            yerr_count = 0
435            for c in containers:
436                has_xerr = getattr(c, "has_xerr", False)
437                has_yerr = getattr(c, "has_yerr", False)
438                if has_xerr:
439                    xerr_count += 1
440                if has_yerr:
441                    yerr_count += 1
442            assert xerr == xerr_count
443            assert yerr == yerr_count
444
445    def _check_box_return_type(
446        self, returned, return_type, expected_keys=None, check_ax_title=True
447    ):
448        """
449        Check box returned type is correct
450
451        Parameters
452        ----------
453        returned : object to be tested, returned from boxplot
454        return_type : str
455            return_type passed to boxplot
456        expected_keys : list-like, optional
457            group labels in subplot case. If not passed,
458            the function checks assuming boxplot uses single ax
459        check_ax_title : bool
460            Whether to check the ax.title is the same as expected_key
461            Intended to be checked by calling from ``boxplot``.
462            Normal ``plot`` doesn't attach ``ax.title``, it must be disabled.
463        """
464        from matplotlib.axes import Axes
465
466        types = {"dict": dict, "axes": Axes, "both": tuple}
467        if expected_keys is None:
468            # should be fixed when the returning default is changed
469            if return_type is None:
470                return_type = "dict"
471
472            assert isinstance(returned, types[return_type])
473            if return_type == "both":
474                assert isinstance(returned.ax, Axes)
475                assert isinstance(returned.lines, dict)
476        else:
477            # should be fixed when the returning default is changed
478            if return_type is None:
479                for r in self._flatten_visible(returned):
480                    assert isinstance(r, Axes)
481                return
482
483            assert isinstance(returned, Series)
484
485            assert sorted(returned.keys()) == sorted(expected_keys)
486            for key, value in returned.items():
487                assert isinstance(value, types[return_type])
488                # check returned dict has correct mapping
489                if return_type == "axes":
490                    if check_ax_title:
491                        assert value.get_title() == key
492                elif return_type == "both":
493                    if check_ax_title:
494                        assert value.ax.get_title() == key
495                    assert isinstance(value.ax, Axes)
496                    assert isinstance(value.lines, dict)
497                elif return_type == "dict":
498                    line = value["medians"][0]
499                    axes = line.axes
500                    if check_ax_title:
501                        assert axes.get_title() == key
502                else:
503                    raise AssertionError
504
505    def _check_grid_settings(self, obj, kinds, kws={}):
506        # Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792
507
508        import matplotlib as mpl
509
510        def is_grid_on():
511            xticks = self.plt.gca().xaxis.get_major_ticks()
512            yticks = self.plt.gca().yaxis.get_major_ticks()
513            # for mpl 2.2.2, gridOn and gridline.get_visible disagree.
514            # for new MPL, they are the same.
515
516            if self.mpl_ge_3_1_0:
517                xoff = all(not g.gridline.get_visible() for g in xticks)
518                yoff = all(not g.gridline.get_visible() for g in yticks)
519            else:
520                xoff = all(not g.gridOn for g in xticks)
521                yoff = all(not g.gridOn for g in yticks)
522
523            return not (xoff and yoff)
524
525        spndx = 1
526        for kind in kinds:
527
528            self.plt.subplot(1, 4 * len(kinds), spndx)
529            spndx += 1
530            mpl.rc("axes", grid=False)
531            obj.plot(kind=kind, **kws)
532            assert not is_grid_on()
533
534            self.plt.subplot(1, 4 * len(kinds), spndx)
535            spndx += 1
536            mpl.rc("axes", grid=True)
537            obj.plot(kind=kind, grid=False, **kws)
538            assert not is_grid_on()
539
540            if kind != "pie":
541                self.plt.subplot(1, 4 * len(kinds), spndx)
542                spndx += 1
543                mpl.rc("axes", grid=True)
544                obj.plot(kind=kind, **kws)
545                assert is_grid_on()
546
547                self.plt.subplot(1, 4 * len(kinds), spndx)
548                spndx += 1
549                mpl.rc("axes", grid=False)
550                obj.plot(kind=kind, grid=True, **kws)
551                assert is_grid_on()
552
553    def _unpack_cycler(self, rcParams, field="color"):
554        """
555        Auxiliary function for correctly unpacking cycler after MPL >= 1.5
556        """
557        return [v[field] for v in rcParams["axes.prop_cycle"]]
558
559
560def _check_plot_works(f, filterwarnings="always", default_axes=False, **kwargs):
561    """
562    Create plot and ensure that plot return object is valid.
563
564    Parameters
565    ----------
566    f : func
567        Plotting function.
568    filterwarnings : str
569        Warnings filter.
570        See https://docs.python.org/3/library/warnings.html#warning-filter
571    default_axes : bool, optional
572        If False (default):
573            - If `ax` not in `kwargs`, then create subplot(211) and plot there
574            - Create new subplot(212) and plot there as well
575            - Mind special corner case for bootstrap_plot (see `_gen_two_subplots`)
576        If True:
577            - Simply run plotting function with kwargs provided
578            - All required axes instances will be created automatically
579            - It is recommended to use it when the plotting function
580            creates multiple axes itself. It helps avoid warnings like
581            'UserWarning: To output multiple subplots,
582            the figure containing the passed axes is being cleared'
583    **kwargs
584        Keyword arguments passed to the plotting function.
585
586    Returns
587    -------
588    Plot object returned by the last plotting.
589    """
590    import matplotlib.pyplot as plt
591
592    if default_axes:
593        gen_plots = _gen_default_plot
594    else:
595        gen_plots = _gen_two_subplots
596
597    ret = None
598    with warnings.catch_warnings():
599        warnings.simplefilter(filterwarnings)
600        try:
601            fig = kwargs.get("figure", plt.gcf())
602            plt.clf()
603
604            for ret in gen_plots(f, fig, **kwargs):
605                tm.assert_is_valid_plot_return_object(ret)
606
607            with tm.ensure_clean(return_filelike=True) as path:
608                plt.savefig(path)
609
610        except Exception as err:
611            raise err
612        finally:
613            tm.close(fig)
614
615        return ret
616
617
618def _gen_default_plot(f, fig, **kwargs):
619    """
620    Create plot in a default way.
621    """
622    yield f(**kwargs)
623
624
625def _gen_two_subplots(f, fig, **kwargs):
626    """
627    Create plot on two subplots forcefully created.
628    """
629    kwargs.get("ax", fig.add_subplot(211))
630    yield f(**kwargs)
631
632    if f is pd.plotting.bootstrap_plot:
633        assert "ax" not in kwargs
634    else:
635        kwargs["ax"] = fig.add_subplot(212)
636    yield f(**kwargs)
637
638
639def curpath():
640    pth, _ = os.path.split(os.path.abspath(__file__))
641    return pth
642