1from datetime import datetime, timezone, timedelta
2import platform
3from unittest.mock import MagicMock
4
5import matplotlib.pyplot as plt
6from matplotlib.testing.decorators import check_figures_equal, image_comparison
7import matplotlib.units as munits
8from matplotlib.category import UnitData
9import numpy as np
10import pytest
11
12
13# Basic class that wraps numpy array and has units
14class Quantity:
15    def __init__(self, data, units):
16        self.magnitude = data
17        self.units = units
18
19    def to(self, new_units):
20        factors = {('hours', 'seconds'): 3600, ('minutes', 'hours'): 1 / 60,
21                   ('minutes', 'seconds'): 60, ('feet', 'miles'): 1 / 5280.,
22                   ('feet', 'inches'): 12, ('miles', 'inches'): 12 * 5280}
23        if self.units != new_units:
24            mult = factors[self.units, new_units]
25            return Quantity(mult * self.magnitude, new_units)
26        else:
27            return Quantity(self.magnitude, self.units)
28
29    def __getattr__(self, attr):
30        return getattr(self.magnitude, attr)
31
32    def __getitem__(self, item):
33        if np.iterable(self.magnitude):
34            return Quantity(self.magnitude[item], self.units)
35        else:
36            return Quantity(self.magnitude, self.units)
37
38    def __array__(self):
39        return np.asarray(self.magnitude)
40
41
42@pytest.fixture
43def quantity_converter():
44    # Create an instance of the conversion interface and
45    # mock so we can check methods called
46    qc = munits.ConversionInterface()
47
48    def convert(value, unit, axis):
49        if hasattr(value, 'units'):
50            return value.to(unit).magnitude
51        elif np.iterable(value):
52            try:
53                return [v.to(unit).magnitude for v in value]
54            except AttributeError:
55                return [Quantity(v, axis.get_units()).to(unit).magnitude
56                        for v in value]
57        else:
58            return Quantity(value, axis.get_units()).to(unit).magnitude
59
60    def default_units(value, axis):
61        if hasattr(value, 'units'):
62            return value.units
63        elif np.iterable(value):
64            for v in value:
65                if hasattr(v, 'units'):
66                    return v.units
67            return None
68
69    qc.convert = MagicMock(side_effect=convert)
70    qc.axisinfo = MagicMock(side_effect=lambda u, a: munits.AxisInfo(label=u))
71    qc.default_units = MagicMock(side_effect=default_units)
72    return qc
73
74
75# Tests that the conversion machinery works properly for classes that
76# work as a facade over numpy arrays (like pint)
77@image_comparison(['plot_pint.png'], remove_text=False, style='mpl20',
78                  tol=0 if platform.machine() == 'x86_64' else 0.01)
79def test_numpy_facade(quantity_converter):
80    # use former defaults to match existing baseline image
81    plt.rcParams['axes.formatter.limits'] = -7, 7
82
83    # Register the class
84    munits.registry[Quantity] = quantity_converter
85
86    # Simple test
87    y = Quantity(np.linspace(0, 30), 'miles')
88    x = Quantity(np.linspace(0, 5), 'hours')
89
90    fig, ax = plt.subplots()
91    fig.subplots_adjust(left=0.15)  # Make space for label
92    ax.plot(x, y, 'tab:blue')
93    ax.axhline(Quantity(26400, 'feet'), color='tab:red')
94    ax.axvline(Quantity(120, 'minutes'), color='tab:green')
95    ax.yaxis.set_units('inches')
96    ax.xaxis.set_units('seconds')
97
98    assert quantity_converter.convert.called
99    assert quantity_converter.axisinfo.called
100    assert quantity_converter.default_units.called
101
102
103# Tests gh-8908
104@image_comparison(['plot_masked_units.png'], remove_text=True, style='mpl20',
105                  tol=0 if platform.machine() == 'x86_64' else 0.01)
106def test_plot_masked_units():
107    data = np.linspace(-5, 5)
108    data_masked = np.ma.array(data, mask=(data > -2) & (data < 2))
109    data_masked_units = Quantity(data_masked, 'meters')
110
111    fig, ax = plt.subplots()
112    ax.plot(data_masked_units)
113
114
115def test_empty_set_limits_with_units(quantity_converter):
116    # Register the class
117    munits.registry[Quantity] = quantity_converter
118
119    fig, ax = plt.subplots()
120    ax.set_xlim(Quantity(-1, 'meters'), Quantity(6, 'meters'))
121    ax.set_ylim(Quantity(-1, 'hours'), Quantity(16, 'hours'))
122
123
124@image_comparison(['jpl_bar_units.png'],
125                  savefig_kwarg={'dpi': 120}, style='mpl20')
126def test_jpl_bar_units():
127    import matplotlib.testing.jpl_units as units
128    units.register()
129
130    day = units.Duration("ET", 24.0 * 60.0 * 60.0)
131    x = [0 * units.km, 1 * units.km, 2 * units.km]
132    w = [1 * day, 2 * day, 3 * day]
133    b = units.Epoch("ET", dt=datetime(2009, 4, 25))
134    fig, ax = plt.subplots()
135    ax.bar(x, w, bottom=b)
136    ax.set_ylim([b - 1 * day, b + w[-1] + (1.001) * day])
137
138
139@image_comparison(['jpl_barh_units.png'],
140                  savefig_kwarg={'dpi': 120}, style='mpl20')
141def test_jpl_barh_units():
142    import matplotlib.testing.jpl_units as units
143    units.register()
144
145    day = units.Duration("ET", 24.0 * 60.0 * 60.0)
146    x = [0 * units.km, 1 * units.km, 2 * units.km]
147    w = [1 * day, 2 * day, 3 * day]
148    b = units.Epoch("ET", dt=datetime(2009, 4, 25))
149
150    fig, ax = plt.subplots()
151    ax.barh(x, w, left=b)
152    ax.set_xlim([b - 1 * day, b + w[-1] + (1.001) * day])
153
154
155def test_empty_arrays():
156    # Check that plotting an empty array with a dtype works
157    plt.scatter(np.array([], dtype='datetime64[ns]'), np.array([]))
158
159
160def test_scatter_element0_masked():
161    times = np.arange('2005-02', '2005-03', dtype='datetime64[D]')
162    y = np.arange(len(times), dtype=float)
163    y[0] = np.nan
164    fig, ax = plt.subplots()
165    ax.scatter(times, y)
166    fig.canvas.draw()
167
168
169@check_figures_equal(extensions=["png"])
170def test_subclass(fig_test, fig_ref):
171    class subdate(datetime):
172        pass
173
174    fig_test.subplots().plot(subdate(2000, 1, 1), 0, "o")
175    fig_ref.subplots().plot(datetime(2000, 1, 1), 0, "o")
176
177
178def test_shared_axis_quantity(quantity_converter):
179    munits.registry[Quantity] = quantity_converter
180    x = Quantity(np.linspace(0, 1, 10), "hours")
181    y1 = Quantity(np.linspace(1, 2, 10), "feet")
182    y2 = Quantity(np.linspace(3, 4, 10), "feet")
183    fig, (ax1, ax2) = plt.subplots(2, 1, sharex='all', sharey='all')
184    ax1.plot(x, y1)
185    ax2.plot(x, y2)
186    assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "hours"
187    assert ax2.yaxis.get_units() == ax2.yaxis.get_units() == "feet"
188    ax1.xaxis.set_units("seconds")
189    ax2.yaxis.set_units("inches")
190    assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "seconds"
191    assert ax1.yaxis.get_units() == ax2.yaxis.get_units() == "inches"
192
193
194def test_shared_axis_datetime():
195    # datetime uses dates.DateConverter
196    y1 = [datetime(2020, i, 1, tzinfo=timezone.utc) for i in range(1, 13)]
197    y2 = [datetime(2021, i, 1, tzinfo=timezone.utc) for i in range(1, 13)]
198    fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
199    ax1.plot(y1)
200    ax2.plot(y2)
201    ax1.yaxis.set_units(timezone(timedelta(hours=5)))
202    assert ax2.yaxis.units == timezone(timedelta(hours=5))
203
204
205def test_shared_axis_categorical():
206    # str uses category.StrCategoryConverter
207    d1 = {"a": 1, "b": 2}
208    d2 = {"a": 3, "b": 4}
209    fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True)
210    ax1.plot(d1.keys(), d1.values())
211    ax2.plot(d2.keys(), d2.values())
212    ax1.xaxis.set_units(UnitData(["c", "d"]))
213    assert "c" in ax2.xaxis.get_units()._mapping.keys()
214