1import numpy as np
2from numpy.testing import assert_equal, assert_raises
3from pandas import Series
4import pytest
5
6from statsmodels.graphics.factorplots import _recode, interaction_plot
7
8try:
9    import matplotlib.pyplot as plt
10except ImportError:
11    pass
12
13
14class TestInteractionPlot(object):
15
16    @classmethod
17    def setup_class(cls):
18        np.random.seed(12345)
19        cls.weight = np.random.randint(1,4,size=60)
20        cls.duration = np.random.randint(1,3,size=60)
21        cls.days = np.log(np.random.randint(1,30, size=60))
22
23    @pytest.mark.matplotlib
24    def test_plot_both(self, close_figures):
25        fig = interaction_plot(self.weight, self.duration, self.days,
26                 colors=['red','blue'], markers=['D','^'], ms=10)
27
28    @pytest.mark.matplotlib
29    def test_plot_rainbow(self, close_figures):
30        fig = interaction_plot(self.weight, self.duration, self.days,
31                 markers=['D','^'], ms=10)
32
33    @pytest.mark.matplotlib
34    @pytest.mark.parametrize('astype', ['str', 'int'])
35    def test_plot_pandas(self, astype, close_figures):
36        weight = Series(self.weight, name='Weight').astype(astype)
37        duration = Series(self.duration, name='Duration')
38        days = Series(self.days, name='Days')
39        fig = interaction_plot(weight, duration, days,
40                               markers=['D', '^'], ms=10)
41        ax = fig.axes[0]
42        trace = ax.get_legend().get_title().get_text()
43        assert_equal(trace, 'Duration')
44        assert_equal(ax.get_ylabel(), 'mean of Days')
45        assert_equal(ax.get_xlabel(), 'Weight')
46
47    @pytest.mark.matplotlib
48    def test_formatting(self, close_figures):
49        fig = interaction_plot(self.weight, self.duration, self.days, colors=['r','g'], linestyles=['--','-.'])
50        assert_equal(isinstance(fig, plt.Figure), True)
51
52    @pytest.mark.matplotlib
53    def test_formatting_errors(self, close_figures):
54        assert_raises(ValueError, interaction_plot, self.weight, self.duration, self.days, markers=['D'])
55        assert_raises(ValueError, interaction_plot, self.weight, self.duration, self.days, colors=['b','r','g'])
56        assert_raises(ValueError, interaction_plot, self.weight, self.duration, self.days, linestyles=['--','-.',':'])
57
58    @pytest.mark.matplotlib
59    def test_plottype(self, close_figures):
60        fig = interaction_plot(self.weight, self.duration, self.days, plottype='line')
61        assert_equal(isinstance(fig, plt.Figure), True)
62        fig = interaction_plot(self.weight, self.duration, self.days, plottype='scatter')
63        assert_equal(isinstance(fig, plt.Figure), True)
64        assert_raises(ValueError, interaction_plot, self.weight, self.duration, self.days, plottype='unknown')
65
66    def test_recode_series(self):
67        series = Series(['a', 'b'] * 10, index=np.arange(0, 40, 2),
68                        name='index_test')
69        series_ = _recode(series, {'a': 0, 'b': 1})
70        assert_equal(series_.index.values, series.index.values,
71                     err_msg='_recode changed the index')
72