1# -*- coding: utf-8 -*-
2# Licensed under a 3-clause BSD style license - see LICENSE.rst
3
4import io
5
6import pytest
7
8from astropy.utils.compat.optional_deps import HAS_PLT
9if HAS_PLT:
10    import matplotlib.pyplot as plt
11import numpy as np
12
13from astropy import units as u
14from astropy.coordinates import Angle
15from astropy.visualization.units import quantity_support
16
17
18def teardown_function(function):
19    plt.close('all')
20
21
22@pytest.mark.skipif('not HAS_PLT')
23def test_units():
24    plt.figure()
25
26    with quantity_support():
27        buff = io.BytesIO()
28
29        plt.plot([1, 2, 3] * u.m, [3, 4, 5] * u.kg, label='label')
30        plt.plot([105, 210, 315] * u.cm, [3050, 3025, 3010] * u.g)
31        plt.legend()
32        # Also test fill_between, which requires actual conversion to ndarray
33        # with numpy >=1.10 (#4654).
34        plt.fill_between([1, 3] * u.m, [3, 5] * u.kg, [3050, 3010] * u.g)
35        plt.savefig(buff, format='svg')
36
37        assert plt.gca().xaxis.get_units() == u.m
38        assert plt.gca().yaxis.get_units() == u.kg
39
40
41@pytest.mark.skipif('not HAS_PLT')
42def test_units_errbarr():
43    pytest.importorskip("matplotlib")
44    plt.figure()
45
46    with quantity_support():
47        x = [1, 2, 3] * u.s
48        y = [1, 2, 3] * u.m
49        yerr = [3, 2, 1] * u.cm
50
51        fig, ax = plt.subplots()
52        ax.errorbar(x, y, yerr=yerr)
53
54        assert ax.xaxis.get_units() == u.s
55        assert ax.yaxis.get_units() == u.m
56
57
58@pytest.mark.skipif('not HAS_PLT')
59def test_incompatible_units():
60    # NOTE: minversion check does not work properly for matplotlib dev.
61    try:
62        # https://github.com/matplotlib/matplotlib/pull/13005
63        from matplotlib.units import ConversionError
64    except ImportError:
65        err_type = u.UnitConversionError
66    else:
67        err_type = ConversionError
68
69    plt.figure()
70
71    with quantity_support():
72        plt.plot([1, 2, 3] * u.m)
73        with pytest.raises(err_type):
74            plt.plot([105, 210, 315] * u.kg)
75
76
77@pytest.mark.skipif('not HAS_PLT')
78def test_quantity_subclass():
79    """Check that subclasses are recognized.
80
81    This sadly is not done by matplotlib.units itself, though
82    there is a PR to change it:
83    https://github.com/matplotlib/matplotlib/pull/13536
84    """
85    plt.figure()
86
87    with quantity_support():
88        plt.scatter(Angle([1, 2, 3], u.deg), [3, 4, 5] * u.kg)
89        plt.scatter([105, 210, 315] * u.arcsec, [3050, 3025, 3010] * u.g)
90        plt.plot(Angle([105, 210, 315], u.arcsec), [3050, 3025, 3010] * u.g)
91
92        assert plt.gca().xaxis.get_units() == u.deg
93        assert plt.gca().yaxis.get_units() == u.kg
94
95
96@pytest.mark.skipif('not HAS_PLT')
97def test_nested():
98
99    with quantity_support():
100
101        with quantity_support():
102
103            fig = plt.figure()
104            ax = fig.add_subplot(1, 1, 1)
105            ax.scatter(Angle([1, 2, 3], u.deg), [3, 4, 5] * u.kg)
106
107            assert ax.xaxis.get_units() == u.deg
108            assert ax.yaxis.get_units() == u.kg
109
110        fig = plt.figure()
111        ax = fig.add_subplot(1, 1, 1)
112        ax.scatter(Angle([1, 2, 3], u.arcsec), [3, 4, 5] * u.pc)
113
114        assert ax.xaxis.get_units() == u.arcsec
115        assert ax.yaxis.get_units() == u.pc
116
117
118@pytest.mark.skipif('not HAS_PLT')
119def test_empty_hist():
120
121    with quantity_support():
122        fig = plt.figure()
123        ax = fig.add_subplot(1, 1, 1)
124        ax.hist([1, 2, 3, 4] * u.mmag, bins=100)
125        # The second call results in an empty list being passed to the
126        # unit converter in matplotlib >= 3.1
127        ax.hist([] * u.mmag, bins=100)
128
129
130@pytest.mark.skipif('not HAS_PLT')
131def test_radian_formatter():
132    with quantity_support():
133        fig, ax = plt.subplots()
134        ax.plot([1, 2, 3], [1, 2, 3] * u.rad * np.pi)
135        fig.canvas.draw()
136        labels = [tl.get_text() for tl in ax.yaxis.get_ticklabels()]
137        assert labels == ['π/2', 'π', '3π/2', '2π', '5π/2', '3π', '7π/2']
138