1import numpy as np
2import pandas as pd
3import pytest
4
5from statsmodels.graphics.agreement import mean_diff_plot
6
7try:
8    import matplotlib.pyplot as plt
9except ImportError:
10    pass
11
12
13@pytest.mark.matplotlib
14def test_mean_diff_plot(close_figures):
15
16    # Seed the random number generator.
17    # This ensures that the results below are reproducible.
18    np.random.seed(11111)
19    m1 = np.random.random(20)
20    m2 = np.random.random(20)
21
22    fig = plt.figure()
23    ax = fig.add_subplot(111)
24
25    # basic test.
26    mean_diff_plot(m1, m2, ax=ax)
27
28    # Test with pandas Series.
29    p1 = pd.Series(m1)
30    p2 = pd.Series(m2)
31    mean_diff_plot(p1, p2)
32
33    # Test plotting on assigned axis.
34    fig, ax = plt.subplots(2)
35    mean_diff_plot(m1, m2, ax=ax[0])
36
37    # Test the setting of confidence intervals.
38    mean_diff_plot(m1, m2, sd_limit=0)
39
40    # Test asethetic controls.
41    mean_diff_plot(m1, m2, scatter_kwds={'color': 'green', 's': 10})
42
43    mean_diff_plot(m1, m2, mean_line_kwds={'color': 'green', 'lw': 5})
44
45    mean_diff_plot(m1, m2, limit_lines_kwds={'color': 'green',
46                                             'lw': 5,
47                                             'ls': 'dotted'})
48