1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3import inspect
4
5import pytest
6import numpy as np
7
8from astropy.utils.exceptions import AstropyUserWarning
9from astropy import units as u
10from astropy.wcs import WCS
11
12from astropy.nddata.nddata import NDData
13from astropy.nddata.decorators import support_nddata
14
15
16class CCDData(NDData):
17    pass
18
19
20@support_nddata
21def wrapped_function_1(data, wcs=None, unit=None):
22    return data, wcs, unit
23
24
25def test_pass_numpy():
26
27    data_in = np.array([1, 2, 3])
28    data_out, wcs_out, unit_out = wrapped_function_1(data=data_in)
29
30    assert data_out is data_in
31    assert wcs_out is None
32    assert unit_out is None
33
34
35def test_pass_all_separate():
36
37    data_in = np.array([1, 2, 3])
38    wcs_in = WCS(naxis=1)
39    unit_in = u.Jy
40
41    data_out, wcs_out, unit_out = wrapped_function_1(data=data_in, wcs=wcs_in, unit=unit_in)
42
43    assert data_out is data_in
44    assert wcs_out is wcs_in
45    assert unit_out is unit_in
46
47
48def test_pass_nddata():
49
50    data_in = np.array([1, 2, 3])
51    wcs_in = WCS(naxis=1)
52    unit_in = u.Jy
53
54    nddata_in = NDData(data_in, wcs=wcs_in, unit=unit_in)
55
56    data_out, wcs_out, unit_out = wrapped_function_1(nddata_in)
57
58    assert data_out is data_in
59    assert wcs_out is wcs_in
60    assert unit_out is unit_in
61
62
63def test_pass_nddata_and_explicit():
64
65    data_in = np.array([1, 2, 3])
66    wcs_in = WCS(naxis=1)
67    unit_in = u.Jy
68    unit_in_alt = u.mJy
69
70    nddata_in = NDData(data_in, wcs=wcs_in, unit=unit_in)
71
72    with pytest.warns(AstropyUserWarning, match="Property unit has been passed explicitly and as "
73                      "an NDData property, using explicitly specified value") as w:
74        data_out, wcs_out, unit_out = wrapped_function_1(nddata_in, unit=unit_in_alt)
75    assert len(w) == 1
76
77    assert data_out is data_in
78    assert wcs_out is wcs_in
79    assert unit_out is unit_in_alt
80
81
82def test_pass_nddata_ignored():
83
84    data_in = np.array([1, 2, 3])
85    wcs_in = WCS(naxis=1)
86    unit_in = u.Jy
87
88    nddata_in = NDData(data_in, wcs=wcs_in, unit=unit_in, mask=[0, 1, 0])
89
90    with pytest.warns(AstropyUserWarning, match="The following attributes were set on the data "
91                      "object, but will be ignored by the function: mask") as w:
92        data_out, wcs_out, unit_out = wrapped_function_1(nddata_in)
93    assert len(w) == 1
94
95    assert data_out is data_in
96    assert wcs_out is wcs_in
97    assert unit_out is unit_in
98
99
100def test_incorrect_first_argument():
101
102    with pytest.raises(ValueError) as exc:
103        @support_nddata
104        def wrapped_function_2(something, wcs=None, unit=None):
105            pass
106    assert exc.value.args[0] == "Can only wrap functions whose first positional argument is `data`"
107
108    with pytest.raises(ValueError) as exc:
109        @support_nddata
110        def wrapped_function_3(something, data, wcs=None, unit=None):
111            pass
112    assert exc.value.args[0] == "Can only wrap functions whose first positional argument is `data`"
113
114    with pytest.raises(ValueError) as exc:
115        @support_nddata
116        def wrapped_function_4(wcs=None, unit=None):
117            pass
118    assert exc.value.args[0] == "Can only wrap functions whose first positional argument is `data`"
119
120
121def test_wrap_function_no_kwargs():
122
123    @support_nddata
124    def wrapped_function_5(data, other_data):
125        return data
126
127    data_in = np.array([1, 2, 3])
128    nddata_in = NDData(data_in)
129
130    assert wrapped_function_5(nddata_in, [1, 2, 3]) is data_in
131
132
133def test_wrap_function_repack_valid():
134
135    @support_nddata(repack=True, returns=['data'])
136    def wrapped_function_5(data, other_data):
137        return data
138
139    data_in = np.array([1, 2, 3])
140    nddata_in = NDData(data_in)
141
142    nddata_out = wrapped_function_5(nddata_in, [1, 2, 3])
143
144    assert isinstance(nddata_out, NDData)
145    assert nddata_out.data is data_in
146
147
148def test_wrap_function_accepts():
149
150    class MyData(NDData):
151        pass
152
153    @support_nddata(accepts=MyData)
154    def wrapped_function_5(data, other_data):
155        return data
156
157    data_in = np.array([1, 2, 3])
158    nddata_in = NDData(data_in)
159    mydata_in = MyData(data_in)
160
161    assert wrapped_function_5(mydata_in, [1, 2, 3]) is data_in
162
163    with pytest.raises(TypeError, match="Only NDData sub-classes that inherit "
164                       "from MyData can be used by this function"):
165        wrapped_function_5(nddata_in, [1, 2, 3])
166
167
168def test_wrap_preserve_signature_docstring():
169
170    @support_nddata
171    def wrapped_function_6(data, wcs=None, unit=None):
172        """
173        An awesome function
174        """
175        pass
176
177    if wrapped_function_6.__doc__ is not None:
178        assert wrapped_function_6.__doc__.strip() == "An awesome function"
179
180    signature = inspect.signature(wrapped_function_6)
181
182    assert str(signature) == "(data, wcs=None, unit=None)"
183
184
185def test_setup_failures1():
186    # repack but no returns
187    with pytest.raises(ValueError):
188        support_nddata(repack=True)
189
190
191def test_setup_failures2():
192    # returns but no repack
193    with pytest.raises(ValueError):
194        support_nddata(returns=['data'])
195
196
197def test_setup_failures9():
198    # keeps but no repack
199    with pytest.raises(ValueError):
200        support_nddata(keeps=['unit'])
201
202
203def test_setup_failures3():
204    # same attribute in keeps and returns
205    with pytest.raises(ValueError):
206        support_nddata(repack=True, keeps=['mask'], returns=['data', 'mask'])
207
208
209def test_setup_failures4():
210    # function accepts *args
211    with pytest.raises(ValueError):
212        @support_nddata
213        def test(data, *args):
214            pass
215
216
217def test_setup_failures10():
218    # function accepts **kwargs
219    with pytest.raises(ValueError):
220        @support_nddata
221        def test(data, **kwargs):
222            pass
223
224
225def test_setup_failures5():
226    # function accepts *args (or **kwargs)
227    with pytest.raises(ValueError):
228        @support_nddata
229        def test(data, *args):
230            pass
231
232
233def test_setup_failures6():
234    # First argument is not data
235    with pytest.raises(ValueError):
236        @support_nddata
237        def test(img):
238            pass
239
240
241def test_setup_failures7():
242    # accepts CCDData but was given just an NDData
243    with pytest.raises(TypeError):
244        @support_nddata(accepts=CCDData)
245        def test(data):
246            pass
247        test(NDData(np.ones((3, 3))))
248
249
250def test_setup_failures8():
251    # function returns a different amount of arguments than specified. Using
252    # NDData here so we don't get into troubles when creating a CCDData without
253    # unit!
254    with pytest.raises(ValueError):
255        @support_nddata(repack=True, returns=['data', 'mask'])
256        def test(data):
257            return 10
258        test(NDData(np.ones((3, 3))))  # do NOT use CCDData here.
259
260
261def test_setup_failures11():
262    # function accepts no arguments
263    with pytest.raises(ValueError):
264        @support_nddata
265        def test():
266            pass
267
268
269def test_setup_numpyarray_default():
270    # It should be possible (even if it's not advisable to use mutable
271    # defaults) to have a numpy array as default value.
272    @support_nddata
273    def func(data, wcs=np.array([1, 2, 3])):
274        return wcs
275
276
277def test_still_accepts_other_input():
278    @support_nddata(repack=True, returns=['data'])
279    def test(data):
280        return data
281    assert isinstance(test(NDData(np.ones((3, 3)))), NDData)
282    assert isinstance(test(10), int)
283    assert isinstance(test([1, 2, 3]), list)
284
285
286def test_accepting_property_normal():
287    # Accepts a mask attribute and takes it from the input
288    @support_nddata
289    def test(data, mask=None):
290        return mask
291
292    ndd = NDData(np.ones((3, 3)))
293    assert test(ndd) is None
294    ndd._mask = np.zeros((3, 3))
295    assert np.all(test(ndd) == 0)
296    # Use the explicitly given one (raises a Warning)
297    with pytest.warns(AstropyUserWarning) as w:
298        assert test(ndd, mask=10) == 10
299    assert len(w) == 1
300
301
302def test_parameter_default_identical_to_explicit_passed_argument():
303    # If the default is identical to the explicitly passed argument this
304    # should still raise a Warning and use the explicit one.
305    @support_nddata
306    def func(data, meta={'a': 1}):
307        return meta
308
309    with pytest.warns(AstropyUserWarning) as w:
310        assert func(NDData(1, meta={'b': 2}), {'a': 1}) == {'a': 1}
311    assert len(w) == 1
312
313    assert func(NDData(1, meta={'b': 2})) == {'b': 2}
314
315
316def test_accepting_property_notexist():
317    # Accepts flags attribute but NDData doesn't have one
318    @support_nddata
319    def test(data, flags=10):
320        return flags
321
322    ndd = NDData(np.ones((3, 3)))
323    test(ndd)
324
325
326def test_accepting_property_translated():
327    # Accepts a error attribute and we want to pass in uncertainty!
328    @support_nddata(mask='masked')
329    def test(data, masked=None):
330        return masked
331
332    ndd = NDData(np.ones((3, 3)))
333    assert test(ndd) is None
334    ndd._mask = np.zeros((3, 3))
335    assert np.all(test(ndd) == 0)
336    # Use the explicitly given one (raises a Warning)
337    with pytest.warns(AstropyUserWarning) as w:
338        assert test(ndd, masked=10) == 10
339    assert len(w) == 1
340
341
342def test_accepting_property_meta_empty():
343    # Meta is always set (OrderedDict) so it has a special case that it's
344    # ignored if it's empty but not None
345    @support_nddata
346    def test(data, meta=None):
347        return meta
348
349    ndd = NDData(np.ones((3, 3)))
350    assert test(ndd) is None
351    ndd._meta = {'a': 10}
352    assert test(ndd) == {'a': 10}
353