1"""
2A module containing unit tests for the `bitmask` module.
3
4Licensed under a 3-clause BSD style license - see LICENSE.rst
5
6"""
7import warnings
8import numpy as np
9import pytest
10
11from astropy.nddata import bitmask
12
13
14MAX_INT_TYPE = np.maximum_sctype(np.int_)
15MAX_UINT_TYPE = np.maximum_sctype(np.uint)
16MAX_UINT_FLAG = np.left_shift(
17    MAX_UINT_TYPE(1),
18    MAX_UINT_TYPE(np.iinfo(MAX_UINT_TYPE).bits - 1)
19)
20MAX_INT_FLAG = np.left_shift(
21    MAX_INT_TYPE(1),
22    MAX_INT_TYPE(np.iinfo(MAX_INT_TYPE).bits - 2)
23)
24SUPER_LARGE_FLAG = 1 << np.iinfo(MAX_UINT_TYPE).bits
25EXTREME_TEST_DATA = np.array([
26        0, 1, 1 + 1 << 2, MAX_INT_FLAG, ~0, MAX_INT_TYPE(MAX_UINT_FLAG),
27        1 + MAX_INT_TYPE(MAX_UINT_FLAG)
28], dtype=MAX_INT_TYPE)
29
30
31@pytest.mark.parametrize('flag', [0, -1])
32def test_nonpositive_not_a_bit_flag(flag):
33    assert not bitmask._is_bit_flag(n=flag)
34
35
36@pytest.mark.parametrize('flag', [
37    1, MAX_UINT_FLAG, int(MAX_UINT_FLAG), SUPER_LARGE_FLAG
38])
39def test_is_bit_flag(flag):
40    assert bitmask._is_bit_flag(n=flag)
41
42
43@pytest.mark.parametrize('number', [0, 1, MAX_UINT_FLAG, SUPER_LARGE_FLAG])
44def test_is_int(number):
45    assert bitmask._is_int(number)
46
47
48@pytest.mark.parametrize('number', ['1', True, 1.0])
49def test_nonint_is_not_an_int(number):
50    assert not bitmask._is_int(number)
51
52
53@pytest.mark.parametrize('flag,flip,expected', [
54    (3, None, 3),
55    (3, True, -4),
56    (3, False, 3),
57    ([1, 2], False, 3),
58    ([1, 2], True, -4)
59])
60def test_interpret_valid_int_bit_flags(flag, flip, expected):
61    assert(
62        bitmask.interpret_bit_flags(bit_flags=flag, flip_bits=flip) == expected
63    )
64
65
66@pytest.mark.parametrize('flag', [None, ' ', 'None', 'Indef'])
67def test_interpret_none_bit_flags_as_None(flag):
68    assert bitmask.interpret_bit_flags(bit_flags=flag) is None
69
70
71@pytest.mark.parametrize('flag,expected', [
72    ('1', 1),
73    ('~-1', ~(-1)),
74    ('~1', ~1),
75    ('1,2', 3),
76    ('1|2', 3),
77    ('1+2', 3),
78    ('(1,2)', 3),
79    ('(1+2)', 3),
80    ('~1,2', ~3),
81    ('~1+2', ~3),
82    ('~(1,2)', ~3),
83    ('~(1+2)', ~3)
84])
85def test_interpret_valid_str_bit_flags(flag, expected):
86    assert(
87        bitmask.interpret_bit_flags(bit_flags=flag) == expected
88    )
89
90
91@pytest.mark.parametrize('flag,expected', [
92    ('CR', 1),
93    ('~CR', ~1),
94    ('CR|HOT', 3),
95    ('CR,HOT', 3),
96    ('CR+HOT', 3),
97    (['CR', 'HOT'], 3),
98    ('(CR,HOT)', 3),
99    ('(HOT+CR)', 3),
100    ('~HOT,CR', ~3),
101    ('~CR+HOT', ~3),
102    ('~(HOT,CR)', ~3),
103    ('~(HOT|CR)', ~3),
104    ('~(CR+HOT)', ~3)
105])
106def test_interpret_valid_mnemonic_bit_flags(flag, expected):
107    flagmap = bitmask.extend_bit_flag_map('DetectorMap', CR=1, HOT=2)
108
109    assert(
110        bitmask.interpret_bit_flags(bit_flags=flag, flag_name_map=flagmap)
111        == expected
112    )
113
114
115@pytest.mark.parametrize('flag,flip', [
116    (None, True),
117    (' ', True),
118    ('None', True),
119    ('Indef', True),
120    (None, False),
121    (' ', False),
122    ('None', False),
123    ('Indef', False),
124    ('1', True),
125    ('1', False)
126])
127def test_interpret_None_or_str_and_flip_incompatibility(flag, flip):
128    with pytest.raises(TypeError):
129        bitmask.interpret_bit_flags(bit_flags=flag, flip_bits=flip)
130
131
132@pytest.mark.parametrize('flag', [True, 1.0, [1.0], object])
133def test_interpret_wrong_flag_type(flag):
134    with pytest.raises(TypeError):
135        bitmask.interpret_bit_flags(bit_flags=flag)
136
137
138@pytest.mark.parametrize('flag', ['SOMETHING', '1.0,2,3'])
139def test_interpret_wrong_string_int_format(flag):
140    with pytest.raises(ValueError):
141        bitmask.interpret_bit_flags(bit_flags=flag)
142
143
144def test_interpret_duplicate_flag_warning():
145    with warnings.catch_warnings(record=True) as w:
146        warnings.simplefilter("always")
147        assert bitmask.interpret_bit_flags([2, 4, 4]) == 6
148        assert len(w)
149        assert issubclass(w[-1].category, UserWarning)
150        assert "Duplicate" in str(w[-1].message)
151
152
153@pytest.mark.parametrize('flag', [[1, 2, 3], '1, 2, 3'])
154def test_interpret_non_flag(flag):
155    with pytest.raises(ValueError):
156        bitmask.interpret_bit_flags(bit_flags=flag)
157
158
159def test_interpret_allow_single_value_str_nonflags():
160    assert bitmask.interpret_bit_flags(bit_flags=str(3)) == 3
161
162
163@pytest.mark.parametrize('flag', [
164    '~',
165    '( )',
166    '(~1,2)',
167    '~(1,2',
168    '1,~2',
169    '1,(2,4)',
170    '1,2+4',
171    '1+4,2',
172    '1|4+2'
173])
174def test_interpret_bad_str_syntax(flag):
175    with pytest.raises(ValueError):
176        bitmask.interpret_bit_flags(bit_flags=flag)
177
178
179def test_bitfield_must_be_integer_check():
180    with pytest.raises(TypeError):
181        bitmask.bitfield_to_boolean_mask(1.0, 1)
182
183
184@pytest.mark.parametrize('data,flags,flip,goodval,dtype,ref', [
185    (EXTREME_TEST_DATA, None, None, True, np.bool_,
186     EXTREME_TEST_DATA.size * [1]),
187    (EXTREME_TEST_DATA, None, None, False, np.bool_,
188     EXTREME_TEST_DATA.size * [0]),
189    (EXTREME_TEST_DATA, [1, MAX_UINT_FLAG], False, True, np.bool_,
190     [1, 1, 0, 0, 0, 1, 1]),
191    (EXTREME_TEST_DATA, None, None, True, np.bool_,
192     EXTREME_TEST_DATA.size * [1]),
193    (EXTREME_TEST_DATA, [1, MAX_UINT_FLAG], False, False, np.bool_,
194     [0, 0, 1, 1, 1, 0, 0]),
195    (EXTREME_TEST_DATA, [1, MAX_UINT_FLAG], True, True, np.int8,
196     [1, 0, 1, 1, 0, 0, 0])
197])
198def test_bitfield_to_boolean_mask(data, flags, flip, goodval, dtype, ref):
199    mask = bitmask.bitfield_to_boolean_mask(
200        bitfield=data,
201        ignore_flags=flags,
202        flip_bits=flip,
203        good_mask_value=goodval,
204        dtype=dtype
205    )
206
207    assert(mask.dtype == dtype)
208    assert np.all(mask == ref)
209
210
211@pytest.mark.parametrize('flag', [(4, 'flag1'), 8])
212def test_bitflag(flag):
213    f = bitmask.BitFlag(flag)
214    if isinstance(flag, tuple):
215        assert f == flag[0]
216        assert f.__doc__ == flag[1]
217
218        f = bitmask.BitFlag(*flag)
219        assert f == flag[0]
220        assert f.__doc__ == flag[1]
221
222    else:
223        assert f == flag
224
225
226def test_bitflag_docs2():
227    with pytest.raises(ValueError):
228        bitmask.BitFlag((1, 'docs1'), 'docs2')
229
230
231@pytest.mark.parametrize('flag', [0, 3])
232def test_bitflag_not_pow2(flag):
233    with pytest.raises(bitmask.InvalidBitFlag):
234        bitmask.BitFlag(flag, 'custom flag')
235
236
237@pytest.mark.parametrize('flag', [0.0, True, '1'])
238def test_bitflag_not_int_flag(flag):
239    with pytest.raises(bitmask.InvalidBitFlag):
240        bitmask.BitFlag((flag, 'custom flag'))
241
242
243@pytest.mark.parametrize('caching', [True, False])
244def test_basic_map(monkeypatch, caching):
245    monkeypatch.setattr(bitmask, '_ENABLE_BITFLAG_CACHING', False)
246
247    class ObservatoryDQMap(bitmask.BitFlagNameMap):
248        _not_a_flag = 1
249        CR = 1, 'cosmic ray'
250        HOT = 2
251        DEAD = 4
252
253    class DetectorMap(ObservatoryDQMap):
254        __version__ = '1.0'
255        _not_a_flag = 181
256        READOUT_ERR = 16
257
258    assert ObservatoryDQMap.cr == 1
259    assert ObservatoryDQMap.cr.__doc__ == 'cosmic ray'
260    assert DetectorMap.READOUT_ERR == 16
261
262
263@pytest.mark.parametrize('caching', [True, False])
264def test_extend_map(monkeypatch, caching):
265    monkeypatch.setattr(bitmask, '_ENABLE_BITFLAG_CACHING', caching)
266
267    class ObservatoryDQMap(bitmask.BitFlagNameMap):
268        CR = 1
269        HOT = 2
270        DEAD = 4
271
272    DetectorMap = bitmask.extend_bit_flag_map(
273        'DetectorMap', ObservatoryDQMap,
274        __version__='1.0',
275        DEAD=4,
276        READOUT_ERR=16
277    )
278
279    assert DetectorMap.CR == 1
280    assert DetectorMap.readout_err == 16
281
282
283@pytest.mark.parametrize('caching', [True, False])
284def test_extend_map_redefine_flag(monkeypatch, caching):
285    monkeypatch.setattr(bitmask, '_ENABLE_BITFLAG_CACHING', caching)
286
287    class ObservatoryDQMap(bitmask.BitFlagNameMap):
288        CR = 1
289        HOT = 2
290        DEAD = 4
291
292    with pytest.raises(AttributeError):
293        bitmask.extend_bit_flag_map(
294            'DetectorMap',
295            ObservatoryDQMap,
296            __version__='1.0',
297            DEAD=32
298        )
299
300    with pytest.raises(AttributeError):
301        bitmask.extend_bit_flag_map(
302            'DetectorMap',
303            ObservatoryDQMap,
304            __version__='1.0',
305            DEAD=32,
306            dead=64
307        )
308
309
310@pytest.mark.parametrize('caching', [True, False])
311def test_map_redefine_flag(monkeypatch, caching):
312    monkeypatch.setattr(bitmask, '_ENABLE_BITFLAG_CACHING', caching)
313
314    class ObservatoryDQMap(bitmask.BitFlagNameMap):
315        _not_a_flag = 8
316        CR = 1
317        HOT = 2
318        DEAD = 4
319
320    with pytest.raises(AttributeError):
321        class DetectorMap1(ObservatoryDQMap):
322            __version__ = '1.0'
323            CR = 16
324
325    with pytest.raises(AttributeError):
326        class DetectorMap2(ObservatoryDQMap):
327            SHADE = 8
328            _FROZEN = 16
329
330        DetectorMap2.novel = 32
331
332    with pytest.raises(AttributeError):
333        bitmask.extend_bit_flag_map(
334            'DetectorMap', ObservatoryDQMap,
335            READOUT_ERR=16,
336            SHADE=32,
337            readout_err=128
338        )
339
340
341def test_map_cant_modify_version():
342    class ObservatoryDQMap(bitmask.BitFlagNameMap):
343        __version__ = '1.2.3'
344        CR = 1
345
346    assert ObservatoryDQMap.__version__ == '1.2.3'
347    assert ObservatoryDQMap.CR == 1
348
349    with pytest.raises(AttributeError):
350        ObservatoryDQMap.__version__ = '3.2.1'
351
352
353@pytest.mark.parametrize('flag', [0, 3])
354def test_map_not_bit_flag(flag):
355    with pytest.raises(ValueError):
356        bitmask.extend_bit_flag_map('DetectorMap', DEAD=flag)
357
358    with pytest.raises(ValueError):
359        class DetectorMap(bitmask.BitFlagNameMap):
360            DEAD=flag
361
362
363@pytest.mark.parametrize('flag', [0.0, True, '1'])
364def test_map_not_int_flag(flag):
365    with pytest.raises(bitmask.InvalidBitFlag):
366        bitmask.extend_bit_flag_map('DetectorMap', DEAD=flag)
367
368    with pytest.raises(bitmask.InvalidBitFlag):
369        class ObservatoryDQMap(bitmask.BitFlagNameMap):
370            CR = flag
371
372
373def test_map_access_undefined_flag():
374    DetectorMap = bitmask.extend_bit_flag_map('DetectorMap', DEAD=1)
375
376    with pytest.raises(AttributeError):
377        DetectorMap.DEAD1
378
379    with pytest.raises(AttributeError):
380        DetectorMap['DEAD1']
381
382
383def test_map_delete_flag():
384    DetectorMap = bitmask.extend_bit_flag_map('DetectorMap', DEAD=1)
385
386    with pytest.raises(AttributeError):
387        del DetectorMap.DEAD1
388
389    with pytest.raises(AttributeError):
390        del DetectorMap['DEAD1']
391
392
393def test_map_repr():
394    DetectorMap = bitmask.extend_bit_flag_map('DetectorMap', DEAD=1)
395    assert repr(DetectorMap) == "<BitFlagNameMap 'DetectorMap'>"
396
397
398def test_map_add_flags():
399    map1 = bitmask.extend_bit_flag_map('DetectorMap', CR=1)
400
401    map2 = map1 + {'HOT': 2, 'DEAD': (4, 'a really dead pixel')}
402    assert map2.CR == 1
403    assert map2.HOT == 2
404    assert map2.DEAD.__doc__ == 'a really dead pixel'
405    assert map2.DEAD == 4
406
407    map2 = map1 + [('HOT', 2), ('DEAD', 4)]
408    assert map2.CR == 1
409    assert map2.HOT == 2
410
411    map2 = map1 + ('HOT', 2)
412    assert map2.CR == 1
413    assert map2.HOT == 2
414