1import gc
2import pathlib
3import warnings
4
5import pytest
6import numpy as np
7from numpy.testing import assert_allclose, assert_array_equal
8
9from astropy.io.fits.column import (_parse_tdisp_format, _fortran_to_python_format,
10                                    python_to_tdisp)
11
12from astropy.io.fits import HDUList, PrimaryHDU, BinTableHDU, ImageHDU, table_to_hdu
13
14from astropy.io import fits
15
16from astropy import units as u
17from astropy.table import Table, QTable, NdarrayMixin, Column
18from astropy.table.table_helpers import simple_table
19from astropy.units import allclose as quantity_allclose
20from astropy.units.format.fits import UnitScaleError
21from astropy.utils.data import get_pkg_data_filename
22from astropy.utils.exceptions import (AstropyUserWarning,
23                                      AstropyDeprecationWarning)
24from astropy.utils.misc import _NOT_OVERWRITING_MSG_MATCH
25
26from astropy.coordinates import (SkyCoord, Latitude, Longitude, Angle, EarthLocation,
27                                 SphericalRepresentation, CartesianRepresentation,
28                                 SphericalCosLatDifferential)
29from astropy.time import Time, TimeDelta
30from astropy.units.quantity import QuantityInfo
31
32
33def equal_data(a, b):
34    for name in a.dtype.names:
35        if not np.all(a[name] == b[name]):
36            return False
37    return True
38
39
40class TestSingleTable:
41
42    def setup_class(self):
43        self.data = np.array(list(zip([1, 2, 3, 4],
44                                      ['a', 'b', 'c', 'd'],
45                                      [2.3, 4.5, 6.7, 8.9])),
46                             dtype=[('a', int), ('b', 'U1'), ('c', float)])
47
48    def test_simple(self, tmpdir):
49        filename = str(tmpdir.join('test_simple.fts'))
50        t1 = Table(self.data)
51        t1.write(filename, overwrite=True)
52        t2 = Table.read(filename)
53        assert equal_data(t1, t2)
54
55    def test_simple_pathlib(self, tmpdir):
56        filename = pathlib.Path(str(tmpdir.join('test_simple.fit')))
57        t1 = Table(self.data)
58        t1.write(filename, overwrite=True)
59        t2 = Table.read(filename)
60        assert equal_data(t1, t2)
61
62    def test_simple_meta(self, tmpdir):
63        filename = str(tmpdir.join('test_simple.fits'))
64        t1 = Table(self.data)
65        t1.meta['A'] = 1
66        t1.meta['B'] = 2.3
67        t1.meta['C'] = 'spam'
68        t1.meta['comments'] = ['this', 'is', 'a', 'long', 'comment']
69        t1.meta['HISTORY'] = ['first', 'second', 'third']
70        t1.write(filename, overwrite=True)
71        t2 = Table.read(filename)
72        assert equal_data(t1, t2)
73        for key in t1.meta:
74            if isinstance(t1.meta, list):
75                for i in range(len(t1.meta[key])):
76                    assert t1.meta[key][i] == t2.meta[key][i]
77            else:
78                assert t1.meta[key] == t2.meta[key]
79
80    def test_simple_meta_conflicting(self, tmpdir):
81        filename = str(tmpdir.join('test_simple.fits'))
82        t1 = Table(self.data)
83        t1.meta['ttype1'] = 'spam'
84        with pytest.warns(AstropyUserWarning, match='Meta-data keyword ttype1 '
85                          'will be ignored since it conflicts with a FITS '
86                          'reserved keyword') as w:
87            t1.write(filename, overwrite=True)
88        assert len(w) == 1
89
90    def test_simple_noextension(self, tmpdir):
91        """
92        Test that file type is recognized without extension
93        """
94        filename = str(tmpdir.join('test_simple'))
95        t1 = Table(self.data)
96        t1.write(filename, overwrite=True, format='fits')
97        t2 = Table.read(filename)
98        assert equal_data(t1, t2)
99
100    @pytest.mark.parametrize('table_type', (Table, QTable))
101    def test_with_units(self, table_type, tmpdir):
102        filename = str(tmpdir.join('test_with_units.fits'))
103        t1 = table_type(self.data)
104        t1['a'].unit = u.m
105        t1['c'].unit = u.km / u.s
106        t1.write(filename, overwrite=True)
107        t2 = table_type.read(filename)
108        assert equal_data(t1, t2)
109        assert t2['a'].unit == u.m
110        assert t2['c'].unit == u.km / u.s
111
112    def test_with_custom_units_qtable(self, tmpdir):
113        # Test only for QTable - for Table's Column, new units are dropped
114        # (as is checked in test_write_drop_nonstandard_units).
115        filename = str(tmpdir.join('test_with_units.fits'))
116        unit = u.def_unit('bandpass_sol_lum')
117        t = QTable()
118        t['l'] = np.ones(5) * unit
119        with pytest.warns(AstropyUserWarning) as w:
120            t.write(filename, overwrite=True)
121        assert len(w) == 1
122        assert 'bandpass_sol_lum' in str(w[0].message)
123        # Just reading back, the data is fine but the unit is not recognized.
124        with pytest.warns(u.UnitsWarning, match="'bandpass_sol_lum' did not parse") as w:
125            t2 = QTable.read(filename)
126        assert len(w) == 1
127        assert isinstance(t2['l'].unit, u.UnrecognizedUnit)
128        assert str(t2['l'].unit) == 'bandpass_sol_lum'
129        assert np.all(t2['l'].value == t['l'].value)
130
131        # But if we enable the unit, it should be recognized.
132        with u.add_enabled_units(unit):
133            t3 = QTable.read(filename)
134            assert t3['l'].unit is unit
135            assert equal_data(t3, t)
136
137            # Regression check for #8897; write used to fail when a custom
138            # unit was enabled.
139            with pytest.warns(AstropyUserWarning):
140                t3.write(filename, overwrite=True)
141
142        # It should also be possible to read the file in using a unit alias,
143        # even to a unit that may not be the same.
144        with u.set_enabled_aliases({'bandpass_sol_lum': u.Lsun}):
145            t3 = QTable.read(filename)
146            assert t3['l'].unit is u.Lsun
147
148    @pytest.mark.parametrize('table_type', (Table, QTable))
149    def test_read_with_unit_aliases(self, table_type):
150        hdu = BinTableHDU(self.data)
151        hdu.columns[0].unit = 'Angstroms'
152        hdu.columns[2].unit = 'ergs/(cm.s.Angstroms)'
153        with u.set_enabled_aliases(dict(Angstroms=u.AA, ergs=u.erg)):
154            t = table_type.read(hdu)
155        assert t['a'].unit == u.AA
156        assert t['c'].unit == u.erg/(u.cm*u.s*u.AA)
157
158    @pytest.mark.parametrize('table_type', (Table, QTable))
159    def test_with_format(self, table_type, tmpdir):
160        filename = str(tmpdir.join('test_with_format.fits'))
161        t1 = table_type(self.data)
162        t1['a'].format = '{:5d}'
163        t1['b'].format = '{:>20}'
164        t1['c'].format = '{:6.2f}'
165        t1.write(filename, overwrite=True)
166        t2 = table_type.read(filename)
167        assert equal_data(t1, t2)
168        assert t2['a'].format == '{:5d}'
169        assert t2['b'].format == '{:>20}'
170        assert t2['c'].format == '{:6.2f}'
171
172    def test_masked(self, tmpdir):
173        filename = str(tmpdir.join('test_masked.fits'))
174        t1 = Table(self.data, masked=True)
175        t1.mask['a'] = [1, 0, 1, 0]
176        t1.mask['b'] = [1, 0, 0, 1]
177        t1.mask['c'] = [0, 1, 1, 0]
178        t1.write(filename, overwrite=True)
179        t2 = Table.read(filename)
180        assert equal_data(t1, t2)
181        assert np.all(t1['a'].mask == t2['a'].mask)
182        assert np.all(t1['b'].mask == t2['b'].mask)
183        assert np.all(t1['c'].mask == t2['c'].mask)
184
185    @pytest.mark.parametrize('masked', [True, False])
186    def test_masked_nan(self, masked, tmpdir):
187        """Check that masked values by default are replaced by NaN.
188
189        This should work for any shape and be independent of whether the
190        Table is formally masked or not.
191
192        """
193        filename = str(tmpdir.join('test_masked_nan.fits'))
194        a = np.ma.MaskedArray([5.25, 8.5, 3.75, 6.25], mask=[1, 0, 1, 0])
195        b = np.ma.MaskedArray([2.5, 4.5, 6.75, 8.875], mask=[1, 0, 0, 1], dtype='f4')
196        c = np.ma.stack([a, b], axis=-1)
197        t1 = Table([a, b, c], names=['a', 'b', 'c'], masked=masked)
198        t1.write(filename, overwrite=True)
199        t2 = Table.read(filename)
200        assert_array_equal(t2['a'].data, [np.nan, 8.5, np.nan, 6.25])
201        assert_array_equal(t2['b'].data, [np.nan, 4.5, 6.75, np.nan])
202        assert_array_equal(t2['c'].data, np.stack([t2['a'].data, t2['b'].data],
203                                                  axis=-1))
204        assert np.all(t1['a'].mask == t2['a'].mask)
205        assert np.all(t1['b'].mask == t2['b'].mask)
206        assert np.all(t1['c'].mask == t2['c'].mask)
207
208    def test_masked_serialize_data_mask(self, tmpdir):
209        filename = str(tmpdir.join('test_masked_nan.fits'))
210        a = np.ma.MaskedArray([5.25, 8.5, 3.75, 6.25], mask=[1, 0, 1, 0])
211        b = np.ma.MaskedArray([2.5, 4.5, 6.75, 8.875], mask=[1, 0, 0, 1])
212        c = np.ma.stack([a, b], axis=-1)
213        t1 = Table([a, b, c], names=['a', 'b', 'c'])
214        t1.write(filename, overwrite=True)
215        t2 = Table.read(filename)
216        assert_array_equal(t2['a'].data, [5.25, 8.5, 3.75, 6.25])
217        assert_array_equal(t2['b'].data, [2.5, 4.5, 6.75, 8.875])
218        assert_array_equal(t2['c'].data, np.stack([t2['a'].data, t2['b'].data],
219                                                  axis=-1))
220        assert np.all(t1['a'].mask == t2['a'].mask)
221        assert np.all(t1['b'].mask == t2['b'].mask)
222        assert np.all(t1['c'].mask == t2['c'].mask)
223
224    def test_read_from_fileobj(self, tmpdir):
225        filename = str(tmpdir.join('test_read_from_fileobj.fits'))
226        hdu = BinTableHDU(self.data)
227        hdu.writeto(filename, overwrite=True)
228        with open(filename, 'rb') as f:
229            t = Table.read(f)
230        assert equal_data(t, self.data)
231
232    def test_read_with_nonstandard_units(self):
233        hdu = BinTableHDU(self.data)
234        hdu.columns[0].unit = 'RADIANS'
235        hdu.columns[1].unit = 'spam'
236        hdu.columns[2].unit = 'millieggs'
237        t = Table.read(hdu)
238        assert equal_data(t, self.data)
239
240    @pytest.mark.parametrize('table_type', (Table, QTable))
241    def test_write_drop_nonstandard_units(self, table_type, tmpdir):
242        # While we are generous on input (see above), we are strict on
243        # output, dropping units not recognized by the fits standard.
244        filename = str(tmpdir.join('test_nonstandard_units.fits'))
245        spam = u.def_unit('spam')
246        t = table_type()
247        t['a'] = [1., 2., 3.] * spam
248        with pytest.warns(AstropyUserWarning, match='spam') as w:
249            t.write(filename)
250        assert len(w) == 1
251        if table_type is Table:
252            assert ('cannot be recovered in reading. ') in str(w[0].message)
253        else:
254            assert 'lost to non-astropy fits readers' in str(w[0].message)
255
256        with fits.open(filename) as ff:
257            hdu = ff[1]
258            assert 'TUNIT1' not in hdu.header
259
260    def test_memmap(self, tmpdir):
261        filename = str(tmpdir.join('test_simple.fts'))
262        t1 = Table(self.data)
263        t1.write(filename, overwrite=True)
264        t2 = Table.read(filename, memmap=False)
265        t3 = Table.read(filename, memmap=True)
266        assert equal_data(t2, t3)
267        # To avoid issues with --open-files, we need to remove references to
268        # data that uses memory mapping and force the garbage collection
269        del t1, t2, t3
270        gc.collect()
271
272    @pytest.mark.parametrize('memmap', (False, True))
273    def test_character_as_bytes(self, tmpdir, memmap):
274        filename = str(tmpdir.join('test_simple.fts'))
275        t1 = Table(self.data)
276        t1.write(filename, overwrite=True)
277        t2 = Table.read(filename, character_as_bytes=False, memmap=memmap)
278        t3 = Table.read(filename, character_as_bytes=True, memmap=memmap)
279        assert t2['b'].dtype.kind == 'U'
280        assert t3['b'].dtype.kind == 'S'
281        assert equal_data(t2, t3)
282        # To avoid issues with --open-files, we need to remove references to
283        # data that uses memory mapping and force the garbage collection
284        del t1, t2, t3
285        gc.collect()
286
287    def test_oned_single_element(self, tmpdir):
288        filename = str(tmpdir.join('test_oned_single_element.fits'))
289        table = Table({'x': [[1], [2]]})
290        table.write(filename, overwrite=True)
291
292        read = Table.read(filename)
293        assert read['x'].shape == (2, 1)
294        assert len(read['x'][0]) == 1
295
296    def test_write_append(self, tmpdir):
297
298        t = Table(self.data)
299        hdu = table_to_hdu(t)
300
301        def check_equal(filename, expected, start_from=1):
302            with fits.open(filename) as hdu_list:
303                assert len(hdu_list) == expected
304                for hdu_table in hdu_list[start_from:]:
305                    assert hdu_table.header == hdu.header
306                    assert np.all(hdu_table.data == hdu.data)
307
308        filename = str(tmpdir.join('test_write_append.fits'))
309        t.write(filename, append=True)
310        t.write(filename, append=True)
311        check_equal(filename, 3)
312
313        # Check the overwrite works correctly.
314        t.write(filename, append=True, overwrite=True)
315        t.write(filename, append=True)
316        check_equal(filename, 3)
317
318        # Normal write, check it's not appending.
319        t.write(filename, overwrite=True)
320        t.write(filename, overwrite=True)
321        check_equal(filename, 2)
322
323        # Now write followed by append, with different shaped tables.
324        t2 = Table(np.array([1, 2]))
325        t2.write(filename, overwrite=True)
326        t.write(filename, append=True)
327        check_equal(filename, 3, start_from=2)
328        assert equal_data(t2, Table.read(filename, hdu=1))
329
330    def test_write_overwrite(self, tmpdir):
331        t = Table(self.data)
332        filename = str(tmpdir.join('test_write_overwrite.fits'))
333        t.write(filename)
334        with pytest.raises(OSError, match=_NOT_OVERWRITING_MSG_MATCH):
335            t.write(filename)
336        t.write(filename, overwrite=True)
337
338    def test_mask_nans_on_read(self, tmpdir):
339        filename = str(tmpdir.join('test_inexact_format_parse_on_read.fits'))
340        c1 = fits.Column(name='a', array=np.array([1, 2, np.nan]), format='E')
341        table_hdu = fits.TableHDU.from_columns([c1])
342        table_hdu.writeto(filename)
343
344        tab = Table.read(filename)
345        assert any(tab.mask)
346        assert tab.mask[2]
347
348    def test_mask_null_on_read(self, tmpdir):
349        filename = str(tmpdir.join('test_null_format_parse_on_read.fits'))
350        col = fits.Column(name='a', array=np.array([1, 2, 99, 60000], dtype='u2'), format='I', null=99, bzero=32768)
351        bin_table_hdu = fits.BinTableHDU.from_columns([col])
352        bin_table_hdu.writeto(filename, overwrite=True)
353
354        tab = Table.read(filename)
355        assert any(tab.mask)
356        assert tab.mask[2]
357
358
359class TestMultipleHDU:
360
361    def setup_class(self):
362        self.data1 = np.array(list(zip([1, 2, 3, 4],
363                                       ['a', 'b', 'c', 'd'],
364                                       [2.3, 4.5, 6.7, 8.9])),
365                              dtype=[('a', int), ('b', 'U1'), ('c', float)])
366        self.data2 = np.array(list(zip([1.4, 2.3, 3.2, 4.7],
367                                       [2.3, 4.5, 6.7, 8.9])),
368                              dtype=[('p', float), ('q', float)])
369        self.data3 = np.array(list(zip([1, 2, 3, 4],
370                                       [2.3, 4.5, 6.7, 8.9])),
371                              dtype=[('A', int), ('B', float)])
372        hdu0 = PrimaryHDU()
373        hdu1 = BinTableHDU(self.data1, name='first')
374        hdu2 = BinTableHDU(self.data2, name='second')
375        hdu3 = ImageHDU(np.ones((3, 3)), name='third')
376        hdu4 = BinTableHDU(self.data3)
377
378        self.hdus = HDUList([hdu0, hdu1, hdu2, hdu3, hdu4])
379        self.hdusb = HDUList([hdu0, hdu3, hdu2, hdu1])
380        self.hdus3 = HDUList([hdu0, hdu3, hdu2])
381        self.hdus2 = HDUList([hdu0, hdu1, hdu3])
382        self.hdus1 = HDUList([hdu0, hdu1])
383
384    def teardown_class(self):
385        del self.hdus
386
387    def setup_method(self, method):
388        warnings.filterwarnings('always')
389
390    def test_read(self, tmpdir):
391        filename = str(tmpdir.join('test_read.fits'))
392        self.hdus.writeto(filename)
393        with pytest.warns(AstropyUserWarning,
394                          match=r"hdu= was not specified but multiple tables "
395                                r"are present, reading in first available "
396                                r"table \(hdu=1\)"):
397            t = Table.read(filename)
398        assert equal_data(t, self.data1)
399
400        filename = str(tmpdir.join('test_read_2.fits'))
401        self.hdusb.writeto(filename)
402        with pytest.warns(AstropyUserWarning,
403                          match=r"hdu= was not specified but multiple tables "
404                                r"are present, reading in first available "
405                                r"table \(hdu=2\)"):
406            t3 = Table.read(filename)
407        assert equal_data(t3, self.data2)
408
409    def test_read_with_hdu_0(self, tmpdir):
410        filename = str(tmpdir.join('test_read_with_hdu_0.fits'))
411        self.hdus.writeto(filename)
412        with pytest.raises(ValueError) as exc:
413            Table.read(filename, hdu=0)
414        assert exc.value.args[0] == 'No table found in hdu=0'
415
416    @pytest.mark.parametrize('hdu', [1, 'first'])
417    def test_read_with_hdu_1(self, tmpdir, hdu):
418        filename = str(tmpdir.join('test_read_with_hdu_1.fits'))
419        self.hdus.writeto(filename)
420        t = Table.read(filename, hdu=hdu)
421        assert equal_data(t, self.data1)
422
423    @pytest.mark.parametrize('hdu', [2, 'second'])
424    def test_read_with_hdu_2(self, tmpdir, hdu):
425        filename = str(tmpdir.join('test_read_with_hdu_2.fits'))
426        self.hdus.writeto(filename)
427        t = Table.read(filename, hdu=hdu)
428        assert equal_data(t, self.data2)
429
430    @pytest.mark.parametrize('hdu', [3, 'third'])
431    def test_read_with_hdu_3(self, tmpdir, hdu):
432        filename = str(tmpdir.join('test_read_with_hdu_3.fits'))
433        self.hdus.writeto(filename)
434        with pytest.raises(ValueError, match='No table found in hdu=3'):
435            Table.read(filename, hdu=hdu)
436
437    def test_read_with_hdu_4(self, tmpdir):
438        filename = str(tmpdir.join('test_read_with_hdu_4.fits'))
439        self.hdus.writeto(filename)
440        t = Table.read(filename, hdu=4)
441        assert equal_data(t, self.data3)
442
443    @pytest.mark.parametrize('hdu', [2, 3, '1', 'second', ''])
444    def test_read_with_hdu_missing(self, tmpdir, hdu):
445        filename = str(tmpdir.join('test_warn_with_hdu_1.fits'))
446        self.hdus1.writeto(filename)
447        with pytest.warns(AstropyDeprecationWarning,
448                          match=rf"Specified hdu={hdu} not found, "
449                                r"reading in first available table \(hdu=1\)"):
450            t1 = Table.read(filename, hdu=hdu)
451        assert equal_data(t1, self.data1)
452
453    @pytest.mark.parametrize('hdu', [0, 2, 'third'])
454    def test_read_with_hdu_warning(self, tmpdir, hdu):
455        filename = str(tmpdir.join('test_warn_with_hdu_2.fits'))
456        self.hdus2.writeto(filename)
457        with pytest.warns(AstropyDeprecationWarning,
458                          match=rf"No table found in specified hdu={hdu}, "
459                                r"reading in first available table \(hdu=1\)"):
460            t2 = Table.read(filename, hdu=hdu)
461        assert equal_data(t2, self.data1)
462
463    @pytest.mark.parametrize('hdu', [0, 1, 'third'])
464    def test_read_in_last_hdu(self, tmpdir, hdu):
465        filename = str(tmpdir.join('test_warn_with_hdu_3.fits'))
466        self.hdus3.writeto(filename)
467        with pytest.warns(AstropyDeprecationWarning,
468                          match=rf"No table found in specified hdu={hdu}, "
469                                r"reading in first available table \(hdu=2\)"):
470            t3 = Table.read(filename, hdu=hdu)
471        assert equal_data(t3, self.data2)
472
473    def test_read_from_hdulist(self):
474        with pytest.warns(AstropyUserWarning,
475                          match=r"hdu= was not specified but multiple tables "
476                                r"are present, reading in first available "
477                                r"table \(hdu=1\)"):
478            t = Table.read(self.hdus)
479        assert equal_data(t, self.data1)
480
481        with pytest.warns(AstropyUserWarning,
482                          match=r"hdu= was not specified but multiple tables "
483                                r"are present, reading in first available "
484                                r"table \(hdu=2\)"):
485            t3 = Table.read(self.hdusb)
486        assert equal_data(t3, self.data2)
487
488    def test_read_from_hdulist_with_hdu_0(self):
489        with pytest.raises(ValueError) as exc:
490            Table.read(self.hdus, hdu=0)
491        assert exc.value.args[0] == 'No table found in hdu=0'
492
493    @pytest.mark.parametrize('hdu', [1, 'first', None])
494    def test_read_from_hdulist_with_single_table(self, hdu):
495        t = Table.read(self.hdus1, hdu=hdu)
496        assert equal_data(t, self.data1)
497
498    @pytest.mark.parametrize('hdu', [1, 'first'])
499    def test_read_from_hdulist_with_hdu_1(self, hdu):
500        t = Table.read(self.hdus, hdu=hdu)
501        assert equal_data(t, self.data1)
502
503    @pytest.mark.parametrize('hdu', [2, 'second'])
504    def test_read_from_hdulist_with_hdu_2(self, hdu):
505        t = Table.read(self.hdus, hdu=hdu)
506        assert equal_data(t, self.data2)
507
508    @pytest.mark.parametrize('hdu', [3, 'third'])
509    def test_read_from_hdulist_with_hdu_3(self, hdu):
510        with pytest.raises(ValueError, match='No table found in hdu=3'):
511            Table.read(self.hdus, hdu=hdu)
512
513    @pytest.mark.parametrize('hdu', [0, 2, 'third'])
514    def test_read_from_hdulist_with_hdu_warning(self, hdu):
515        with pytest.warns(AstropyDeprecationWarning,
516                          match=rf"No table found in specified hdu={hdu}, "
517                                r"reading in first available table \(hdu=1\)"):
518            t2 = Table.read(self.hdus2, hdu=hdu)
519        assert equal_data(t2, self.data1)
520
521    @pytest.mark.parametrize('hdu', [2, 3, '1', 'second', ''])
522    def test_read_from_hdulist_with_hdu_missing(self, hdu):
523        with pytest.warns(AstropyDeprecationWarning,
524                          match=rf"Specified hdu={hdu} not found, "
525                                r"reading in first available table \(hdu=1\)"):
526            t1 = Table.read(self.hdus1, hdu=hdu)
527        assert equal_data(t1, self.data1)
528
529    @pytest.mark.parametrize('hdu', [0, 1, 'third'])
530    def test_read_from_hdulist_in_last_hdu(self, hdu):
531        with pytest.warns(AstropyDeprecationWarning,
532                          match=rf"No table found in specified hdu={hdu}, "
533                                r"reading in first available table \(hdu=2\)"):
534            t3 = Table.read(self.hdus3, hdu=hdu)
535        assert equal_data(t3, self.data2)
536
537    @pytest.mark.parametrize('hdu', [None, 1, 'first'])
538    def test_read_from_single_hdu(self, hdu):
539        t = Table.read(self.hdus[1])
540        assert equal_data(t, self.data1)
541
542
543def test_masking_regression_1795():
544    """
545    Regression test for #1795 - this bug originally caused columns where TNULL
546    was not defined to have their first element masked.
547    """
548    t = Table.read(get_pkg_data_filename('data/tb.fits'))
549    assert np.all(t['c1'].mask == np.array([False, False]))
550    assert not hasattr(t['c2'], 'mask')
551    assert not hasattr(t['c3'], 'mask')
552    assert not hasattr(t['c4'], 'mask')
553    assert np.all(t['c1'].data == np.array([1, 2]))
554    assert np.all(t['c2'].data == np.array([b'abc', b'xy ']))
555    assert_allclose(t['c3'].data, np.array([3.70000007153, 6.6999997139]))
556    assert np.all(t['c4'].data == np.array([False, True]))
557
558
559def test_scale_error():
560    a = [1, 4, 5]
561    b = [2.0, 5.0, 8.2]
562    c = ['x', 'y', 'z']
563    t = Table([a, b, c], names=('a', 'b', 'c'), meta={'name': 'first table'})
564    t['a'].unit = '1.2'
565    with pytest.raises(UnitScaleError, match=r"The column 'a' could not be "
566                       r"stored in FITS format because it has a scale '\(1\.2\)'"
567                       r" that is not recognized by the FITS standard\. Either "
568                       r"scale the data or change the units\."):
569        t.write('t.fits', format='fits', overwrite=True)
570
571
572@pytest.mark.parametrize('tdisp_str, format_return',
573                         [('EN10.5', ('EN', '10', '5', None)),
574                          ('F6.2', ('F', '6', '2', None)),
575                          ('B5.10', ('B', '5', '10', None)),
576                          ('E10.5E3', ('E', '10', '5', '3')),
577                          ('A21', ('A', '21', None, None))])
578def test_parse_tdisp_format(tdisp_str, format_return):
579    assert _parse_tdisp_format(tdisp_str) == format_return
580
581
582@pytest.mark.parametrize('tdisp_str, format_str_return',
583                         [('G15.4E2', '{:15.4g}'),
584                          ('Z5.10', '{:5x}'),
585                          ('I6.5', '{:6d}'),
586                          ('L8', '{:>8}'),
587                          ('E20.7', '{:20.7e}')])
588def test_fortran_to_python_format(tdisp_str, format_str_return):
589    assert _fortran_to_python_format(tdisp_str) == format_str_return
590
591
592@pytest.mark.parametrize('fmt_str, tdisp_str',
593                         [('{:3d}', 'I3'),
594                          ('3d', 'I3'),
595                          ('7.3f', 'F7.3'),
596                          ('{:>4}', 'A4'),
597                          ('{:7.4f}', 'F7.4'),
598                          ('%5.3g', 'G5.3'),
599                          ('%10s', 'A10'),
600                          ('%.4f', 'F13.4')])
601def test_python_to_tdisp(fmt_str, tdisp_str):
602    assert python_to_tdisp(fmt_str) == tdisp_str
603
604
605def test_logical_python_to_tdisp():
606    assert python_to_tdisp('{:>7}', logical_dtype=True) == 'L7'
607
608
609def test_bool_column(tmpdir):
610    """
611    Regression test for https://github.com/astropy/astropy/issues/1953
612
613    Ensures that Table columns of bools are properly written to a FITS table.
614    """
615
616    arr = np.ones(5, dtype=bool)
617    arr[::2] == np.False_
618
619    t = Table([arr])
620    t.write(str(tmpdir.join('test.fits')), overwrite=True)
621
622    with fits.open(str(tmpdir.join('test.fits'))) as hdul:
623        assert hdul[1].data['col0'].dtype == np.dtype('bool')
624        assert np.all(hdul[1].data['col0'] == arr)
625
626
627def test_unicode_column(tmpdir):
628    """
629    Test that a column of unicode strings is still written as one
630    byte-per-character in the FITS table (so long as the column can be ASCII
631    encoded).
632
633    Regression test for one of the issues fixed in
634    https://github.com/astropy/astropy/pull/4228
635    """
636
637    t = Table([np.array(['a', 'b', 'cd'])])
638    t.write(str(tmpdir.join('test.fits')), overwrite=True)
639
640    with fits.open(str(tmpdir.join('test.fits'))) as hdul:
641        assert np.all(hdul[1].data['col0'] == ['a', 'b', 'cd'])
642        assert hdul[1].header['TFORM1'] == '2A'
643
644    t2 = Table([np.array(['\N{SNOWMAN}'])])
645
646    with pytest.raises(UnicodeEncodeError):
647        t2.write(str(tmpdir.join('test.fits')), overwrite=True)
648
649
650def test_unit_warnings_read_write(tmpdir):
651    filename = str(tmpdir.join('test_unit.fits'))
652    t1 = Table([[1, 2], [3, 4]], names=['a', 'b'])
653    t1['a'].unit = 'm/s'
654    t1['b'].unit = 'not-a-unit'
655
656    with pytest.warns(u.UnitsWarning, match="'not-a-unit' did not parse as fits unit") as w:
657        t1.write(filename, overwrite=True)
658    assert len(w) == 1
659
660    Table.read(filename, hdu=1)
661
662
663def test_convert_comment_convention(tmpdir):
664    """
665    Regression test for https://github.com/astropy/astropy/issues/6079
666    """
667    filename = get_pkg_data_filename('data/stddata.fits')
668    with pytest.warns(AstropyUserWarning, match=r'hdu= was not specified but '
669                      r'multiple tables are present'):
670        t = Table.read(filename)
671
672    assert t.meta['comments'] == [
673        '',
674        ' *** End of mandatory fields ***',
675        '',
676        '',
677        ' *** Column names ***',
678        '',
679        '',
680        ' *** Column formats ***',
681        ''
682    ]
683
684
685def assert_objects_equal(obj1, obj2, attrs, compare_class=True):
686    if compare_class:
687        assert obj1.__class__ is obj2.__class__
688
689    info_attrs = ['info.name', 'info.format', 'info.unit', 'info.description', 'info.meta']
690    for attr in attrs + info_attrs:
691        a1 = obj1
692        a2 = obj2
693        for subattr in attr.split('.'):
694            try:
695                a1 = getattr(a1, subattr)
696                a2 = getattr(a2, subattr)
697            except AttributeError:
698                a1 = a1[subattr]
699                a2 = a2[subattr]
700
701        # Mixin info.meta can None instead of empty OrderedDict(), #6720 would
702        # fix this.
703        if attr == 'info.meta':
704            if a1 is None:
705                a1 = {}
706            if a2 is None:
707                a2 = {}
708
709        if isinstance(a1, np.ndarray) and a1.dtype.kind == 'f':
710            assert quantity_allclose(a1, a2, rtol=1e-15)
711        else:
712            assert np.all(a1 == a2)
713
714# Testing FITS table read/write with mixins.  This is mostly
715# copied from ECSV mixin testing.  Analogous tests also exist for HDF5.
716
717
718el = EarthLocation(x=1 * u.km, y=3 * u.km, z=5 * u.km)
719el2 = EarthLocation(x=[1, 2] * u.km, y=[3, 4] * u.km, z=[5, 6] * u.km)
720sr = SphericalRepresentation(
721    [0, 1]*u.deg, [2, 3]*u.deg, 1*u.kpc)
722cr = CartesianRepresentation(
723    [0, 1]*u.pc, [4, 5]*u.pc, [8, 6]*u.pc)
724sd = SphericalCosLatDifferential(
725    [0, 1]*u.mas/u.yr, [0, 1]*u.mas/u.yr, 10*u.km/u.s)
726srd = SphericalRepresentation(sr, differentials=sd)
727sc = SkyCoord([1, 2], [3, 4], unit='deg,deg', frame='fk4',
728              obstime='J1990.5')
729scd = SkyCoord([1, 2], [3, 4], [5, 6], unit='deg,deg,m', frame='fk4',
730               obstime=['J1990.5', 'J1991.5'])
731scdc = scd.copy()
732scdc.representation_type = 'cartesian'
733scpm = SkyCoord([1, 2], [3, 4], [5, 6], unit='deg,deg,pc',
734                pm_ra_cosdec=[7, 8]*u.mas/u.yr, pm_dec=[9, 10]*u.mas/u.yr)
735scpmrv = SkyCoord([1, 2], [3, 4], [5, 6], unit='deg,deg,pc',
736                  pm_ra_cosdec=[7, 8]*u.mas/u.yr, pm_dec=[9, 10]*u.mas/u.yr,
737                  radial_velocity=[11, 12]*u.km/u.s)
738scrv = SkyCoord([1, 2], [3, 4], [5, 6], unit='deg,deg,pc',
739                radial_velocity=[11, 12]*u.km/u.s)
740tm = Time([2450814.5, 2450815.5], format='jd', scale='tai', location=el)
741
742# NOTE: in the test below the name of the column "x" for the Quantity is
743# important since it tests the fix for #10215 (namespace clash, where "x"
744# clashes with "el2.x").
745mixin_cols = {
746    'tm': tm,
747    'dt': TimeDelta([1, 2] * u.day),
748    'sc': sc,
749    'scd': scd,
750    'scdc': scdc,
751    'scpm': scpm,
752    'scpmrv': scpmrv,
753    'scrv': scrv,
754    'x': [1, 2] * u.m,
755    'lat': Latitude([1, 2] * u.deg),
756    'lon': Longitude([1, 2] * u.deg, wrap_angle=180. * u.deg),
757    'ang': Angle([1, 2] * u.deg),
758    'el2': el2,
759    'sr': sr,
760    'cr': cr,
761    'sd': sd,
762    'srd': srd,
763}
764
765time_attrs = ['value', 'shape', 'format', 'scale', 'location']
766compare_attrs = {
767    'c1': ['data'],
768    'c2': ['data'],
769    'tm': time_attrs,
770    'dt': ['shape', 'value', 'format', 'scale'],
771    'sc': ['ra', 'dec', 'representation_type', 'frame.name'],
772    'scd': ['ra', 'dec', 'distance', 'representation_type', 'frame.name'],
773    'scdc': ['x', 'y', 'z', 'representation_type', 'frame.name'],
774    'scpm': ['ra', 'dec', 'distance', 'pm_ra_cosdec', 'pm_dec',
775             'representation_type', 'frame.name'],
776    'scpmrv': ['ra', 'dec', 'distance', 'pm_ra_cosdec', 'pm_dec',
777               'radial_velocity', 'representation_type', 'frame.name'],
778    'scrv': ['ra', 'dec', 'distance', 'radial_velocity', 'representation_type',
779             'frame.name'],
780    'x': ['value', 'unit'],
781    'lon': ['value', 'unit', 'wrap_angle'],
782    'lat': ['value', 'unit'],
783    'ang': ['value', 'unit'],
784    'el2': ['x', 'y', 'z', 'ellipsoid'],
785    'nd': ['x', 'y', 'z'],
786    'sr': ['lon', 'lat', 'distance'],
787    'cr': ['x', 'y', 'z'],
788    'sd': ['d_lon_coslat', 'd_lat', 'd_distance'],
789    'srd': ['lon', 'lat', 'distance', 'differentials.s.d_lon_coslat',
790            'differentials.s.d_lat', 'differentials.s.d_distance'],
791}
792
793
794def test_fits_mixins_qtable_to_table(tmpdir):
795    """Test writing as QTable and reading as Table.  Ensure correct classes
796    come out.
797    """
798    filename = str(tmpdir.join('test_simple.fits'))
799
800    names = sorted(mixin_cols)
801
802    t = QTable([mixin_cols[name] for name in names], names=names)
803    t.write(filename, format='fits')
804    t2 = Table.read(filename, format='fits', astropy_native=True)
805
806    assert t.colnames == t2.colnames
807
808    for name, col in t.columns.items():
809        col2 = t2[name]
810
811        # Special-case Time, which does not yet support round-tripping
812        # the format.
813        if isinstance(col2, Time):
814            col2.format = col.format
815
816        attrs = compare_attrs[name]
817        compare_class = True
818
819        if isinstance(col.info, QuantityInfo):
820            # Downgrade Quantity to Column + unit
821            assert type(col2) is Column
822            # Class-specific attributes like `value` or `wrap_angle` are lost.
823            attrs = ['unit']
824            compare_class = False
825            # Compare data values here (assert_objects_equal doesn't know how in this case)
826            assert np.all(col.value == col2)
827
828        assert_objects_equal(col, col2, attrs, compare_class)
829
830
831@pytest.mark.parametrize('table_cls', (Table, QTable))
832def test_fits_mixins_as_one(table_cls, tmpdir):
833    """Test write/read all cols at once and validate intermediate column names"""
834    filename = str(tmpdir.join('test_simple.fits'))
835    names = sorted(mixin_cols)
836
837    serialized_names = ['ang',
838                        'cr.x', 'cr.y', 'cr.z',
839                        'dt.jd1', 'dt.jd2',
840                        'el2.x', 'el2.y', 'el2.z',
841                        'lat',
842                        'lon',
843                        'sc.ra', 'sc.dec',
844                        'scd.ra', 'scd.dec', 'scd.distance',
845                        'scd.obstime.jd1', 'scd.obstime.jd2',
846                        'scdc.x', 'scdc.y', 'scdc.z',
847                        'scdc.obstime.jd1', 'scdc.obstime.jd2',
848                        'scpm.ra', 'scpm.dec', 'scpm.distance',
849                        'scpm.pm_ra_cosdec', 'scpm.pm_dec',
850                        'scpmrv.ra', 'scpmrv.dec', 'scpmrv.distance',
851                        'scpmrv.pm_ra_cosdec', 'scpmrv.pm_dec',
852                        'scpmrv.radial_velocity',
853                        'scrv.ra', 'scrv.dec', 'scrv.distance',
854                        'scrv.radial_velocity',
855                        'sd.d_lon_coslat', 'sd.d_lat', 'sd.d_distance',
856                        'sr.lon', 'sr.lat', 'sr.distance',
857                        'srd.lon', 'srd.lat', 'srd.distance',
858                        'srd.differentials.s.d_lon_coslat',
859                        'srd.differentials.s.d_lat',
860                        'srd.differentials.s.d_distance',
861                        'tm',  # serialize_method is formatted_value
862                        'x',
863                        ]
864
865    t = table_cls([mixin_cols[name] for name in names], names=names)
866    t.meta['C'] = 'spam'
867    t.meta['comments'] = ['this', 'is', 'a', 'comment']
868    t.meta['history'] = ['first', 'second', 'third']
869
870    t.write(filename, format="fits")
871
872    t2 = table_cls.read(filename, format='fits', astropy_native=True)
873    assert t2.meta['C'] == 'spam'
874    assert t2.meta['comments'] == ['this', 'is', 'a', 'comment']
875    assert t2.meta['HISTORY'] == ['first', 'second', 'third']
876
877    assert t.colnames == t2.colnames
878
879    # Read directly via fits and confirm column names
880    with fits.open(filename) as hdus:
881        assert hdus[1].columns.names == serialized_names
882
883
884@pytest.mark.parametrize('name_col', list(mixin_cols.items()))
885@pytest.mark.parametrize('table_cls', (Table, QTable))
886def test_fits_mixins_per_column(table_cls, name_col, tmpdir):
887    """Test write/read one col at a time and do detailed validation"""
888    filename = str(tmpdir.join('test_simple.fits'))
889    name, col = name_col
890
891    c = [1.0, 2.0]
892    t = table_cls([c, col, c], names=['c1', name, 'c2'])
893    t[name].info.description = 'my \n\n\n description'
894    t[name].info.meta = {'list': list(range(50)), 'dict': {'a': 'b' * 200}}
895
896    if not t.has_mixin_columns:
897        pytest.skip('column is not a mixin (e.g. Quantity subclass in Table)')
898
899    if isinstance(t[name], NdarrayMixin):
900        pytest.xfail('NdarrayMixin not supported')
901
902    t.write(filename, format="fits")
903    t2 = table_cls.read(filename, format='fits', astropy_native=True)
904
905    assert t.colnames == t2.colnames
906
907    for colname in t.colnames:
908        assert_objects_equal(t[colname], t2[colname], compare_attrs[colname])
909
910    # Special case to make sure Column type doesn't leak into Time class data
911    if name.startswith('tm'):
912        assert t2[name]._time.jd1.__class__ is np.ndarray
913        assert t2[name]._time.jd2.__class__ is np.ndarray
914
915
916def test_info_attributes_with_no_mixins(tmpdir):
917    """Even if there are no mixin columns, if there is metadata that would be lost it still
918    gets serialized
919    """
920    filename = str(tmpdir.join('test.fits'))
921    t = Table([[1.0, 2.0]])
922    t['col0'].description = 'hello' * 40
923    t['col0'].format = '{:8.4f}'
924    t['col0'].meta['a'] = {'b': 'c'}
925    t.write(filename, overwrite=True)
926
927    t2 = Table.read(filename)
928    assert t2['col0'].description == 'hello' * 40
929    assert t2['col0'].format == '{:8.4f}'
930    assert t2['col0'].meta['a'] == {'b': 'c'}
931
932
933@pytest.mark.parametrize('method', ['set_cols', 'names', 'class'])
934def test_round_trip_masked_table_serialize_mask(tmpdir, method):
935    """
936    Same as previous test but set the serialize_method to 'data_mask' so mask is
937    written out and the behavior is all correct.
938    """
939    filename = str(tmpdir.join('test.fits'))
940
941    t = simple_table(masked=True)  # int, float, and str cols with one masked element
942
943    # MaskedColumn but no masked elements.  See table the MaskedColumnInfo class
944    # _represent_as_dict() method for info about we test a column with no masked elements.
945    t['d'] = [1, 2, 3]
946
947    if method == 'set_cols':
948        for col in t.itercols():
949            col.info.serialize_method['fits'] = 'data_mask'
950        t.write(filename)
951    elif method == 'names':
952        t.write(filename, serialize_method={'a': 'data_mask', 'b': 'data_mask',
953                                            'c': 'data_mask', 'd': 'data_mask'})
954    elif method == 'class':
955        t.write(filename, serialize_method='data_mask')
956
957    t2 = Table.read(filename)
958    assert t2.masked is False
959    assert t2.colnames == t.colnames
960    for name in t2.colnames:
961        assert np.all(t2[name].mask == t[name].mask)
962        assert np.all(t2[name] == t[name])
963
964        # Data under the mask round-trips also (unmask data to show this).
965        t[name].mask = False
966        t2[name].mask = False
967        assert np.all(t2[name] == t[name])
968
969
970def test_meta_not_modified(tmpdir):
971    filename = str(tmpdir.join('test.fits'))
972    t = Table(data=[Column([1, 2], 'a', description='spam')])
973    t.meta['comments'] = ['a', 'b']
974    assert len(t.meta) == 1
975    t.write(filename)
976    assert len(t.meta) == 1
977    assert t.meta['comments'] == ['a', 'b']
978