1# Licensed under a 3-clause BSD style license - see LICENSE.rst 2 3""" 4This module tests some of the methods related to YAML serialization. 5""" 6 7from io import StringIO 8 9import pytest 10import numpy as np 11import yaml 12 13from astropy.coordinates import (SkyCoord, EarthLocation, Angle, Longitude, Latitude, 14 SphericalRepresentation, UnitSphericalRepresentation, 15 CartesianRepresentation, SphericalCosLatDifferential, 16 SphericalDifferential, CartesianDifferential) 17from astropy import units as u 18from astropy.time import Time 19from astropy.table import QTable, SerializedColumn 20from astropy.coordinates.tests.test_representation import representation_equal 21 22from astropy.io.misc.yaml import load, load_all, dump # noqa 23 24 25@pytest.mark.parametrize('c', [True, np.uint8(8), np.int16(4), 26 np.int32(1), np.int64(3), np.int64(2**63 - 1), 27 2.0, np.float64(), 28 3+4j, np.complex_(3 + 4j), 29 np.complex64(3 + 4j), 30 np.complex128(1. - 2**-52 + 1j * (1. - 2**-52))]) 31def test_numpy_types(c): 32 cy = load(dump(c)) 33 assert c == cy 34 35 36@pytest.mark.parametrize('c', [u.m, u.m / u.s, u.hPa, u.dimensionless_unscaled]) 37def test_unit(c): 38 cy = load(dump(c)) 39 if isinstance(c, u.CompositeUnit): 40 assert c == cy 41 else: 42 assert c is cy 43 44 45@pytest.mark.parametrize('c', [u.Unit('bakers_dozen', 13*u.one), 46 u.def_unit('magic')]) 47def test_custom_unit(c): 48 s = dump(c) 49 with pytest.warns(u.UnitsWarning, match=f"'{c!s}' did not parse") as w: 50 cy = load(s) 51 assert len(w) == 1 52 assert isinstance(cy, u.UnrecognizedUnit) 53 assert str(cy) == str(c) 54 55 with u.add_enabled_units(c): 56 cy2 = load(s) 57 assert cy2 is c 58 59 60@pytest.mark.parametrize('c', [Angle('1 2 3', unit='deg'), 61 Longitude('1 2 3', unit='deg'), 62 Latitude('1 2 3', unit='deg'), 63 [[1], [3]] * u.m, 64 np.array([[1, 2], [3, 4]], order='F'), 65 np.array([[1, 2], [3, 4]], order='C'), 66 np.array([1, 2, 3, 4])[::2]]) 67def test_ndarray_subclasses(c): 68 cy = load(dump(c)) 69 70 assert np.all(c == cy) 71 assert c.shape == cy.shape 72 assert type(c) is type(cy) 73 74 cc = 'C_CONTIGUOUS' 75 fc = 'F_CONTIGUOUS' 76 if c.flags[cc] or c.flags[fc]: 77 assert c.flags[cc] == cy.flags[cc] 78 assert c.flags[fc] == cy.flags[fc] 79 else: 80 # Original was not contiguous but round-trip version 81 # should be c-contig. 82 assert cy.flags[cc] 83 84 if hasattr(c, 'unit'): 85 assert c.unit == cy.unit 86 87 88def compare_coord(c, cy): 89 assert c.shape == cy.shape 90 assert c.frame.name == cy.frame.name 91 92 assert list(c.get_frame_attr_names()) == list(cy.get_frame_attr_names()) 93 for attr in c.get_frame_attr_names(): 94 assert getattr(c, attr) == getattr(cy, attr) 95 96 assert (list(c.representation_component_names) == 97 list(cy.representation_component_names)) 98 for name in c.representation_component_names: 99 assert np.all(getattr(c, attr) == getattr(cy, attr)) 100 101 102@pytest.mark.parametrize('frame', ['fk4', 'altaz']) 103def test_skycoord(frame): 104 105 c = SkyCoord([[1, 2], [3, 4]], [[5, 6], [7, 8]], 106 unit='deg', frame=frame, 107 obstime=Time('2016-01-02'), 108 location=EarthLocation(1000, 2000, 3000, unit=u.km)) 109 cy = load(dump(c)) 110 compare_coord(c, cy) 111 112 113@pytest.mark.parametrize('rep', [ 114 CartesianRepresentation(1*u.m, 2.*u.m, 3.*u.m), 115 SphericalRepresentation([[1, 2], [3, 4]]*u.deg, 116 [[5, 6], [7, 8]]*u.deg, 117 10*u.pc), 118 UnitSphericalRepresentation(0*u.deg, 10*u.deg), 119 SphericalCosLatDifferential([[1.], [2.]]*u.mas/u.yr, 120 [4., 5.]*u.mas/u.yr, 121 [[[10]], [[20]]]*u.km/u.s), 122 CartesianDifferential([10, 20, 30]*u.km/u.s), 123 CartesianRepresentation( 124 [1, 2, 3]*u.m, 125 differentials=CartesianDifferential([10, 20, 30]*u.km/u.s)), 126 SphericalRepresentation( 127 [[1, 2], [3, 4]]*u.deg, [[5, 6], [7, 8]]*u.deg, 10*u.pc, 128 differentials={ 129 's': SphericalDifferential([[0., 1.], [2., 3.]]*u.mas/u.yr, 130 [[4., 5.], [6., 7.]]*u.mas/u.yr, 131 10*u.km/u.s)})]) 132def test_representations(rep): 133 rrep = load(dump(rep)) 134 assert np.all(representation_equal(rrep, rep)) 135 136 137def _get_time(): 138 t = Time([[1], [2]], format='cxcsec', 139 location=EarthLocation(1000, 2000, 3000, unit=u.km)) 140 t.format = 'iso' 141 t.precision = 5 142 t.delta_ut1_utc = np.array([[3.0], [4.0]]) 143 t.delta_tdb_tt = np.array([[5.0], [6.0]]) 144 t.out_subfmt = 'date_hm' 145 146 return t 147 148 149def compare_time(t, ty): 150 assert type(t) is type(ty) 151 assert np.all(t == ty) 152 for attr in ('shape', 'jd1', 'jd2', 'format', 'scale', 'precision', 'in_subfmt', 153 'out_subfmt', 'location', 'delta_ut1_utc', 'delta_tdb_tt'): 154 assert np.all(getattr(t, attr) == getattr(ty, attr)) 155 156 157def test_time(): 158 t = _get_time() 159 ty = load(dump(t)) 160 compare_time(t, ty) 161 162 163def test_timedelta(): 164 t = _get_time() 165 dt = t - t + 0.1234556 * u.s 166 dty = load(dump(dt)) 167 168 assert type(dt) is type(dty) 169 for attr in ('shape', 'jd1', 'jd2', 'format', 'scale'): 170 assert np.all(getattr(dt, attr) == getattr(dty, attr)) 171 172 173def test_serialized_column(): 174 sc = SerializedColumn({'name': 'hello', 'other': 1, 'other2': 2.0}) 175 scy = load(dump(sc)) 176 177 assert sc == scy 178 179 180def test_load_all(): 181 t = _get_time() 182 unit = u.m / u.s 183 c = SkyCoord([[1, 2], [3, 4]], [[5, 6], [7, 8]], 184 unit='deg', frame='fk4', 185 obstime=Time('2016-01-02'), 186 location=EarthLocation(1000, 2000, 3000, unit=u.km)) 187 188 # Make a multi-document stream 189 out = ('---\n' + dump(t) 190 + '---\n' + dump(unit) 191 + '---\n' + dump(c)) 192 193 ty, unity, cy = list(load_all(out)) 194 195 compare_time(t, ty) 196 compare_coord(c, cy) 197 assert unity == unit 198 199 200def test_ecsv_astropy_objects_in_meta(): 201 """ 202 Test that astropy core objects in ``meta`` are serialized. 203 """ 204 t = QTable([[1, 2] * u.m, [4, 5]], names=['a', 'b']) 205 tm = _get_time() 206 c = SkyCoord([[1, 2], [3, 4]], [[5, 6], [7, 8]], 207 unit='deg', frame='fk4', 208 obstime=Time('2016-01-02'), 209 location=EarthLocation(1000, 2000, 3000, unit=u.km)) 210 unit = u.m / u.s 211 212 t.meta = {'tm': tm, 'c': c, 'unit': unit} 213 out = StringIO() 214 t.write(out, format='ascii.ecsv') 215 t2 = QTable.read(out.getvalue(), format='ascii.ecsv') 216 217 compare_time(tm, t2.meta['tm']) 218 compare_coord(c, t2.meta['c']) 219 assert t2.meta['unit'] == unit 220