1import numpy as np 2from numpy.testing import assert_equal, assert_raises 3from pandas import Series 4import pytest 5 6from statsmodels.graphics.factorplots import _recode, interaction_plot 7 8try: 9 import matplotlib.pyplot as plt 10except ImportError: 11 pass 12 13 14class TestInteractionPlot(object): 15 16 @classmethod 17 def setup_class(cls): 18 np.random.seed(12345) 19 cls.weight = np.random.randint(1,4,size=60) 20 cls.duration = np.random.randint(1,3,size=60) 21 cls.days = np.log(np.random.randint(1,30, size=60)) 22 23 @pytest.mark.matplotlib 24 def test_plot_both(self, close_figures): 25 fig = interaction_plot(self.weight, self.duration, self.days, 26 colors=['red','blue'], markers=['D','^'], ms=10) 27 28 @pytest.mark.matplotlib 29 def test_plot_rainbow(self, close_figures): 30 fig = interaction_plot(self.weight, self.duration, self.days, 31 markers=['D','^'], ms=10) 32 33 @pytest.mark.matplotlib 34 @pytest.mark.parametrize('astype', ['str', 'int']) 35 def test_plot_pandas(self, astype, close_figures): 36 weight = Series(self.weight, name='Weight').astype(astype) 37 duration = Series(self.duration, name='Duration') 38 days = Series(self.days, name='Days') 39 fig = interaction_plot(weight, duration, days, 40 markers=['D', '^'], ms=10) 41 ax = fig.axes[0] 42 trace = ax.get_legend().get_title().get_text() 43 assert_equal(trace, 'Duration') 44 assert_equal(ax.get_ylabel(), 'mean of Days') 45 assert_equal(ax.get_xlabel(), 'Weight') 46 47 @pytest.mark.matplotlib 48 def test_formatting(self, close_figures): 49 fig = interaction_plot(self.weight, self.duration, self.days, colors=['r','g'], linestyles=['--','-.']) 50 assert_equal(isinstance(fig, plt.Figure), True) 51 52 @pytest.mark.matplotlib 53 def test_formatting_errors(self, close_figures): 54 assert_raises(ValueError, interaction_plot, self.weight, self.duration, self.days, markers=['D']) 55 assert_raises(ValueError, interaction_plot, self.weight, self.duration, self.days, colors=['b','r','g']) 56 assert_raises(ValueError, interaction_plot, self.weight, self.duration, self.days, linestyles=['--','-.',':']) 57 58 @pytest.mark.matplotlib 59 def test_plottype(self, close_figures): 60 fig = interaction_plot(self.weight, self.duration, self.days, plottype='line') 61 assert_equal(isinstance(fig, plt.Figure), True) 62 fig = interaction_plot(self.weight, self.duration, self.days, plottype='scatter') 63 assert_equal(isinstance(fig, plt.Figure), True) 64 assert_raises(ValueError, interaction_plot, self.weight, self.duration, self.days, plottype='unknown') 65 66 def test_recode_series(self): 67 series = Series(['a', 'b'] * 10, index=np.arange(0, 40, 2), 68 name='index_test') 69 series_ = _recode(series, {'a': 0, 'b': 1}) 70 assert_equal(series_.index.values, series.index.values, 71 err_msg='_recode changed the index') 72