1from datetime import datetime, timezone, timedelta 2import platform 3from unittest.mock import MagicMock 4 5import matplotlib.pyplot as plt 6from matplotlib.testing.decorators import check_figures_equal, image_comparison 7import matplotlib.units as munits 8from matplotlib.category import UnitData 9import numpy as np 10import pytest 11 12 13# Basic class that wraps numpy array and has units 14class Quantity: 15 def __init__(self, data, units): 16 self.magnitude = data 17 self.units = units 18 19 def to(self, new_units): 20 factors = {('hours', 'seconds'): 3600, ('minutes', 'hours'): 1 / 60, 21 ('minutes', 'seconds'): 60, ('feet', 'miles'): 1 / 5280., 22 ('feet', 'inches'): 12, ('miles', 'inches'): 12 * 5280} 23 if self.units != new_units: 24 mult = factors[self.units, new_units] 25 return Quantity(mult * self.magnitude, new_units) 26 else: 27 return Quantity(self.magnitude, self.units) 28 29 def __getattr__(self, attr): 30 return getattr(self.magnitude, attr) 31 32 def __getitem__(self, item): 33 if np.iterable(self.magnitude): 34 return Quantity(self.magnitude[item], self.units) 35 else: 36 return Quantity(self.magnitude, self.units) 37 38 def __array__(self): 39 return np.asarray(self.magnitude) 40 41 42@pytest.fixture 43def quantity_converter(): 44 # Create an instance of the conversion interface and 45 # mock so we can check methods called 46 qc = munits.ConversionInterface() 47 48 def convert(value, unit, axis): 49 if hasattr(value, 'units'): 50 return value.to(unit).magnitude 51 elif np.iterable(value): 52 try: 53 return [v.to(unit).magnitude for v in value] 54 except AttributeError: 55 return [Quantity(v, axis.get_units()).to(unit).magnitude 56 for v in value] 57 else: 58 return Quantity(value, axis.get_units()).to(unit).magnitude 59 60 def default_units(value, axis): 61 if hasattr(value, 'units'): 62 return value.units 63 elif np.iterable(value): 64 for v in value: 65 if hasattr(v, 'units'): 66 return v.units 67 return None 68 69 qc.convert = MagicMock(side_effect=convert) 70 qc.axisinfo = MagicMock(side_effect=lambda u, a: munits.AxisInfo(label=u)) 71 qc.default_units = MagicMock(side_effect=default_units) 72 return qc 73 74 75# Tests that the conversion machinery works properly for classes that 76# work as a facade over numpy arrays (like pint) 77@image_comparison(['plot_pint.png'], remove_text=False, style='mpl20', 78 tol=0 if platform.machine() == 'x86_64' else 0.01) 79def test_numpy_facade(quantity_converter): 80 # use former defaults to match existing baseline image 81 plt.rcParams['axes.formatter.limits'] = -7, 7 82 83 # Register the class 84 munits.registry[Quantity] = quantity_converter 85 86 # Simple test 87 y = Quantity(np.linspace(0, 30), 'miles') 88 x = Quantity(np.linspace(0, 5), 'hours') 89 90 fig, ax = plt.subplots() 91 fig.subplots_adjust(left=0.15) # Make space for label 92 ax.plot(x, y, 'tab:blue') 93 ax.axhline(Quantity(26400, 'feet'), color='tab:red') 94 ax.axvline(Quantity(120, 'minutes'), color='tab:green') 95 ax.yaxis.set_units('inches') 96 ax.xaxis.set_units('seconds') 97 98 assert quantity_converter.convert.called 99 assert quantity_converter.axisinfo.called 100 assert quantity_converter.default_units.called 101 102 103# Tests gh-8908 104@image_comparison(['plot_masked_units.png'], remove_text=True, style='mpl20', 105 tol=0 if platform.machine() == 'x86_64' else 0.01) 106def test_plot_masked_units(): 107 data = np.linspace(-5, 5) 108 data_masked = np.ma.array(data, mask=(data > -2) & (data < 2)) 109 data_masked_units = Quantity(data_masked, 'meters') 110 111 fig, ax = plt.subplots() 112 ax.plot(data_masked_units) 113 114 115def test_empty_set_limits_with_units(quantity_converter): 116 # Register the class 117 munits.registry[Quantity] = quantity_converter 118 119 fig, ax = plt.subplots() 120 ax.set_xlim(Quantity(-1, 'meters'), Quantity(6, 'meters')) 121 ax.set_ylim(Quantity(-1, 'hours'), Quantity(16, 'hours')) 122 123 124@image_comparison(['jpl_bar_units.png'], 125 savefig_kwarg={'dpi': 120}, style='mpl20') 126def test_jpl_bar_units(): 127 import matplotlib.testing.jpl_units as units 128 units.register() 129 130 day = units.Duration("ET", 24.0 * 60.0 * 60.0) 131 x = [0 * units.km, 1 * units.km, 2 * units.km] 132 w = [1 * day, 2 * day, 3 * day] 133 b = units.Epoch("ET", dt=datetime(2009, 4, 25)) 134 fig, ax = plt.subplots() 135 ax.bar(x, w, bottom=b) 136 ax.set_ylim([b - 1 * day, b + w[-1] + (1.001) * day]) 137 138 139@image_comparison(['jpl_barh_units.png'], 140 savefig_kwarg={'dpi': 120}, style='mpl20') 141def test_jpl_barh_units(): 142 import matplotlib.testing.jpl_units as units 143 units.register() 144 145 day = units.Duration("ET", 24.0 * 60.0 * 60.0) 146 x = [0 * units.km, 1 * units.km, 2 * units.km] 147 w = [1 * day, 2 * day, 3 * day] 148 b = units.Epoch("ET", dt=datetime(2009, 4, 25)) 149 150 fig, ax = plt.subplots() 151 ax.barh(x, w, left=b) 152 ax.set_xlim([b - 1 * day, b + w[-1] + (1.001) * day]) 153 154 155def test_empty_arrays(): 156 # Check that plotting an empty array with a dtype works 157 plt.scatter(np.array([], dtype='datetime64[ns]'), np.array([])) 158 159 160def test_scatter_element0_masked(): 161 times = np.arange('2005-02', '2005-03', dtype='datetime64[D]') 162 y = np.arange(len(times), dtype=float) 163 y[0] = np.nan 164 fig, ax = plt.subplots() 165 ax.scatter(times, y) 166 fig.canvas.draw() 167 168 169@check_figures_equal(extensions=["png"]) 170def test_subclass(fig_test, fig_ref): 171 class subdate(datetime): 172 pass 173 174 fig_test.subplots().plot(subdate(2000, 1, 1), 0, "o") 175 fig_ref.subplots().plot(datetime(2000, 1, 1), 0, "o") 176 177 178def test_shared_axis_quantity(quantity_converter): 179 munits.registry[Quantity] = quantity_converter 180 x = Quantity(np.linspace(0, 1, 10), "hours") 181 y1 = Quantity(np.linspace(1, 2, 10), "feet") 182 y2 = Quantity(np.linspace(3, 4, 10), "feet") 183 fig, (ax1, ax2) = plt.subplots(2, 1, sharex='all', sharey='all') 184 ax1.plot(x, y1) 185 ax2.plot(x, y2) 186 assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "hours" 187 assert ax2.yaxis.get_units() == ax2.yaxis.get_units() == "feet" 188 ax1.xaxis.set_units("seconds") 189 ax2.yaxis.set_units("inches") 190 assert ax1.xaxis.get_units() == ax2.xaxis.get_units() == "seconds" 191 assert ax1.yaxis.get_units() == ax2.yaxis.get_units() == "inches" 192 193 194def test_shared_axis_datetime(): 195 # datetime uses dates.DateConverter 196 y1 = [datetime(2020, i, 1, tzinfo=timezone.utc) for i in range(1, 13)] 197 y2 = [datetime(2021, i, 1, tzinfo=timezone.utc) for i in range(1, 13)] 198 fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True) 199 ax1.plot(y1) 200 ax2.plot(y2) 201 ax1.yaxis.set_units(timezone(timedelta(hours=5))) 202 assert ax2.yaxis.units == timezone(timedelta(hours=5)) 203 204 205def test_shared_axis_categorical(): 206 # str uses category.StrCategoryConverter 207 d1 = {"a": 1, "b": 2} 208 d2 = {"a": 3, "b": 4} 209 fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True) 210 ax1.plot(d1.keys(), d1.values()) 211 ax2.plot(d2.keys(), d2.values()) 212 ax1.xaxis.set_units(UnitData(["c", "d"])) 213 assert "c" in ax2.xaxis.get_units()._mapping.keys() 214