1from matplotlib.cbook import iterable
2import matplotlib.pyplot as plt
3from matplotlib.testing.decorators import image_comparison
4import matplotlib.units as munits
5import numpy as np
6import datetime
7
8try:
9    # mock in python 3.3+
10    from unittest.mock import MagicMock
11except ImportError:
12    from mock import MagicMock
13
14
15# Basic class that wraps numpy array and has units
16class Quantity(object):
17    def __init__(self, data, units):
18        self.magnitude = data
19        self.units = units
20
21    def to(self, new_units):
22        factors = {('hours', 'seconds'): 3600, ('minutes', 'hours'): 1 / 60,
23                   ('minutes', 'seconds'): 60, ('feet', 'miles'): 1 / 5280.,
24                   ('feet', 'inches'): 12, ('miles', 'inches'): 12 * 5280}
25        if self.units != new_units:
26            mult = factors[self.units, new_units]
27            return Quantity(mult * self.magnitude, new_units)
28        else:
29            return Quantity(self.magnitude, self.units)
30
31    def __getattr__(self, attr):
32        return getattr(self.magnitude, attr)
33
34    def __getitem__(self, item):
35        if iterable(self.magnitude):
36            return Quantity(self.magnitude[item], self.units)
37        else:
38            return Quantity(self.magnitude, self.units)
39
40    def __array__(self):
41        return np.asarray(self.magnitude)
42
43
44# Tests that the conversion machinery works properly for classes that
45# work as a facade over numpy arrays (like pint)
46@image_comparison(baseline_images=['plot_pint'],
47                  extensions=['png'], remove_text=False, style='mpl20')
48def test_numpy_facade():
49    # Create an instance of the conversion interface and
50    # mock so we can check methods called
51    qc = munits.ConversionInterface()
52
53    def convert(value, unit, axis):
54        if hasattr(value, 'units'):
55            return value.to(unit).magnitude
56        elif iterable(value):
57            try:
58                return [v.to(unit).magnitude for v in value]
59            except AttributeError:
60                return [Quantity(v, axis.get_units()).to(unit).magnitude
61                        for v in value]
62        else:
63            return Quantity(value, axis.get_units()).to(unit).magnitude
64
65    qc.convert = MagicMock(side_effect=convert)
66    qc.axisinfo = MagicMock(side_effect=lambda u, a: munits.AxisInfo(label=u))
67    qc.default_units = MagicMock(side_effect=lambda x, a: x.units)
68
69    # Register the class
70    munits.registry[Quantity] = qc
71
72    # Simple test
73    y = Quantity(np.linspace(0, 30), 'miles')
74    x = Quantity(np.linspace(0, 5), 'hours')
75
76    fig, ax = plt.subplots()
77    fig.subplots_adjust(left=0.15)  # Make space for label
78    ax.plot(x, y, 'tab:blue')
79    ax.axhline(Quantity(26400, 'feet'), color='tab:red')
80    ax.axvline(Quantity(120, 'minutes'), color='tab:green')
81    ax.yaxis.set_units('inches')
82    ax.xaxis.set_units('seconds')
83
84    assert qc.convert.called
85    assert qc.axisinfo.called
86    assert qc.default_units.called
87
88
89# Tests gh-8908
90@image_comparison(baseline_images=['plot_masked_units'],
91                  extensions=['png'], remove_text=True, style='mpl20')
92def test_plot_masked_units():
93    data = np.linspace(-5, 5)
94    data_masked = np.ma.array(data, mask=(data > -2) & (data < 2))
95    data_masked_units = Quantity(data_masked, 'meters')
96
97    fig, ax = plt.subplots()
98    ax.plot(data_masked_units)
99
100
101@image_comparison(baseline_images=['jpl_bar_units'], extensions=['png'],
102                  savefig_kwarg={'dpi': 120}, style='mpl20')
103def test_jpl_bar_units():
104    from datetime import datetime
105    import matplotlib.testing.jpl_units as units
106    units.register()
107
108    day = units.Duration("ET", 24.0 * 60.0 * 60.0)
109    x = [0*units.km, 1*units.km, 2*units.km]
110    w = [1*day, 2*day, 3*day]
111    b = units.Epoch("ET", dt=datetime(2009, 4, 25))
112
113    fig, ax = plt.subplots()
114    ax.bar(x, w, bottom=b)
115    ax.set_ylim([b-1*day, b+w[-1]+1*day])
116
117
118@image_comparison(baseline_images=['jpl_barh_units'], extensions=['png'],
119                  savefig_kwarg={'dpi': 120}, style='mpl20')
120def test_jpl_barh_units():
121    from datetime import datetime
122    import matplotlib.testing.jpl_units as units
123    units.register()
124
125    day = units.Duration("ET", 24.0 * 60.0 * 60.0)
126    x = [0*units.km, 1*units.km, 2*units.km]
127    w = [1*day, 2*day, 3*day]
128    b = units.Epoch("ET", dt=datetime(2009, 4, 25))
129
130    fig, ax = plt.subplots()
131    ax.barh(x, w, left=b)
132    ax.set_xlim([b-1*day, b+w[-1]+1*day])
133