1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3from packaging.version import Version
4import pytest
5import numpy as np
6from numpy import ma
7from numpy.testing import assert_allclose, assert_equal
8
9from astropy.utils.exceptions import AstropyDeprecationWarning
10from astropy.visualization.mpl_normalize import ImageNormalize, simple_norm, imshow_norm
11from astropy.visualization.interval import ManualInterval, PercentileInterval
12from astropy.visualization.stretch import LogStretch, PowerStretch, SqrtStretch
13from astropy.utils.compat.optional_deps import HAS_MATPLOTLIB, HAS_PLT  # noqa
14
15if HAS_MATPLOTLIB:
16    import matplotlib
17    MATPLOTLIB_LT_32 = Version(matplotlib.__version__) < Version('3.2')
18
19DATA = np.linspace(0., 15., 6)
20DATA2 = np.arange(3)
21DATA2SCL = 0.5 * DATA2
22DATA3 = np.linspace(-3., 3., 7)
23STRETCHES = (SqrtStretch(), PowerStretch(0.5), LogStretch())
24INVALID = (None, -np.inf, -1)
25
26
27@pytest.mark.skipif('HAS_MATPLOTLIB')
28def test_normalize_error_message():
29    with pytest.raises(ImportError) as exc:
30        ImageNormalize()
31    assert (exc.value.args[0] == "matplotlib is required in order to use "
32            "this class.")
33
34
35@pytest.mark.skipif('not HAS_MATPLOTLIB')
36class TestNormalize:
37    def test_invalid_interval(self):
38        with pytest.raises(TypeError):
39            ImageNormalize(vmin=2., vmax=10., interval=ManualInterval,
40                           clip=True)
41
42    def test_invalid_stretch(self):
43        with pytest.raises(TypeError):
44            ImageNormalize(vmin=2., vmax=10., stretch=SqrtStretch,
45                           clip=True)
46
47    def test_stretch_none(self):
48        with pytest.raises(ValueError):
49            ImageNormalize(vmin=2., vmax=10., stretch=None)
50
51    def test_scalar(self):
52        norm = ImageNormalize(vmin=2., vmax=10., stretch=SqrtStretch(),
53                              clip=True)
54        norm2 = ImageNormalize(data=6, interval=ManualInterval(2, 10),
55                               stretch=SqrtStretch(), clip=True)
56        assert_allclose(norm(6), 0.70710678)
57        assert_allclose(norm(6), norm2(6))
58
59    def test_clip(self):
60        norm = ImageNormalize(vmin=2., vmax=10., stretch=SqrtStretch(),
61                              clip=True)
62        norm2 = ImageNormalize(DATA, interval=ManualInterval(2, 10),
63                               stretch=SqrtStretch(), clip=True)
64        output = norm(DATA)
65        expected = [0., 0.35355339, 0.70710678, 0.93541435, 1., 1.]
66        assert_allclose(output, expected)
67        assert_allclose(output.mask, [0, 0, 0, 0, 0, 0])
68        assert_allclose(output, norm2(DATA))
69
70    def test_noclip(self):
71        norm = ImageNormalize(vmin=2., vmax=10., stretch=SqrtStretch(),
72                              clip=False, invalid=None)
73        norm2 = ImageNormalize(DATA, interval=ManualInterval(2, 10),
74                               stretch=SqrtStretch(), clip=False,
75                               invalid=None)
76        output = norm(DATA)
77        expected = [np.nan, 0.35355339, 0.70710678, 0.93541435, 1.11803399,
78                    1.27475488]
79        assert_allclose(output, expected)
80        assert_allclose(output.mask, [0, 0, 0, 0, 0, 0])
81        assert_allclose(norm.inverse(norm(DATA))[1:], DATA[1:])
82        assert_allclose(output, norm2(DATA))
83
84    def test_implicit_autoscale(self):
85        norm = ImageNormalize(vmin=None, vmax=10., stretch=SqrtStretch(),
86                              clip=False)
87        norm2 = ImageNormalize(DATA, interval=ManualInterval(None, 10),
88                               stretch=SqrtStretch(), clip=False)
89        output = norm(DATA)
90        assert norm.vmin == np.min(DATA)
91        assert norm.vmax == 10.
92        assert_allclose(output, norm2(DATA))
93
94        norm = ImageNormalize(vmin=2., vmax=None, stretch=SqrtStretch(),
95                              clip=False)
96        norm2 = ImageNormalize(DATA, interval=ManualInterval(2, None),
97                               stretch=SqrtStretch(), clip=False)
98        output = norm(DATA)
99        assert norm.vmin == 2.
100        assert norm.vmax == np.max(DATA)
101        assert_allclose(output, norm2(DATA))
102
103    def test_call_clip(self):
104        """Test that the clip keyword is used when calling the object."""
105        data = np.arange(5)
106        norm = ImageNormalize(vmin=1., vmax=3., clip=False)
107
108        output = norm(data, clip=True)
109        assert_equal(output.data, [0, 0, 0.5, 1.0, 1.0])
110        assert np.all(~output.mask)
111
112        output = norm(data, clip=False)
113        assert_equal(output.data, [-0.5, 0, 0.5, 1.0, 1.5])
114        assert np.all(~output.mask)
115
116    def test_masked_clip(self):
117        mdata = ma.array(DATA, mask=[0, 0, 1, 0, 0, 0])
118        norm = ImageNormalize(vmin=2., vmax=10., stretch=SqrtStretch(),
119                              clip=True)
120        norm2 = ImageNormalize(mdata, interval=ManualInterval(2, 10),
121                               stretch=SqrtStretch(), clip=True)
122        output = norm(mdata)
123        expected = [0., 0.35355339, 1., 0.93541435, 1., 1.]
124        assert_allclose(output.filled(-10), expected)
125        assert_allclose(output.mask, [0, 0, 0, 0, 0, 0])
126        assert_allclose(output, norm2(mdata))
127
128    def test_masked_noclip(self):
129        mdata = ma.array(DATA, mask=[0, 0, 1, 0, 0, 0])
130        norm = ImageNormalize(vmin=2., vmax=10., stretch=SqrtStretch(),
131                              clip=False, invalid=None)
132        norm2 = ImageNormalize(mdata, interval=ManualInterval(2, 10),
133                               stretch=SqrtStretch(), clip=False,
134                               invalid=None)
135        output = norm(mdata)
136        expected = [np.nan, 0.35355339, -10, 0.93541435, 1.11803399,
137                    1.27475488]
138        assert_allclose(output.filled(-10), expected)
139        assert_allclose(output.mask, [0, 0, 1, 0, 0, 0])
140
141        assert_allclose(norm.inverse(norm(DATA))[1:], DATA[1:])
142        assert_allclose(output, norm2(mdata))
143
144    def test_invalid_data(self):
145        data = np.arange(25.).reshape((5, 5))
146        data[2, 2] = np.nan
147        data[1, 2] = np.inf
148        percent = 85.0
149        interval = PercentileInterval(percent)
150
151        # initialized without data
152        norm = ImageNormalize(interval=interval)
153        norm(data)  # sets vmin/vmax
154        assert_equal((norm.vmin, norm.vmax), (1.65, 22.35))
155
156        # initialized with data
157        norm2 = ImageNormalize(data, interval=interval)
158        assert_equal((norm2.vmin, norm2.vmax), (norm.vmin, norm.vmax))
159
160        norm3 = simple_norm(data, 'linear', percent=percent)
161        assert_equal((norm3.vmin, norm3.vmax), (norm.vmin, norm.vmax))
162
163        assert_allclose(norm(data), norm2(data))
164        assert_allclose(norm(data), norm3(data))
165
166        norm4 = ImageNormalize()
167        norm4(data)  # sets vmin/vmax
168        assert_equal((norm4.vmin, norm4.vmax), (0, 24))
169
170        norm5 = ImageNormalize(data)
171        assert_equal((norm5.vmin, norm5.vmax), (norm4.vmin, norm4.vmax))
172
173    @pytest.mark.parametrize('stretch', STRETCHES)
174    def test_invalid_keyword(self, stretch):
175        norm1 = ImageNormalize(stretch=stretch, vmin=-1, vmax=1, clip=False,
176                               invalid=None)
177        norm2 = ImageNormalize(stretch=stretch, vmin=-1, vmax=1, clip=False)
178        norm3 = ImageNormalize(DATA3, stretch=stretch, vmin=-1, vmax=1,
179                               clip=False, invalid=-1.)
180        result1 = norm1(DATA3)
181        result2 = norm2(DATA3)
182        result3 = norm3(DATA3)
183        assert_equal(result1[0:2], (np.nan, np.nan))
184        assert_equal(result2[0:2], (-1., -1.))
185        assert_equal(result1[2:], result2[2:])
186        assert_equal(result2, result3)
187
188
189@pytest.mark.skipif('not HAS_MATPLOTLIB')
190class TestImageScaling:
191
192    def test_linear(self):
193        """Test linear scaling."""
194        norm = simple_norm(DATA2, stretch='linear')
195        assert_allclose(norm(DATA2), DATA2SCL, atol=0, rtol=1.e-5)
196
197    def test_sqrt(self):
198        """Test sqrt scaling."""
199        norm1 = simple_norm(DATA2, stretch='sqrt')
200        assert_allclose(norm1(DATA2), np.sqrt(DATA2SCL), atol=0, rtol=1.e-5)
201
202    @pytest.mark.parametrize('invalid', INVALID)
203    def test_sqrt_invalid_kw(self, invalid):
204        stretch = SqrtStretch()
205        norm1 = simple_norm(DATA3, stretch='sqrt', min_cut=-1, max_cut=1,
206                            clip=False, invalid=invalid)
207        norm2 = ImageNormalize(stretch=stretch, vmin=-1, vmax=1, clip=False,
208                               invalid=invalid)
209        assert_equal(norm1(DATA3), norm2(DATA3))
210
211    def test_power(self):
212        """Test power scaling."""
213        power = 3.0
214        norm = simple_norm(DATA2, stretch='power', power=power)
215        assert_allclose(norm(DATA2), DATA2SCL ** power, atol=0, rtol=1.e-5)
216
217    def test_log(self):
218        """Test log10 scaling."""
219        norm = simple_norm(DATA2, stretch='log')
220        ref = np.log10(1000 * DATA2SCL + 1.0) / np.log10(1001.0)
221        assert_allclose(norm(DATA2), ref, atol=0, rtol=1.e-5)
222
223    def test_log_with_log_a(self):
224        """Test log10 scaling with a custom log_a."""
225        log_a = 100
226        norm = simple_norm(DATA2, stretch='log', log_a=log_a)
227        ref = np.log10(log_a * DATA2SCL + 1.0) / np.log10(log_a + 1)
228        assert_allclose(norm(DATA2), ref, atol=0, rtol=1.e-5)
229
230    def test_asinh(self):
231        """Test asinh scaling."""
232        norm = simple_norm(DATA2, stretch='asinh')
233        ref = np.arcsinh(10 * DATA2SCL) / np.arcsinh(10)
234        assert_allclose(norm(DATA2), ref, atol=0, rtol=1.e-5)
235
236    def test_asinh_with_asinh_a(self):
237        """Test asinh scaling with a custom asinh_a."""
238        asinh_a = 0.5
239        norm = simple_norm(DATA2, stretch='asinh', asinh_a=asinh_a)
240        ref = np.arcsinh(DATA2SCL / asinh_a) / np.arcsinh(1. / asinh_a)
241        assert_allclose(norm(DATA2), ref, atol=0, rtol=1.e-5)
242
243    def test_min(self):
244        """Test linear scaling."""
245        norm = simple_norm(DATA2, stretch='linear', min_cut=1., clip=True)
246        assert_allclose(norm(DATA2), [0., 0., 1.], atol=0, rtol=1.e-5)
247
248    def test_percent(self):
249        """Test percent keywords."""
250        norm = simple_norm(DATA2, stretch='linear', percent=99., clip=True)
251        assert_allclose(norm(DATA2), DATA2SCL, atol=0, rtol=1.e-5)
252
253        norm2 = simple_norm(DATA2, stretch='linear', min_percent=0.5,
254                            max_percent=99.5, clip=True)
255        assert_allclose(norm(DATA2), norm2(DATA2), atol=0, rtol=1.e-5)
256
257    def test_invalid_stretch(self):
258        """Test invalid stretch keyword."""
259        with pytest.raises(ValueError):
260            simple_norm(DATA2, stretch='invalid')
261
262
263@pytest.mark.skipif('not HAS_PLT')
264def test_imshow_norm():
265    import matplotlib.pyplot as plt
266    image = np.random.randn(10, 10)
267
268    ax = plt.subplot(label='test_imshow_norm')
269    imshow_norm(image, ax=ax)
270
271    with pytest.raises(ValueError):
272        # X and data are the same, can't give both
273        imshow_norm(image, X=image, ax=ax)
274
275    with pytest.raises(ValueError):
276        # illegal to manually pass in normalization since that defeats the point
277        imshow_norm(image, ax=ax, norm=ImageNormalize())
278
279    imshow_norm(image, ax=ax, vmin=0, vmax=1)
280
281    # make sure the pyplot version works
282    imres, norm = imshow_norm(image, ax=None)
283
284    assert isinstance(norm, ImageNormalize)
285
286    plt.close('all')
287