1# Licensed under a 3-clause BSD style license - see LICENSE.rst 2 3from packaging.version import Version 4import pytest 5import numpy as np 6from numpy import ma 7from numpy.testing import assert_allclose, assert_equal 8 9from astropy.utils.exceptions import AstropyDeprecationWarning 10from astropy.visualization.mpl_normalize import ImageNormalize, simple_norm, imshow_norm 11from astropy.visualization.interval import ManualInterval, PercentileInterval 12from astropy.visualization.stretch import LogStretch, PowerStretch, SqrtStretch 13from astropy.utils.compat.optional_deps import HAS_MATPLOTLIB, HAS_PLT # noqa 14 15if HAS_MATPLOTLIB: 16 import matplotlib 17 MATPLOTLIB_LT_32 = Version(matplotlib.__version__) < Version('3.2') 18 19DATA = np.linspace(0., 15., 6) 20DATA2 = np.arange(3) 21DATA2SCL = 0.5 * DATA2 22DATA3 = np.linspace(-3., 3., 7) 23STRETCHES = (SqrtStretch(), PowerStretch(0.5), LogStretch()) 24INVALID = (None, -np.inf, -1) 25 26 27@pytest.mark.skipif('HAS_MATPLOTLIB') 28def test_normalize_error_message(): 29 with pytest.raises(ImportError) as exc: 30 ImageNormalize() 31 assert (exc.value.args[0] == "matplotlib is required in order to use " 32 "this class.") 33 34 35@pytest.mark.skipif('not HAS_MATPLOTLIB') 36class TestNormalize: 37 def test_invalid_interval(self): 38 with pytest.raises(TypeError): 39 ImageNormalize(vmin=2., vmax=10., interval=ManualInterval, 40 clip=True) 41 42 def test_invalid_stretch(self): 43 with pytest.raises(TypeError): 44 ImageNormalize(vmin=2., vmax=10., stretch=SqrtStretch, 45 clip=True) 46 47 def test_stretch_none(self): 48 with pytest.raises(ValueError): 49 ImageNormalize(vmin=2., vmax=10., stretch=None) 50 51 def test_scalar(self): 52 norm = ImageNormalize(vmin=2., vmax=10., stretch=SqrtStretch(), 53 clip=True) 54 norm2 = ImageNormalize(data=6, interval=ManualInterval(2, 10), 55 stretch=SqrtStretch(), clip=True) 56 assert_allclose(norm(6), 0.70710678) 57 assert_allclose(norm(6), norm2(6)) 58 59 def test_clip(self): 60 norm = ImageNormalize(vmin=2., vmax=10., stretch=SqrtStretch(), 61 clip=True) 62 norm2 = ImageNormalize(DATA, interval=ManualInterval(2, 10), 63 stretch=SqrtStretch(), clip=True) 64 output = norm(DATA) 65 expected = [0., 0.35355339, 0.70710678, 0.93541435, 1., 1.] 66 assert_allclose(output, expected) 67 assert_allclose(output.mask, [0, 0, 0, 0, 0, 0]) 68 assert_allclose(output, norm2(DATA)) 69 70 def test_noclip(self): 71 norm = ImageNormalize(vmin=2., vmax=10., stretch=SqrtStretch(), 72 clip=False, invalid=None) 73 norm2 = ImageNormalize(DATA, interval=ManualInterval(2, 10), 74 stretch=SqrtStretch(), clip=False, 75 invalid=None) 76 output = norm(DATA) 77 expected = [np.nan, 0.35355339, 0.70710678, 0.93541435, 1.11803399, 78 1.27475488] 79 assert_allclose(output, expected) 80 assert_allclose(output.mask, [0, 0, 0, 0, 0, 0]) 81 assert_allclose(norm.inverse(norm(DATA))[1:], DATA[1:]) 82 assert_allclose(output, norm2(DATA)) 83 84 def test_implicit_autoscale(self): 85 norm = ImageNormalize(vmin=None, vmax=10., stretch=SqrtStretch(), 86 clip=False) 87 norm2 = ImageNormalize(DATA, interval=ManualInterval(None, 10), 88 stretch=SqrtStretch(), clip=False) 89 output = norm(DATA) 90 assert norm.vmin == np.min(DATA) 91 assert norm.vmax == 10. 92 assert_allclose(output, norm2(DATA)) 93 94 norm = ImageNormalize(vmin=2., vmax=None, stretch=SqrtStretch(), 95 clip=False) 96 norm2 = ImageNormalize(DATA, interval=ManualInterval(2, None), 97 stretch=SqrtStretch(), clip=False) 98 output = norm(DATA) 99 assert norm.vmin == 2. 100 assert norm.vmax == np.max(DATA) 101 assert_allclose(output, norm2(DATA)) 102 103 def test_call_clip(self): 104 """Test that the clip keyword is used when calling the object.""" 105 data = np.arange(5) 106 norm = ImageNormalize(vmin=1., vmax=3., clip=False) 107 108 output = norm(data, clip=True) 109 assert_equal(output.data, [0, 0, 0.5, 1.0, 1.0]) 110 assert np.all(~output.mask) 111 112 output = norm(data, clip=False) 113 assert_equal(output.data, [-0.5, 0, 0.5, 1.0, 1.5]) 114 assert np.all(~output.mask) 115 116 def test_masked_clip(self): 117 mdata = ma.array(DATA, mask=[0, 0, 1, 0, 0, 0]) 118 norm = ImageNormalize(vmin=2., vmax=10., stretch=SqrtStretch(), 119 clip=True) 120 norm2 = ImageNormalize(mdata, interval=ManualInterval(2, 10), 121 stretch=SqrtStretch(), clip=True) 122 output = norm(mdata) 123 expected = [0., 0.35355339, 1., 0.93541435, 1., 1.] 124 assert_allclose(output.filled(-10), expected) 125 assert_allclose(output.mask, [0, 0, 0, 0, 0, 0]) 126 assert_allclose(output, norm2(mdata)) 127 128 def test_masked_noclip(self): 129 mdata = ma.array(DATA, mask=[0, 0, 1, 0, 0, 0]) 130 norm = ImageNormalize(vmin=2., vmax=10., stretch=SqrtStretch(), 131 clip=False, invalid=None) 132 norm2 = ImageNormalize(mdata, interval=ManualInterval(2, 10), 133 stretch=SqrtStretch(), clip=False, 134 invalid=None) 135 output = norm(mdata) 136 expected = [np.nan, 0.35355339, -10, 0.93541435, 1.11803399, 137 1.27475488] 138 assert_allclose(output.filled(-10), expected) 139 assert_allclose(output.mask, [0, 0, 1, 0, 0, 0]) 140 141 assert_allclose(norm.inverse(norm(DATA))[1:], DATA[1:]) 142 assert_allclose(output, norm2(mdata)) 143 144 def test_invalid_data(self): 145 data = np.arange(25.).reshape((5, 5)) 146 data[2, 2] = np.nan 147 data[1, 2] = np.inf 148 percent = 85.0 149 interval = PercentileInterval(percent) 150 151 # initialized without data 152 norm = ImageNormalize(interval=interval) 153 norm(data) # sets vmin/vmax 154 assert_equal((norm.vmin, norm.vmax), (1.65, 22.35)) 155 156 # initialized with data 157 norm2 = ImageNormalize(data, interval=interval) 158 assert_equal((norm2.vmin, norm2.vmax), (norm.vmin, norm.vmax)) 159 160 norm3 = simple_norm(data, 'linear', percent=percent) 161 assert_equal((norm3.vmin, norm3.vmax), (norm.vmin, norm.vmax)) 162 163 assert_allclose(norm(data), norm2(data)) 164 assert_allclose(norm(data), norm3(data)) 165 166 norm4 = ImageNormalize() 167 norm4(data) # sets vmin/vmax 168 assert_equal((norm4.vmin, norm4.vmax), (0, 24)) 169 170 norm5 = ImageNormalize(data) 171 assert_equal((norm5.vmin, norm5.vmax), (norm4.vmin, norm4.vmax)) 172 173 @pytest.mark.parametrize('stretch', STRETCHES) 174 def test_invalid_keyword(self, stretch): 175 norm1 = ImageNormalize(stretch=stretch, vmin=-1, vmax=1, clip=False, 176 invalid=None) 177 norm2 = ImageNormalize(stretch=stretch, vmin=-1, vmax=1, clip=False) 178 norm3 = ImageNormalize(DATA3, stretch=stretch, vmin=-1, vmax=1, 179 clip=False, invalid=-1.) 180 result1 = norm1(DATA3) 181 result2 = norm2(DATA3) 182 result3 = norm3(DATA3) 183 assert_equal(result1[0:2], (np.nan, np.nan)) 184 assert_equal(result2[0:2], (-1., -1.)) 185 assert_equal(result1[2:], result2[2:]) 186 assert_equal(result2, result3) 187 188 189@pytest.mark.skipif('not HAS_MATPLOTLIB') 190class TestImageScaling: 191 192 def test_linear(self): 193 """Test linear scaling.""" 194 norm = simple_norm(DATA2, stretch='linear') 195 assert_allclose(norm(DATA2), DATA2SCL, atol=0, rtol=1.e-5) 196 197 def test_sqrt(self): 198 """Test sqrt scaling.""" 199 norm1 = simple_norm(DATA2, stretch='sqrt') 200 assert_allclose(norm1(DATA2), np.sqrt(DATA2SCL), atol=0, rtol=1.e-5) 201 202 @pytest.mark.parametrize('invalid', INVALID) 203 def test_sqrt_invalid_kw(self, invalid): 204 stretch = SqrtStretch() 205 norm1 = simple_norm(DATA3, stretch='sqrt', min_cut=-1, max_cut=1, 206 clip=False, invalid=invalid) 207 norm2 = ImageNormalize(stretch=stretch, vmin=-1, vmax=1, clip=False, 208 invalid=invalid) 209 assert_equal(norm1(DATA3), norm2(DATA3)) 210 211 def test_power(self): 212 """Test power scaling.""" 213 power = 3.0 214 norm = simple_norm(DATA2, stretch='power', power=power) 215 assert_allclose(norm(DATA2), DATA2SCL ** power, atol=0, rtol=1.e-5) 216 217 def test_log(self): 218 """Test log10 scaling.""" 219 norm = simple_norm(DATA2, stretch='log') 220 ref = np.log10(1000 * DATA2SCL + 1.0) / np.log10(1001.0) 221 assert_allclose(norm(DATA2), ref, atol=0, rtol=1.e-5) 222 223 def test_log_with_log_a(self): 224 """Test log10 scaling with a custom log_a.""" 225 log_a = 100 226 norm = simple_norm(DATA2, stretch='log', log_a=log_a) 227 ref = np.log10(log_a * DATA2SCL + 1.0) / np.log10(log_a + 1) 228 assert_allclose(norm(DATA2), ref, atol=0, rtol=1.e-5) 229 230 def test_asinh(self): 231 """Test asinh scaling.""" 232 norm = simple_norm(DATA2, stretch='asinh') 233 ref = np.arcsinh(10 * DATA2SCL) / np.arcsinh(10) 234 assert_allclose(norm(DATA2), ref, atol=0, rtol=1.e-5) 235 236 def test_asinh_with_asinh_a(self): 237 """Test asinh scaling with a custom asinh_a.""" 238 asinh_a = 0.5 239 norm = simple_norm(DATA2, stretch='asinh', asinh_a=asinh_a) 240 ref = np.arcsinh(DATA2SCL / asinh_a) / np.arcsinh(1. / asinh_a) 241 assert_allclose(norm(DATA2), ref, atol=0, rtol=1.e-5) 242 243 def test_min(self): 244 """Test linear scaling.""" 245 norm = simple_norm(DATA2, stretch='linear', min_cut=1., clip=True) 246 assert_allclose(norm(DATA2), [0., 0., 1.], atol=0, rtol=1.e-5) 247 248 def test_percent(self): 249 """Test percent keywords.""" 250 norm = simple_norm(DATA2, stretch='linear', percent=99., clip=True) 251 assert_allclose(norm(DATA2), DATA2SCL, atol=0, rtol=1.e-5) 252 253 norm2 = simple_norm(DATA2, stretch='linear', min_percent=0.5, 254 max_percent=99.5, clip=True) 255 assert_allclose(norm(DATA2), norm2(DATA2), atol=0, rtol=1.e-5) 256 257 def test_invalid_stretch(self): 258 """Test invalid stretch keyword.""" 259 with pytest.raises(ValueError): 260 simple_norm(DATA2, stretch='invalid') 261 262 263@pytest.mark.skipif('not HAS_PLT') 264def test_imshow_norm(): 265 import matplotlib.pyplot as plt 266 image = np.random.randn(10, 10) 267 268 ax = plt.subplot(label='test_imshow_norm') 269 imshow_norm(image, ax=ax) 270 271 with pytest.raises(ValueError): 272 # X and data are the same, can't give both 273 imshow_norm(image, X=image, ax=ax) 274 275 with pytest.raises(ValueError): 276 # illegal to manually pass in normalization since that defeats the point 277 imshow_norm(image, ax=ax, norm=ImageNormalize()) 278 279 imshow_norm(image, ax=ax, vmin=0, vmax=1) 280 281 # make sure the pyplot version works 282 imres, norm = imshow_norm(image, ax=None) 283 284 assert isinstance(norm, ImageNormalize) 285 286 plt.close('all') 287