1# -*- coding: utf-8 -*- 2# Licensed under a 3-clause BSD style license - see LICENSE.rst 3 4import io 5 6import pytest 7 8from astropy.utils.compat.optional_deps import HAS_PLT 9if HAS_PLT: 10 import matplotlib.pyplot as plt 11import numpy as np 12 13from astropy import units as u 14from astropy.coordinates import Angle 15from astropy.visualization.units import quantity_support 16 17 18def teardown_function(function): 19 plt.close('all') 20 21 22@pytest.mark.skipif('not HAS_PLT') 23def test_units(): 24 plt.figure() 25 26 with quantity_support(): 27 buff = io.BytesIO() 28 29 plt.plot([1, 2, 3] * u.m, [3, 4, 5] * u.kg, label='label') 30 plt.plot([105, 210, 315] * u.cm, [3050, 3025, 3010] * u.g) 31 plt.legend() 32 # Also test fill_between, which requires actual conversion to ndarray 33 # with numpy >=1.10 (#4654). 34 plt.fill_between([1, 3] * u.m, [3, 5] * u.kg, [3050, 3010] * u.g) 35 plt.savefig(buff, format='svg') 36 37 assert plt.gca().xaxis.get_units() == u.m 38 assert plt.gca().yaxis.get_units() == u.kg 39 40 41@pytest.mark.skipif('not HAS_PLT') 42def test_units_errbarr(): 43 pytest.importorskip("matplotlib") 44 plt.figure() 45 46 with quantity_support(): 47 x = [1, 2, 3] * u.s 48 y = [1, 2, 3] * u.m 49 yerr = [3, 2, 1] * u.cm 50 51 fig, ax = plt.subplots() 52 ax.errorbar(x, y, yerr=yerr) 53 54 assert ax.xaxis.get_units() == u.s 55 assert ax.yaxis.get_units() == u.m 56 57 58@pytest.mark.skipif('not HAS_PLT') 59def test_incompatible_units(): 60 # NOTE: minversion check does not work properly for matplotlib dev. 61 try: 62 # https://github.com/matplotlib/matplotlib/pull/13005 63 from matplotlib.units import ConversionError 64 except ImportError: 65 err_type = u.UnitConversionError 66 else: 67 err_type = ConversionError 68 69 plt.figure() 70 71 with quantity_support(): 72 plt.plot([1, 2, 3] * u.m) 73 with pytest.raises(err_type): 74 plt.plot([105, 210, 315] * u.kg) 75 76 77@pytest.mark.skipif('not HAS_PLT') 78def test_quantity_subclass(): 79 """Check that subclasses are recognized. 80 81 This sadly is not done by matplotlib.units itself, though 82 there is a PR to change it: 83 https://github.com/matplotlib/matplotlib/pull/13536 84 """ 85 plt.figure() 86 87 with quantity_support(): 88 plt.scatter(Angle([1, 2, 3], u.deg), [3, 4, 5] * u.kg) 89 plt.scatter([105, 210, 315] * u.arcsec, [3050, 3025, 3010] * u.g) 90 plt.plot(Angle([105, 210, 315], u.arcsec), [3050, 3025, 3010] * u.g) 91 92 assert plt.gca().xaxis.get_units() == u.deg 93 assert plt.gca().yaxis.get_units() == u.kg 94 95 96@pytest.mark.skipif('not HAS_PLT') 97def test_nested(): 98 99 with quantity_support(): 100 101 with quantity_support(): 102 103 fig = plt.figure() 104 ax = fig.add_subplot(1, 1, 1) 105 ax.scatter(Angle([1, 2, 3], u.deg), [3, 4, 5] * u.kg) 106 107 assert ax.xaxis.get_units() == u.deg 108 assert ax.yaxis.get_units() == u.kg 109 110 fig = plt.figure() 111 ax = fig.add_subplot(1, 1, 1) 112 ax.scatter(Angle([1, 2, 3], u.arcsec), [3, 4, 5] * u.pc) 113 114 assert ax.xaxis.get_units() == u.arcsec 115 assert ax.yaxis.get_units() == u.pc 116 117 118@pytest.mark.skipif('not HAS_PLT') 119def test_empty_hist(): 120 121 with quantity_support(): 122 fig = plt.figure() 123 ax = fig.add_subplot(1, 1, 1) 124 ax.hist([1, 2, 3, 4] * u.mmag, bins=100) 125 # The second call results in an empty list being passed to the 126 # unit converter in matplotlib >= 3.1 127 ax.hist([] * u.mmag, bins=100) 128 129 130@pytest.mark.skipif('not HAS_PLT') 131def test_radian_formatter(): 132 with quantity_support(): 133 fig, ax = plt.subplots() 134 ax.plot([1, 2, 3], [1, 2, 3] * u.rad * np.pi) 135 fig.canvas.draw() 136 labels = [tl.get_text() for tl in ax.yaxis.get_ticklabels()] 137 assert labels == ['π/2', 'π', '3π/2', '2π', '5π/2', '3π', '7π/2'] 138