1# coding: utf-8 2# Licensed under a 3-clause BSD style license - see LICENSE.rst 3""" 4Test Structured units and quantities. 5""" 6import pytest 7import numpy as np 8from numpy.testing import assert_array_equal 9 10from astropy import units as u 11from astropy.units import StructuredUnit, Unit, UnitBase, Quantity 12from astropy.utils.masked import Masked 13 14 15class StructuredTestBase: 16 @classmethod 17 def setup_class(self): 18 self.pv_dtype = np.dtype([('p', 'f8'), ('v', 'f8')]) 19 self.pv_t_dtype = np.dtype([('pv', self.pv_dtype), ('t', 'f8')]) 20 self.p_unit = u.km 21 self.v_unit = u.km / u.s 22 self.t_unit = u.s 23 self.pv_dtype = np.dtype([('p', 'f8'), ('v', 'f8')]) 24 self.pv_t_dtype = np.dtype([('pv', self.pv_dtype), ('t', 'f8')]) 25 self.pv = np.array([(1., 0.25), (2., 0.5), (3., 0.75)], 26 self.pv_dtype) 27 self.pv_t = np.array([((4., 2.5), 0.), 28 ((5., 5.0), 1.), 29 ((6., 7.5), 2.)], self.pv_t_dtype) 30 31 32class StructuredTestBaseWithUnits(StructuredTestBase): 33 @classmethod 34 def setup_class(self): 35 super().setup_class() 36 self.pv_unit = StructuredUnit((self.p_unit, self.v_unit), 37 ('p', 'v')) 38 self.pv_t_unit = StructuredUnit((self.pv_unit, self.t_unit), 39 ('pv', 't')) 40 41 42class TestStructuredUnitBasics(StructuredTestBase): 43 44 def test_initialization_and_keying(self): 45 su = StructuredUnit((self.p_unit, self.v_unit), ('p', 'v')) 46 assert su['p'] is self.p_unit 47 assert su['v'] is self.v_unit 48 su2 = StructuredUnit((su, self.t_unit), ('pv', 't')) 49 assert isinstance(su2['pv'], StructuredUnit) 50 assert su2['pv']['p'] is self.p_unit 51 assert su2['pv']['v'] is self.v_unit 52 assert su2['t'] is self.t_unit 53 assert su2['pv'] == su 54 su3 = StructuredUnit(('AU', 'AU/day'), ('p', 'v')) 55 assert isinstance(su3['p'], UnitBase) 56 assert isinstance(su3['v'], UnitBase) 57 su4 = StructuredUnit('AU, AU/day', ('p', 'v')) 58 assert su4['p'] == u.AU 59 assert su4['v'] == u.AU / u.day 60 su5 = StructuredUnit(('AU', 'AU/day')) 61 assert su5.field_names == ('f0', 'f1') 62 assert su5['f0'] == u.AU 63 assert su5['f1'] == u.AU / u.day 64 65 def test_recursive_initialization(self): 66 su = StructuredUnit(((self.p_unit, self.v_unit), self.t_unit), 67 (('p', 'v'), 't')) 68 assert isinstance(su['pv'], StructuredUnit) 69 assert su['pv']['p'] is self.p_unit 70 assert su['pv']['v'] is self.v_unit 71 assert su['t'] is self.t_unit 72 su2 = StructuredUnit(((self.p_unit, self.v_unit), self.t_unit), 73 (['p_v', ('p', 'v')], 't')) 74 assert isinstance(su2['p_v'], StructuredUnit) 75 assert su2['p_v']['p'] is self.p_unit 76 assert su2['p_v']['v'] is self.v_unit 77 assert su2['t'] is self.t_unit 78 su3 = StructuredUnit((('AU', 'AU/day'), 'yr'), 79 (['p_v', ('p', 'v')], 't')) 80 assert isinstance(su3['p_v'], StructuredUnit) 81 assert su3['p_v']['p'] == u.AU 82 assert su3['p_v']['v'] == u.AU / u.day 83 assert su3['t'] == u.yr 84 su4 = StructuredUnit('(AU, AU/day), yr', (('p', 'v'), 't')) 85 assert isinstance(su4['pv'], StructuredUnit) 86 assert su4['pv']['p'] == u.AU 87 assert su4['pv']['v'] == u.AU / u.day 88 assert su4['t'] == u.yr 89 90 def test_extreme_recursive_initialization(self): 91 su = StructuredUnit('(yr,(AU,AU/day,(km,(day,day))),m)', 92 ('t', ('p', 'v', ('h', ('d1', 'd2'))), 'l')) 93 assert su.field_names == ('t', ['pvhd1d2', 94 ('p', 'v', 95 ['hd1d2', 96 ('h', 97 ['d1d2', 98 ('d1', 'd2')])])], 'l') 99 100 @pytest.mark.parametrize('names, invalid', [ 101 [('t', ['p', 'v']), "['p', 'v']"], 102 [('t', ['pv', 'p', 'v']), "['pv', 'p', 'v']"], 103 [('t', ['pv', ['p', 'v']]), "['pv', ['p', 'v']"], 104 [('t', ()), "()"], 105 [('t', ('p', None)), "None"], 106 [('t', ['pv', ('p', '')]), "''"]]) 107 def test_initialization_names_invalid_list_errors(self, names, invalid): 108 with pytest.raises(ValueError) as exc: 109 StructuredUnit('(yr,(AU,AU/day)', names) 110 assert f'invalid entry {invalid}' in str(exc) 111 112 def test_looks_like_unit(self): 113 su = StructuredUnit((self.p_unit, self.v_unit), ('p', 'v')) 114 assert Unit(su) is su 115 116 def test_initialize_with_float_dtype(self): 117 su = StructuredUnit(('AU', 'AU/d'), self.pv_dtype) 118 assert isinstance(su['p'], UnitBase) 119 assert isinstance(su['v'], UnitBase) 120 assert su['p'] == u.AU 121 assert su['v'] == u.AU / u.day 122 su = StructuredUnit((('km', 'km/s'), 'yr'), self.pv_t_dtype) 123 assert isinstance(su['pv'], StructuredUnit) 124 assert isinstance(su['pv']['p'], UnitBase) 125 assert isinstance(su['t'], UnitBase) 126 assert su['pv']['v'] == u.km / u.s 127 su = StructuredUnit('(km, km/s), yr', self.pv_t_dtype) 128 assert isinstance(su['pv'], StructuredUnit) 129 assert isinstance(su['pv']['p'], UnitBase) 130 assert isinstance(su['t'], UnitBase) 131 assert su['pv']['v'] == u.km / u.s 132 133 def test_initialize_with_structured_unit_for_names(self): 134 su = StructuredUnit(('AU', 'AU/d'), names=('p', 'v')) 135 su2 = StructuredUnit(('km', 'km/s'), names=su) 136 assert su2.field_names == ('p', 'v') 137 assert su2['p'] == u.km 138 assert su2['v'] == u.km / u.s 139 140 def test_initialize_single_field(self): 141 su = StructuredUnit('AU', 'p') 142 assert isinstance(su, StructuredUnit) 143 assert isinstance(su['p'], UnitBase) 144 assert su['p'] == u.AU 145 su = StructuredUnit('AU') 146 assert isinstance(su, StructuredUnit) 147 assert isinstance(su['f0'], UnitBase) 148 assert su['f0'] == u.AU 149 150 def test_equality(self): 151 su = StructuredUnit(('AU', 'AU/d'), self.pv_dtype) 152 assert su == StructuredUnit(('AU', 'AU/d'), self.pv_dtype) 153 assert su != StructuredUnit(('m', 'AU/d'), self.pv_dtype) 154 # Names should be ignored. 155 assert su == StructuredUnit(('AU', 'AU/d')) 156 assert su == StructuredUnit(('AU', 'AU/d'), names=('q', 'w')) 157 assert su != StructuredUnit(('m', 'm/s')) 158 159 def test_parsing(self): 160 su = Unit('AU, AU/d') 161 assert isinstance(su, StructuredUnit) 162 assert isinstance(su['f0'], UnitBase) 163 assert isinstance(su['f1'], UnitBase) 164 assert su['f0'] == u.AU 165 assert su['f1'] == u.AU/u.day 166 su2 = Unit('AU, AU/d, yr') 167 assert isinstance(su2, StructuredUnit) 168 assert su2 == StructuredUnit(('AU', 'AU/d', 'yr')) 169 su2a = Unit('(AU, AU/d, yr)') 170 assert isinstance(su2a, StructuredUnit) 171 assert su2a == su2 172 su3 = Unit('(km, km/s), yr') 173 assert isinstance(su3, StructuredUnit) 174 assert su3 == StructuredUnit((('km', 'km/s'), 'yr')) 175 su4 = Unit('km,') 176 assert isinstance(su4, StructuredUnit) 177 assert su4 == StructuredUnit((u.km,)) 178 su5 = Unit('(m,s),') 179 assert isinstance(su5, StructuredUnit) 180 assert su5 == StructuredUnit(((u.m, u.s),)) 181 ldbody_unit = Unit('Msun, 0.5rad^2, (au, au/day)') 182 assert ldbody_unit == StructuredUnit( 183 (u.Msun, Unit(u.rad**2 / 2), (u.AU, u.AU / u.day))) 184 185 def test_str(self): 186 su = StructuredUnit(((u.km, u.km/u.s), u.yr)) 187 assert str(su) == '((km, km / s), yr)' 188 assert Unit(str(su)) == su 189 190 def test_repr(self): 191 su = StructuredUnit(((u.km, u.km/u.s), u.yr)) 192 assert repr(su) == 'Unit("((km, km / s), yr)")' 193 assert eval(repr(su)) == su 194 195 196class TestStructuredUnitAsMapping(StructuredTestBaseWithUnits): 197 198 def test_len(self): 199 assert len(self.pv_unit) == 2 200 assert len(self.pv_t_unit) == 2 201 202 def test_keys(self): 203 slv = list(self.pv_t_unit.keys()) 204 assert slv == ['pv', 't'] 205 206 def test_values(self): 207 values = self.pv_t_unit.values() 208 assert values == (self.pv_unit, self.t_unit) 209 210 def test_field_names(self): 211 field_names = self.pv_t_unit.field_names 212 assert isinstance(field_names, tuple) 213 assert field_names == (['pv', ('p', 'v')], 't') 214 215 @pytest.mark.parametrize('iterable', [list, set]) 216 def test_as_iterable(self, iterable): 217 sl = iterable(self.pv_unit) 218 assert isinstance(sl, iterable) 219 assert sl == iterable(['p', 'v']) 220 221 def test_as_dict(self): 222 sd = dict(self.pv_t_unit) 223 assert sd == {'pv': self.pv_unit, 't': self.t_unit} 224 225 def test_contains(self): 226 assert 'p' in self.pv_unit 227 assert 'v' in self.pv_unit 228 assert 't' not in self.pv_unit 229 230 def test_setitem_fails(self): 231 with pytest.raises(TypeError, match='item assignment'): 232 self.pv_t_unit['t'] = u.Gyr 233 234 235class TestStructuredUnitMethods(StructuredTestBaseWithUnits): 236 def test_physical_type_id(self): 237 pv_ptid = self.pv_unit._get_physical_type_id() 238 assert len(pv_ptid) == 2 239 assert pv_ptid.dtype.names == ('p', 'v') 240 p_ptid = self.pv_unit['p']._get_physical_type_id() 241 v_ptid = self.pv_unit['v']._get_physical_type_id() 242 # Expected should be (subclass of) void, with structured object dtype. 243 expected = np.array((p_ptid, v_ptid), [('p', 'O'), ('v', 'O')])[()] 244 assert pv_ptid == expected 245 # Names should be ignored in comparison. 246 assert pv_ptid == np.array((p_ptid, v_ptid), 'O,O')[()] 247 # Should be possible to address by field and by number. 248 assert pv_ptid['p'] == p_ptid 249 assert pv_ptid['v'] == v_ptid 250 assert pv_ptid[0] == p_ptid 251 assert pv_ptid[1] == v_ptid 252 # More complicated version. 253 pv_t_ptid = self.pv_t_unit._get_physical_type_id() 254 t_ptid = self.t_unit._get_physical_type_id() 255 assert pv_t_ptid == np.array((pv_ptid, t_ptid), 'O,O')[()] 256 assert pv_t_ptid['pv'] == pv_ptid 257 assert pv_t_ptid['t'] == t_ptid 258 assert pv_t_ptid['pv'][1] == v_ptid 259 260 def test_physical_type(self): 261 pv_pt = self.pv_unit.physical_type 262 assert pv_pt == np.array(('length', 'speed'), 'O,O')[()] 263 264 pv_t_pt = self.pv_t_unit.physical_type 265 assert pv_t_pt == np.array((pv_pt, 'time'), 'O,O')[()] 266 267 def test_si(self): 268 pv_t_si = self.pv_t_unit.si 269 assert pv_t_si == self.pv_t_unit 270 assert pv_t_si['pv']['v'].scale == 1000 271 272 def test_cgs(self): 273 pv_t_cgs = self.pv_t_unit.cgs 274 assert pv_t_cgs == self.pv_t_unit 275 assert pv_t_cgs['pv']['v'].scale == 100000 276 277 def test_decompose(self): 278 pv_t_decompose = self.pv_t_unit.decompose() 279 assert pv_t_decompose['pv']['v'].scale == 1000 280 281 def test_is_equivalent(self): 282 assert self.pv_unit.is_equivalent(('AU', 'AU/day')) 283 assert not self.pv_unit.is_equivalent('m') 284 assert not self.pv_unit.is_equivalent(('AU', 'AU')) 285 # Names should be ignored. 286 pv_alt = StructuredUnit('m,m/s', names=('q', 'w')) 287 assert pv_alt.field_names != self.pv_unit.field_names 288 assert self.pv_unit.is_equivalent(pv_alt) 289 # Regular units should work too. 290 assert not u.m.is_equivalent(self.pv_unit) 291 292 def test_conversion(self): 293 pv1 = self.pv_unit.to(('AU', 'AU/day'), self.pv) 294 assert isinstance(pv1, np.ndarray) 295 assert pv1.dtype == self.pv.dtype 296 assert np.all(pv1['p'] * u.AU == self.pv['p'] * self.p_unit) 297 assert np.all(pv1['v'] * u.AU / u.day == self.pv['v'] * self.v_unit) 298 # Names should be from value. 299 su2 = StructuredUnit((self.p_unit, self.v_unit), 300 ('position', 'velocity')) 301 pv2 = su2.to(('Mm', 'mm/s'), self.pv) 302 assert pv2.dtype.names == ('p', 'v') 303 assert pv2.dtype == self.pv.dtype 304 # Check recursion. 305 pv_t1 = self.pv_t_unit.to((('AU', 'AU/day'), 'Myr'), self.pv_t) 306 assert isinstance(pv_t1, np.ndarray) 307 assert pv_t1.dtype == self.pv_t.dtype 308 assert np.all(pv_t1['pv']['p'] * u.AU == 309 self.pv_t['pv']['p'] * self.p_unit) 310 assert np.all(pv_t1['pv']['v'] * u.AU / u.day == 311 self.pv_t['pv']['v'] * self.v_unit) 312 assert np.all(pv_t1['t'] * u.Myr == self.pv_t['t'] * self.t_unit) 313 # Passing in tuples should work. 314 pv_t2 = self.pv_t_unit.to((('AU', 'AU/day'), 'Myr'), 315 ((1., 0.1), 10.)) 316 assert pv_t2['pv']['p'] == self.p_unit.to('AU', 1.) 317 assert pv_t2['pv']['v'] == self.v_unit.to('AU/day', 0.1) 318 assert pv_t2['t'] == self.t_unit.to('Myr', 10.) 319 pv_t3 = self.pv_t_unit.to((('AU', 'AU/day'), 'Myr'), 320 [((1., 0.1), 10.), 321 ((2., 0.2), 20.)]) 322 assert np.all(pv_t3['pv']['p'] == self.p_unit.to('AU', [1., 2.])) 323 assert np.all(pv_t3['pv']['v'] == self.v_unit.to('AU/day', [0.1, 0.2])) 324 assert np.all(pv_t3['t'] == self.t_unit.to('Myr', [10., 20.])) 325 326 327class TestStructuredUnitArithmatic(StructuredTestBaseWithUnits): 328 def test_multiplication(self): 329 pv_times_au = self.pv_unit * u.au 330 assert isinstance(pv_times_au, StructuredUnit) 331 assert pv_times_au.field_names == ('p', 'v') 332 assert pv_times_au['p'] == self.p_unit * u.AU 333 assert pv_times_au['v'] == self.v_unit * u.AU 334 au_times_pv = u.au * self.pv_unit 335 assert au_times_pv == pv_times_au 336 pv_times_au2 = self.pv_unit * 'au' 337 assert pv_times_au2 == pv_times_au 338 au_times_pv2 = 'AU' * self.pv_unit 339 assert au_times_pv2 == pv_times_au 340 with pytest.raises(TypeError): 341 self.pv_unit * self.pv_unit 342 with pytest.raises(TypeError): 343 's,s' * self.pv_unit 344 345 def test_division(self): 346 pv_by_s = self.pv_unit / u.s 347 assert isinstance(pv_by_s, StructuredUnit) 348 assert pv_by_s.field_names == ('p', 'v') 349 assert pv_by_s['p'] == self.p_unit / u.s 350 assert pv_by_s['v'] == self.v_unit / u.s 351 pv_by_s2 = self.pv_unit / 's' 352 assert pv_by_s2 == pv_by_s 353 with pytest.raises(TypeError): 354 1. / self.pv_unit 355 with pytest.raises(TypeError): 356 u.s / self.pv_unit 357 358 359class TestStructuredQuantity(StructuredTestBaseWithUnits): 360 def test_initialization_and_keying(self): 361 q_pv = Quantity(self.pv, self.pv_unit) 362 q_p = q_pv['p'] 363 assert isinstance(q_p, Quantity) 364 assert isinstance(q_p.unit, UnitBase) 365 assert np.all(q_p == self.pv['p'] * self.pv_unit['p']) 366 q_v = q_pv['v'] 367 assert isinstance(q_v, Quantity) 368 assert isinstance(q_v.unit, UnitBase) 369 assert np.all(q_v == self.pv['v'] * self.pv_unit['v']) 370 q_pv_t = Quantity(self.pv_t, self.pv_t_unit) 371 q_t = q_pv_t['t'] 372 assert np.all(q_t == self.pv_t['t'] * self.pv_t_unit['t']) 373 q_pv2 = q_pv_t['pv'] 374 assert isinstance(q_pv2, Quantity) 375 assert q_pv2.unit == self.pv_unit 376 with pytest.raises(ValueError): 377 Quantity(self.pv, self.pv_t_unit) 378 with pytest.raises(ValueError): 379 Quantity(self.pv_t, self.pv_unit) 380 381 def test_initialization_with_unit_tuples(self): 382 q_pv_t = Quantity(self.pv_t, (('km', 'km/s'), 's')) 383 assert isinstance(q_pv_t.unit, StructuredUnit) 384 assert q_pv_t.unit == self.pv_t_unit 385 386 def test_initialization_with_string(self): 387 q_pv_t = Quantity(self.pv_t, '(km, km/s), s') 388 assert isinstance(q_pv_t.unit, StructuredUnit) 389 assert q_pv_t.unit == self.pv_t_unit 390 391 def test_initialization_by_multiplication_with_unit(self): 392 q_pv_t = self.pv_t * self.pv_t_unit 393 assert q_pv_t.unit is self.pv_t_unit 394 assert np.all(q_pv_t.value == self.pv_t) 395 assert not np.may_share_memory(q_pv_t, self.pv_t) 396 q_pv_t2 = self.pv_t_unit * self.pv_t 397 assert q_pv_t.unit is self.pv_t_unit 398 # Not testing equality of structured Quantity here. 399 assert np.all(q_pv_t2.value == q_pv_t.value) 400 401 def test_initialization_by_shifting_to_unit(self): 402 q_pv_t = self.pv_t << self.pv_t_unit 403 assert q_pv_t.unit is self.pv_t_unit 404 assert np.all(q_pv_t.value == self.pv_t) 405 assert np.may_share_memory(q_pv_t, self.pv_t) 406 407 def test_getitem(self): 408 q_pv_t = Quantity(self.pv_t, self.pv_t_unit) 409 q_pv_t01 = q_pv_t[:2] 410 assert isinstance(q_pv_t01, Quantity) 411 assert q_pv_t01.unit == q_pv_t.unit 412 assert np.all(q_pv_t01['t'] == q_pv_t['t'][:2]) 413 q_pv_t1 = q_pv_t[1] 414 assert isinstance(q_pv_t1, Quantity) 415 assert q_pv_t1.unit == q_pv_t.unit 416 assert q_pv_t1.shape == () 417 assert q_pv_t1['t'] == q_pv_t['t'][1] 418 419 def test_value(self): 420 q_pv_t = Quantity(self.pv_t, self.pv_t_unit) 421 value = q_pv_t.value 422 assert type(value) is np.ndarray 423 assert np.all(value == self.pv_t) 424 value1 = q_pv_t[1].value 425 assert type(value1) is np.void 426 assert np.all(value1 == self.pv_t[1]) 427 428 def test_conversion(self): 429 q_pv = Quantity(self.pv, self.pv_unit) 430 q1 = q_pv.to(('AU', 'AU/day')) 431 assert isinstance(q1, Quantity) 432 assert q1['p'].unit == u.AU 433 assert q1['v'].unit == u.AU / u.day 434 assert np.all(q1['p'] == q_pv['p'].to(u.AU)) 435 assert np.all(q1['v'] == q_pv['v'].to(u.AU/u.day)) 436 q2 = q_pv.to(self.pv_unit) 437 assert q2['p'].unit == self.p_unit 438 assert q2['v'].unit == self.v_unit 439 assert np.all(q2['p'].value == self.pv['p']) 440 assert np.all(q2['v'].value == self.pv['v']) 441 assert not np.may_share_memory(q2, q_pv) 442 pv1 = q_pv.to_value(('AU', 'AU/day')) 443 assert type(pv1) is np.ndarray 444 assert np.all(pv1['p'] == q_pv['p'].to_value(u.AU)) 445 assert np.all(pv1['v'] == q_pv['v'].to_value(u.AU/u.day)) 446 pv11 = q_pv[1].to_value(('AU', 'AU/day')) 447 assert type(pv11) is np.void 448 assert pv11 == pv1[1] 449 q_pv_t = Quantity(self.pv_t, self.pv_t_unit) 450 q2 = q_pv_t.to((('kpc', 'kpc/Myr'), 'Myr')) 451 assert q2['pv']['p'].unit == u.kpc 452 assert q2['pv']['v'].unit == u.kpc / u.Myr 453 assert q2['t'].unit == u.Myr 454 assert np.all(q2['pv']['p'] == q_pv_t['pv']['p'].to(u.kpc)) 455 assert np.all(q2['pv']['v'] == q_pv_t['pv']['v'].to(u.kpc/u.Myr)) 456 assert np.all(q2['t'] == q_pv_t['t'].to(u.Myr)) 457 458 def test_conversion_via_lshift(self): 459 q_pv = Quantity(self.pv, self.pv_unit) 460 q1 = q_pv << StructuredUnit(('AU', 'AU/day')) 461 assert isinstance(q1, Quantity) 462 assert q1['p'].unit == u.AU 463 assert q1['v'].unit == u.AU / u.day 464 assert np.all(q1['p'] == q_pv['p'].to(u.AU)) 465 assert np.all(q1['v'] == q_pv['v'].to(u.AU/u.day)) 466 q2 = q_pv << self.pv_unit 467 assert q2['p'].unit == self.p_unit 468 assert q2['v'].unit == self.v_unit 469 assert np.all(q2['p'].value == self.pv['p']) 470 assert np.all(q2['v'].value == self.pv['v']) 471 assert np.may_share_memory(q2, q_pv) 472 q_pv_t = Quantity(self.pv_t, self.pv_t_unit) 473 q2 = q_pv_t << '(kpc,kpc/Myr),Myr' 474 assert q2['pv']['p'].unit == u.kpc 475 assert q2['pv']['v'].unit == u.kpc / u.Myr 476 assert q2['t'].unit == u.Myr 477 assert np.all(q2['pv']['p'] == q_pv_t['pv']['p'].to(u.kpc)) 478 assert np.all(q2['pv']['v'] == q_pv_t['pv']['v'].to(u.kpc/u.Myr)) 479 assert np.all(q2['t'] == q_pv_t['t'].to(u.Myr)) 480 481 def test_inplace_conversion(self): 482 q_pv = Quantity(self.pv, self.pv_unit) 483 q1 = q_pv.copy() 484 q_link = q1 485 q1 <<= StructuredUnit(('AU', 'AU/day')) 486 assert q1 is q_link 487 assert q1['p'].unit == u.AU 488 assert q1['v'].unit == u.AU / u.day 489 assert np.all(q1['p'] == q_pv['p'].to(u.AU)) 490 assert np.all(q1['v'] == q_pv['v'].to(u.AU/u.day)) 491 q_pv_t = Quantity(self.pv_t, self.pv_t_unit) 492 q2 = q_pv_t.copy() 493 q_link = q2 494 q2 <<= '(kpc,kpc/Myr),Myr' 495 assert q2 is q_link 496 assert q2['pv']['p'].unit == u.kpc 497 assert q2['pv']['v'].unit == u.kpc / u.Myr 498 assert q2['t'].unit == u.Myr 499 assert np.all(q2['pv']['p'] == q_pv_t['pv']['p'].to(u.kpc)) 500 assert np.all(q2['pv']['v'] == q_pv_t['pv']['v'].to(u.kpc/u.Myr)) 501 assert np.all(q2['t'] == q_pv_t['t'].to(u.Myr)) 502 503 def test_si(self): 504 q_pv_t = Quantity(self.pv_t, self.pv_t_unit) 505 q_pv_t_si = q_pv_t.si 506 assert_array_equal(q_pv_t_si, q_pv_t.to('(m,m/s),s')) 507 508 def test_cgs(self): 509 q_pv_t = Quantity(self.pv_t, self.pv_t_unit) 510 q_pv_t_cgs = q_pv_t.cgs 511 assert_array_equal(q_pv_t_cgs, q_pv_t.to('(cm,cm/s),s')) 512 513 def test_equality(self): 514 q_pv = Quantity(self.pv, self.pv_unit) 515 equal = q_pv == q_pv 516 not_equal = q_pv != q_pv 517 assert np.all(equal) 518 assert not np.any(not_equal) 519 equal2 = q_pv == q_pv[1] 520 not_equal2 = q_pv != q_pv[1] 521 assert np.all(equal2 == [False, True, False]) 522 assert np.all(not_equal2 != equal2) 523 q1 = q_pv.to(('AU', 'AU/day')) 524 # Ensure same conversion is done, by placing q1 first. 525 assert np.all(q1 == q_pv) 526 assert not np.any(q1 != q_pv) 527 # Check different names in dtype. 528 assert np.all(q1.value * u.Unit('AU, AU/day') == q_pv) 529 assert not np.any(q1.value * u.Unit('AU, AU/day') != q_pv) 530 assert (q_pv == 'b') is False 531 assert ('b' != q_pv) is True 532 q_pv_t = Quantity(self.pv_t, self.pv_t_unit) 533 assert np.all((q_pv_t[2] == q_pv_t) == [False, False, True]) 534 assert np.all((q_pv_t[2] != q_pv_t) != [False, False, True]) 535 assert (q_pv == q_pv_t) is False 536 assert (q_pv_t != q_pv) is True 537 538 def test_setitem(self): 539 q_pv = Quantity(self.pv, self.pv_unit) 540 q_pv[1] = (2., 2.) * self.pv_unit 541 assert q_pv[1].value == np.array((2., 2.), self.pv_dtype) 542 q_pv[1:2] = (1., 0.5) * u.Unit('AU, AU/day') 543 assert q_pv['p'][1] == 1. * u.AU 544 assert q_pv['v'][1] == 0.5 * u.AU / u.day 545 q_pv['v'] = 1. * u.km / u.s 546 assert np.all(q_pv['v'] == 1. * u.km / u.s) 547 with pytest.raises(u.UnitsError): 548 q_pv[1] = (1., 1.) * u.Unit('AU, AU') 549 with pytest.raises(u.UnitsError): 550 q_pv['v'] = 1. * u.km 551 q_pv_t = Quantity(self.pv_t, self.pv_t_unit) 552 q_pv_t[1] = ((2., 2.), 3.) * self.pv_t_unit 553 assert q_pv_t[1].value == np.array(((2., 2.), 3.), self.pv_t_dtype) 554 q_pv_t[1:2] = ((1., 0.5), 5.) * u.Unit('(AU, AU/day), yr') 555 assert q_pv_t['pv'][1] == (1., 0.5) * u.Unit('AU, AU/day') 556 assert q_pv_t['t'][1] == 5. * u.yr 557 q_pv_t['pv'] = (1., 0.5) * self.pv_unit 558 assert np.all(q_pv_t['pv'] == (1., 0.5) * self.pv_unit) 559 560 561class TestStructuredQuantityFunctions(StructuredTestBaseWithUnits): 562 @classmethod 563 def setup_class(self): 564 super().setup_class() 565 self.q_pv = self.pv << self.pv_unit 566 self.q_pv_t = self.pv_t << self.pv_t_unit 567 568 def test_empty_like(self): 569 z = np.empty_like(self.q_pv) 570 assert z.dtype == self.pv_dtype 571 assert z.unit == self.pv_unit 572 assert z.shape == self.pv.shape 573 574 @pytest.mark.parametrize('func', [np.zeros_like, np.ones_like]) 575 def test_zeros_ones_like(self, func): 576 z = func(self.q_pv) 577 assert z.dtype == self.pv_dtype 578 assert z.unit == self.pv_unit 579 assert z.shape == self.pv.shape 580 assert_array_equal(z, func(self.pv) << self.pv_unit) 581 582 583class TestStructuredSpecificTypeQuantity(StructuredTestBaseWithUnits): 584 def setup_class(self): 585 super().setup_class() 586 587 class PositionVelocity(u.SpecificTypeQuantity): 588 _equivalent_unit = self.pv_unit 589 590 self.PositionVelocity = PositionVelocity 591 592 def test_init(self): 593 pv = self.PositionVelocity(self.pv, self.pv_unit) 594 assert isinstance(pv, self.PositionVelocity) 595 assert type(pv['p']) is u.Quantity 596 assert_array_equal(pv['p'], self.pv['p'] << self.pv_unit['p']) 597 598 pv2 = self.PositionVelocity(self.pv, 'AU,AU/day') 599 assert_array_equal(pv2['p'], self.pv['p'] << u.AU) 600 601 def test_error_on_non_equivalent_unit(self): 602 with pytest.raises(u.UnitsError): 603 self.PositionVelocity(self.pv, 'AU') 604 with pytest.raises(u.UnitsError): 605 self.PositionVelocity(self.pv, 'AU,yr') 606 607 608class TestStructuredLogUnit: 609 def setup_class(self): 610 self.mag_time_dtype = np.dtype([('mag', 'f8'), ('t', 'f8')]) 611 self.mag_time = np.array([(20., 10.), (25., 100.)], self.mag_time_dtype) 612 613 def test_unit_initialization(self): 614 mag_time_unit = StructuredUnit((u.STmag, u.s), self.mag_time_dtype) 615 assert mag_time_unit['mag'] == u.STmag 616 assert mag_time_unit['t'] == u.s 617 618 mag_time_unit2 = u.Unit('mag(ST),s') 619 assert mag_time_unit2 == mag_time_unit 620 621 def test_quantity_initialization(self): 622 su = u.Unit('mag(ST),s') 623 mag_time = self.mag_time << su 624 assert isinstance(mag_time['mag'], u.Magnitude) 625 assert isinstance(mag_time['t'], u.Quantity) 626 assert mag_time.unit == su 627 assert_array_equal(mag_time['mag'], self.mag_time['mag'] << u.STmag) 628 assert_array_equal(mag_time['t'], self.mag_time['t'] << u.s) 629 630 def test_quantity_si(self): 631 mag_time = self.mag_time << u.Unit('mag(ST),yr') 632 mag_time_si = mag_time.si 633 assert_array_equal(mag_time_si['mag'], mag_time['mag'].si) 634 assert_array_equal(mag_time_si['t'], mag_time['t'].si) 635 636 637class TestStructuredMaskedQuantity(StructuredTestBaseWithUnits): 638 """Somewhat minimal tests. Conversion is most stringent.""" 639 def setup_class(self): 640 super().setup_class() 641 self.qpv = self.pv << self.pv_unit 642 self.pv_mask = np.array([(True, False), 643 (False, False), 644 (False, True)], [('p', bool), ('v', bool)]) 645 self.mpv = Masked(self.qpv, mask=self.pv_mask) 646 647 def test_init(self): 648 assert isinstance(self.mpv, Masked) 649 assert isinstance(self.mpv, Quantity) 650 assert_array_equal(self.mpv.unmasked, self.qpv) 651 assert_array_equal(self.mpv.mask, self.pv_mask) 652 653 def test_slicing(self): 654 mp = self.mpv['p'] 655 assert isinstance(mp, Masked) 656 assert isinstance(mp, Quantity) 657 assert_array_equal(mp.unmasked, self.qpv['p']) 658 assert_array_equal(mp.mask, self.pv_mask['p']) 659 660 def test_conversion(self): 661 mpv = self.mpv.to('AU,AU/day') 662 assert isinstance(mpv, Masked) 663 assert isinstance(mpv, Quantity) 664 assert_array_equal(mpv.unmasked, self.qpv.to('AU,AU/day')) 665 assert_array_equal(mpv.mask, self.pv_mask) 666 assert np.all(mpv == self.mpv) 667 668 def test_si(self): 669 mpv = self.mpv.si 670 assert isinstance(mpv, Masked) 671 assert isinstance(mpv, Quantity) 672 assert_array_equal(mpv.unmasked, self.qpv.si) 673 assert_array_equal(mpv.mask, self.pv_mask) 674 assert np.all(mpv == self.mpv) 675