1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3"""
4This module tests some of the methods related to YAML serialization.
5"""
6
7from io import StringIO
8
9import pytest
10import numpy as np
11import yaml
12
13from astropy.coordinates import (SkyCoord, EarthLocation, Angle, Longitude, Latitude,
14                                 SphericalRepresentation, UnitSphericalRepresentation,
15                                 CartesianRepresentation, SphericalCosLatDifferential,
16                                 SphericalDifferential, CartesianDifferential)
17from astropy import units as u
18from astropy.time import Time
19from astropy.table import QTable, SerializedColumn
20from astropy.coordinates.tests.test_representation import representation_equal
21
22from astropy.io.misc.yaml import load, load_all, dump  # noqa
23
24
25@pytest.mark.parametrize('c', [True, np.uint8(8), np.int16(4),
26                               np.int32(1), np.int64(3), np.int64(2**63 - 1),
27                               2.0, np.float64(),
28                               3+4j, np.complex_(3 + 4j),
29                               np.complex64(3 + 4j),
30                               np.complex128(1. - 2**-52 + 1j * (1. - 2**-52))])
31def test_numpy_types(c):
32    cy = load(dump(c))
33    assert c == cy
34
35
36@pytest.mark.parametrize('c', [u.m, u.m / u.s, u.hPa, u.dimensionless_unscaled])
37def test_unit(c):
38    cy = load(dump(c))
39    if isinstance(c, u.CompositeUnit):
40        assert c == cy
41    else:
42        assert c is cy
43
44
45@pytest.mark.parametrize('c', [u.Unit('bakers_dozen', 13*u.one),
46                               u.def_unit('magic')])
47def test_custom_unit(c):
48    s = dump(c)
49    with pytest.warns(u.UnitsWarning, match=f"'{c!s}' did not parse") as w:
50        cy = load(s)
51    assert len(w) == 1
52    assert isinstance(cy, u.UnrecognizedUnit)
53    assert str(cy) == str(c)
54
55    with u.add_enabled_units(c):
56        cy2 = load(s)
57        assert cy2 is c
58
59
60@pytest.mark.parametrize('c', [Angle('1 2 3', unit='deg'),
61                               Longitude('1 2 3', unit='deg'),
62                               Latitude('1 2 3', unit='deg'),
63                               [[1], [3]] * u.m,
64                               np.array([[1, 2], [3, 4]], order='F'),
65                               np.array([[1, 2], [3, 4]], order='C'),
66                               np.array([1, 2, 3, 4])[::2]])
67def test_ndarray_subclasses(c):
68    cy = load(dump(c))
69
70    assert np.all(c == cy)
71    assert c.shape == cy.shape
72    assert type(c) is type(cy)
73
74    cc = 'C_CONTIGUOUS'
75    fc = 'F_CONTIGUOUS'
76    if c.flags[cc] or c.flags[fc]:
77        assert c.flags[cc] == cy.flags[cc]
78        assert c.flags[fc] == cy.flags[fc]
79    else:
80        # Original was not contiguous but round-trip version
81        # should be c-contig.
82        assert cy.flags[cc]
83
84    if hasattr(c, 'unit'):
85        assert c.unit == cy.unit
86
87
88def compare_coord(c, cy):
89    assert c.shape == cy.shape
90    assert c.frame.name == cy.frame.name
91
92    assert list(c.get_frame_attr_names()) == list(cy.get_frame_attr_names())
93    for attr in c.get_frame_attr_names():
94        assert getattr(c, attr) == getattr(cy, attr)
95
96    assert (list(c.representation_component_names) ==
97            list(cy.representation_component_names))
98    for name in c.representation_component_names:
99        assert np.all(getattr(c, attr) == getattr(cy, attr))
100
101
102@pytest.mark.parametrize('frame', ['fk4', 'altaz'])
103def test_skycoord(frame):
104
105    c = SkyCoord([[1, 2], [3, 4]], [[5, 6], [7, 8]],
106                 unit='deg', frame=frame,
107                 obstime=Time('2016-01-02'),
108                 location=EarthLocation(1000, 2000, 3000, unit=u.km))
109    cy = load(dump(c))
110    compare_coord(c, cy)
111
112
113@pytest.mark.parametrize('rep', [
114    CartesianRepresentation(1*u.m, 2.*u.m, 3.*u.m),
115    SphericalRepresentation([[1, 2], [3, 4]]*u.deg,
116                            [[5, 6], [7, 8]]*u.deg,
117                            10*u.pc),
118    UnitSphericalRepresentation(0*u.deg, 10*u.deg),
119    SphericalCosLatDifferential([[1.], [2.]]*u.mas/u.yr,
120                                [4., 5.]*u.mas/u.yr,
121                                [[[10]], [[20]]]*u.km/u.s),
122    CartesianDifferential([10, 20, 30]*u.km/u.s),
123    CartesianRepresentation(
124        [1, 2, 3]*u.m,
125        differentials=CartesianDifferential([10, 20, 30]*u.km/u.s)),
126    SphericalRepresentation(
127        [[1, 2], [3, 4]]*u.deg, [[5, 6], [7, 8]]*u.deg, 10*u.pc,
128        differentials={
129            's': SphericalDifferential([[0., 1.], [2., 3.]]*u.mas/u.yr,
130                                       [[4., 5.], [6., 7.]]*u.mas/u.yr,
131                                       10*u.km/u.s)})])
132def test_representations(rep):
133    rrep = load(dump(rep))
134    assert np.all(representation_equal(rrep, rep))
135
136
137def _get_time():
138    t = Time([[1], [2]], format='cxcsec',
139             location=EarthLocation(1000, 2000, 3000, unit=u.km))
140    t.format = 'iso'
141    t.precision = 5
142    t.delta_ut1_utc = np.array([[3.0], [4.0]])
143    t.delta_tdb_tt = np.array([[5.0], [6.0]])
144    t.out_subfmt = 'date_hm'
145
146    return t
147
148
149def compare_time(t, ty):
150    assert type(t) is type(ty)
151    assert np.all(t == ty)
152    for attr in ('shape', 'jd1', 'jd2', 'format', 'scale', 'precision', 'in_subfmt',
153                 'out_subfmt', 'location', 'delta_ut1_utc', 'delta_tdb_tt'):
154        assert np.all(getattr(t, attr) == getattr(ty, attr))
155
156
157def test_time():
158    t = _get_time()
159    ty = load(dump(t))
160    compare_time(t, ty)
161
162
163def test_timedelta():
164    t = _get_time()
165    dt = t - t + 0.1234556 * u.s
166    dty = load(dump(dt))
167
168    assert type(dt) is type(dty)
169    for attr in ('shape', 'jd1', 'jd2', 'format', 'scale'):
170        assert np.all(getattr(dt, attr) == getattr(dty, attr))
171
172
173def test_serialized_column():
174    sc = SerializedColumn({'name': 'hello', 'other': 1, 'other2': 2.0})
175    scy = load(dump(sc))
176
177    assert sc == scy
178
179
180def test_load_all():
181    t = _get_time()
182    unit = u.m / u.s
183    c = SkyCoord([[1, 2], [3, 4]], [[5, 6], [7, 8]],
184                 unit='deg', frame='fk4',
185                 obstime=Time('2016-01-02'),
186                 location=EarthLocation(1000, 2000, 3000, unit=u.km))
187
188    # Make a multi-document stream
189    out = ('---\n' + dump(t)
190           + '---\n' + dump(unit)
191           + '---\n' + dump(c))
192
193    ty, unity, cy = list(load_all(out))
194
195    compare_time(t, ty)
196    compare_coord(c, cy)
197    assert unity == unit
198
199
200def test_ecsv_astropy_objects_in_meta():
201    """
202    Test that astropy core objects in ``meta`` are serialized.
203    """
204    t = QTable([[1, 2] * u.m, [4, 5]], names=['a', 'b'])
205    tm = _get_time()
206    c = SkyCoord([[1, 2], [3, 4]], [[5, 6], [7, 8]],
207                 unit='deg', frame='fk4',
208                 obstime=Time('2016-01-02'),
209                 location=EarthLocation(1000, 2000, 3000, unit=u.km))
210    unit = u.m / u.s
211
212    t.meta = {'tm': tm, 'c': c, 'unit': unit}
213    out = StringIO()
214    t.write(out, format='ascii.ecsv')
215    t2 = QTable.read(out.getvalue(), format='ascii.ecsv')
216
217    compare_time(tm, t2.meta['tm'])
218    compare_coord(c, t2.meta['c'])
219    assert t2.meta['unit'] == unit
220