1# coding: utf-8
2# Licensed under a 3-clause BSD style license - see LICENSE.rst
3"""
4Test Structured units and quantities.
5"""
6import pytest
7import numpy as np
8from numpy.testing import assert_array_equal
9
10from astropy import units as u
11from astropy.units import StructuredUnit, Unit, UnitBase, Quantity
12from astropy.utils.masked import Masked
13
14
15class StructuredTestBase:
16    @classmethod
17    def setup_class(self):
18        self.pv_dtype = np.dtype([('p', 'f8'), ('v', 'f8')])
19        self.pv_t_dtype = np.dtype([('pv', self.pv_dtype), ('t', 'f8')])
20        self.p_unit = u.km
21        self.v_unit = u.km / u.s
22        self.t_unit = u.s
23        self.pv_dtype = np.dtype([('p', 'f8'), ('v', 'f8')])
24        self.pv_t_dtype = np.dtype([('pv', self.pv_dtype), ('t', 'f8')])
25        self.pv = np.array([(1., 0.25), (2., 0.5), (3., 0.75)],
26                           self.pv_dtype)
27        self.pv_t = np.array([((4., 2.5), 0.),
28                              ((5., 5.0), 1.),
29                              ((6., 7.5), 2.)], self.pv_t_dtype)
30
31
32class StructuredTestBaseWithUnits(StructuredTestBase):
33    @classmethod
34    def setup_class(self):
35        super().setup_class()
36        self.pv_unit = StructuredUnit((self.p_unit, self.v_unit),
37                                      ('p', 'v'))
38        self.pv_t_unit = StructuredUnit((self.pv_unit, self.t_unit),
39                                        ('pv', 't'))
40
41
42class TestStructuredUnitBasics(StructuredTestBase):
43
44    def test_initialization_and_keying(self):
45        su = StructuredUnit((self.p_unit, self.v_unit), ('p', 'v'))
46        assert su['p'] is self.p_unit
47        assert su['v'] is self.v_unit
48        su2 = StructuredUnit((su, self.t_unit), ('pv', 't'))
49        assert isinstance(su2['pv'], StructuredUnit)
50        assert su2['pv']['p'] is self.p_unit
51        assert su2['pv']['v'] is self.v_unit
52        assert su2['t'] is self.t_unit
53        assert su2['pv'] == su
54        su3 = StructuredUnit(('AU', 'AU/day'), ('p', 'v'))
55        assert isinstance(su3['p'], UnitBase)
56        assert isinstance(su3['v'], UnitBase)
57        su4 = StructuredUnit('AU, AU/day', ('p', 'v'))
58        assert su4['p'] == u.AU
59        assert su4['v'] == u.AU / u.day
60        su5 = StructuredUnit(('AU', 'AU/day'))
61        assert su5.field_names == ('f0', 'f1')
62        assert su5['f0'] == u.AU
63        assert su5['f1'] == u.AU / u.day
64
65    def test_recursive_initialization(self):
66        su = StructuredUnit(((self.p_unit, self.v_unit), self.t_unit),
67                            (('p', 'v'), 't'))
68        assert isinstance(su['pv'], StructuredUnit)
69        assert su['pv']['p'] is self.p_unit
70        assert su['pv']['v'] is self.v_unit
71        assert su['t'] is self.t_unit
72        su2 = StructuredUnit(((self.p_unit, self.v_unit), self.t_unit),
73                             (['p_v', ('p', 'v')], 't'))
74        assert isinstance(su2['p_v'], StructuredUnit)
75        assert su2['p_v']['p'] is self.p_unit
76        assert su2['p_v']['v'] is self.v_unit
77        assert su2['t'] is self.t_unit
78        su3 = StructuredUnit((('AU', 'AU/day'), 'yr'),
79                             (['p_v', ('p', 'v')], 't'))
80        assert isinstance(su3['p_v'], StructuredUnit)
81        assert su3['p_v']['p'] == u.AU
82        assert su3['p_v']['v'] == u.AU / u.day
83        assert su3['t'] == u.yr
84        su4 = StructuredUnit('(AU, AU/day), yr', (('p', 'v'), 't'))
85        assert isinstance(su4['pv'], StructuredUnit)
86        assert su4['pv']['p'] == u.AU
87        assert su4['pv']['v'] == u.AU / u.day
88        assert su4['t'] == u.yr
89
90    def test_extreme_recursive_initialization(self):
91        su = StructuredUnit('(yr,(AU,AU/day,(km,(day,day))),m)',
92                            ('t', ('p', 'v', ('h', ('d1', 'd2'))), 'l'))
93        assert su.field_names == ('t', ['pvhd1d2',
94                                        ('p', 'v',
95                                         ['hd1d2',
96                                          ('h',
97                                           ['d1d2',
98                                            ('d1', 'd2')])])], 'l')
99
100    @pytest.mark.parametrize('names, invalid', [
101        [('t', ['p', 'v']), "['p', 'v']"],
102        [('t', ['pv', 'p', 'v']), "['pv', 'p', 'v']"],
103        [('t', ['pv', ['p', 'v']]), "['pv', ['p', 'v']"],
104        [('t', ()), "()"],
105        [('t', ('p', None)), "None"],
106        [('t', ['pv', ('p', '')]), "''"]])
107    def test_initialization_names_invalid_list_errors(self, names, invalid):
108        with pytest.raises(ValueError) as exc:
109            StructuredUnit('(yr,(AU,AU/day)', names)
110        assert f'invalid entry {invalid}' in str(exc)
111
112    def test_looks_like_unit(self):
113        su = StructuredUnit((self.p_unit, self.v_unit), ('p', 'v'))
114        assert Unit(su) is su
115
116    def test_initialize_with_float_dtype(self):
117        su = StructuredUnit(('AU', 'AU/d'), self.pv_dtype)
118        assert isinstance(su['p'], UnitBase)
119        assert isinstance(su['v'], UnitBase)
120        assert su['p'] == u.AU
121        assert su['v'] == u.AU / u.day
122        su = StructuredUnit((('km', 'km/s'), 'yr'), self.pv_t_dtype)
123        assert isinstance(su['pv'], StructuredUnit)
124        assert isinstance(su['pv']['p'], UnitBase)
125        assert isinstance(su['t'], UnitBase)
126        assert su['pv']['v'] == u.km / u.s
127        su = StructuredUnit('(km, km/s), yr', self.pv_t_dtype)
128        assert isinstance(su['pv'], StructuredUnit)
129        assert isinstance(su['pv']['p'], UnitBase)
130        assert isinstance(su['t'], UnitBase)
131        assert su['pv']['v'] == u.km / u.s
132
133    def test_initialize_with_structured_unit_for_names(self):
134        su = StructuredUnit(('AU', 'AU/d'), names=('p', 'v'))
135        su2 = StructuredUnit(('km', 'km/s'), names=su)
136        assert su2.field_names == ('p', 'v')
137        assert su2['p'] == u.km
138        assert su2['v'] == u.km / u.s
139
140    def test_initialize_single_field(self):
141        su = StructuredUnit('AU', 'p')
142        assert isinstance(su, StructuredUnit)
143        assert isinstance(su['p'], UnitBase)
144        assert su['p'] == u.AU
145        su = StructuredUnit('AU')
146        assert isinstance(su, StructuredUnit)
147        assert isinstance(su['f0'], UnitBase)
148        assert su['f0'] == u.AU
149
150    def test_equality(self):
151        su = StructuredUnit(('AU', 'AU/d'), self.pv_dtype)
152        assert su == StructuredUnit(('AU', 'AU/d'), self.pv_dtype)
153        assert su != StructuredUnit(('m', 'AU/d'), self.pv_dtype)
154        # Names should be ignored.
155        assert su == StructuredUnit(('AU', 'AU/d'))
156        assert su == StructuredUnit(('AU', 'AU/d'), names=('q', 'w'))
157        assert su != StructuredUnit(('m', 'm/s'))
158
159    def test_parsing(self):
160        su = Unit('AU, AU/d')
161        assert isinstance(su, StructuredUnit)
162        assert isinstance(su['f0'], UnitBase)
163        assert isinstance(su['f1'], UnitBase)
164        assert su['f0'] == u.AU
165        assert su['f1'] == u.AU/u.day
166        su2 = Unit('AU, AU/d, yr')
167        assert isinstance(su2, StructuredUnit)
168        assert su2 == StructuredUnit(('AU', 'AU/d', 'yr'))
169        su2a = Unit('(AU, AU/d, yr)')
170        assert isinstance(su2a, StructuredUnit)
171        assert su2a == su2
172        su3 = Unit('(km, km/s), yr')
173        assert isinstance(su3, StructuredUnit)
174        assert su3 == StructuredUnit((('km', 'km/s'), 'yr'))
175        su4 = Unit('km,')
176        assert isinstance(su4, StructuredUnit)
177        assert su4 == StructuredUnit((u.km,))
178        su5 = Unit('(m,s),')
179        assert isinstance(su5, StructuredUnit)
180        assert su5 == StructuredUnit(((u.m, u.s),))
181        ldbody_unit = Unit('Msun, 0.5rad^2, (au, au/day)')
182        assert ldbody_unit == StructuredUnit(
183            (u.Msun, Unit(u.rad**2 / 2), (u.AU, u.AU / u.day)))
184
185    def test_str(self):
186        su = StructuredUnit(((u.km, u.km/u.s), u.yr))
187        assert str(su) == '((km, km / s), yr)'
188        assert Unit(str(su)) == su
189
190    def test_repr(self):
191        su = StructuredUnit(((u.km, u.km/u.s), u.yr))
192        assert repr(su) == 'Unit("((km, km / s), yr)")'
193        assert eval(repr(su)) == su
194
195
196class TestStructuredUnitAsMapping(StructuredTestBaseWithUnits):
197
198    def test_len(self):
199        assert len(self.pv_unit) == 2
200        assert len(self.pv_t_unit) == 2
201
202    def test_keys(self):
203        slv = list(self.pv_t_unit.keys())
204        assert slv == ['pv', 't']
205
206    def test_values(self):
207        values = self.pv_t_unit.values()
208        assert values == (self.pv_unit, self.t_unit)
209
210    def test_field_names(self):
211        field_names = self.pv_t_unit.field_names
212        assert isinstance(field_names, tuple)
213        assert field_names == (['pv', ('p', 'v')], 't')
214
215    @pytest.mark.parametrize('iterable', [list, set])
216    def test_as_iterable(self, iterable):
217        sl = iterable(self.pv_unit)
218        assert isinstance(sl, iterable)
219        assert sl == iterable(['p', 'v'])
220
221    def test_as_dict(self):
222        sd = dict(self.pv_t_unit)
223        assert sd == {'pv': self.pv_unit, 't': self.t_unit}
224
225    def test_contains(self):
226        assert 'p' in self.pv_unit
227        assert 'v' in self.pv_unit
228        assert 't' not in self.pv_unit
229
230    def test_setitem_fails(self):
231        with pytest.raises(TypeError, match='item assignment'):
232            self.pv_t_unit['t'] = u.Gyr
233
234
235class TestStructuredUnitMethods(StructuredTestBaseWithUnits):
236    def test_physical_type_id(self):
237        pv_ptid = self.pv_unit._get_physical_type_id()
238        assert len(pv_ptid) == 2
239        assert pv_ptid.dtype.names == ('p', 'v')
240        p_ptid = self.pv_unit['p']._get_physical_type_id()
241        v_ptid = self.pv_unit['v']._get_physical_type_id()
242        # Expected should be (subclass of) void, with structured object dtype.
243        expected = np.array((p_ptid, v_ptid), [('p', 'O'), ('v', 'O')])[()]
244        assert pv_ptid == expected
245        # Names should be ignored in comparison.
246        assert pv_ptid == np.array((p_ptid, v_ptid), 'O,O')[()]
247        # Should be possible to address by field and by number.
248        assert pv_ptid['p'] == p_ptid
249        assert pv_ptid['v'] == v_ptid
250        assert pv_ptid[0] == p_ptid
251        assert pv_ptid[1] == v_ptid
252        # More complicated version.
253        pv_t_ptid = self.pv_t_unit._get_physical_type_id()
254        t_ptid = self.t_unit._get_physical_type_id()
255        assert pv_t_ptid == np.array((pv_ptid, t_ptid), 'O,O')[()]
256        assert pv_t_ptid['pv'] == pv_ptid
257        assert pv_t_ptid['t'] == t_ptid
258        assert pv_t_ptid['pv'][1] == v_ptid
259
260    def test_physical_type(self):
261        pv_pt = self.pv_unit.physical_type
262        assert pv_pt == np.array(('length', 'speed'), 'O,O')[()]
263
264        pv_t_pt = self.pv_t_unit.physical_type
265        assert pv_t_pt == np.array((pv_pt, 'time'), 'O,O')[()]
266
267    def test_si(self):
268        pv_t_si = self.pv_t_unit.si
269        assert pv_t_si == self.pv_t_unit
270        assert pv_t_si['pv']['v'].scale == 1000
271
272    def test_cgs(self):
273        pv_t_cgs = self.pv_t_unit.cgs
274        assert pv_t_cgs == self.pv_t_unit
275        assert pv_t_cgs['pv']['v'].scale == 100000
276
277    def test_decompose(self):
278        pv_t_decompose = self.pv_t_unit.decompose()
279        assert pv_t_decompose['pv']['v'].scale == 1000
280
281    def test_is_equivalent(self):
282        assert self.pv_unit.is_equivalent(('AU', 'AU/day'))
283        assert not self.pv_unit.is_equivalent('m')
284        assert not self.pv_unit.is_equivalent(('AU', 'AU'))
285        # Names should be ignored.
286        pv_alt = StructuredUnit('m,m/s', names=('q', 'w'))
287        assert pv_alt.field_names != self.pv_unit.field_names
288        assert self.pv_unit.is_equivalent(pv_alt)
289        # Regular units should work too.
290        assert not u.m.is_equivalent(self.pv_unit)
291
292    def test_conversion(self):
293        pv1 = self.pv_unit.to(('AU', 'AU/day'), self.pv)
294        assert isinstance(pv1, np.ndarray)
295        assert pv1.dtype == self.pv.dtype
296        assert np.all(pv1['p'] * u.AU == self.pv['p'] * self.p_unit)
297        assert np.all(pv1['v'] * u.AU / u.day == self.pv['v'] * self.v_unit)
298        # Names should be from value.
299        su2 = StructuredUnit((self.p_unit, self.v_unit),
300                             ('position', 'velocity'))
301        pv2 = su2.to(('Mm', 'mm/s'), self.pv)
302        assert pv2.dtype.names == ('p', 'v')
303        assert pv2.dtype == self.pv.dtype
304        # Check recursion.
305        pv_t1 = self.pv_t_unit.to((('AU', 'AU/day'), 'Myr'), self.pv_t)
306        assert isinstance(pv_t1, np.ndarray)
307        assert pv_t1.dtype == self.pv_t.dtype
308        assert np.all(pv_t1['pv']['p'] * u.AU ==
309                      self.pv_t['pv']['p'] * self.p_unit)
310        assert np.all(pv_t1['pv']['v'] * u.AU / u.day ==
311                      self.pv_t['pv']['v'] * self.v_unit)
312        assert np.all(pv_t1['t'] * u.Myr == self.pv_t['t'] * self.t_unit)
313        # Passing in tuples should work.
314        pv_t2 = self.pv_t_unit.to((('AU', 'AU/day'), 'Myr'),
315                                  ((1., 0.1), 10.))
316        assert pv_t2['pv']['p'] == self.p_unit.to('AU', 1.)
317        assert pv_t2['pv']['v'] == self.v_unit.to('AU/day', 0.1)
318        assert pv_t2['t'] == self.t_unit.to('Myr', 10.)
319        pv_t3 = self.pv_t_unit.to((('AU', 'AU/day'), 'Myr'),
320                                  [((1., 0.1), 10.),
321                                   ((2., 0.2), 20.)])
322        assert np.all(pv_t3['pv']['p'] == self.p_unit.to('AU', [1., 2.]))
323        assert np.all(pv_t3['pv']['v'] == self.v_unit.to('AU/day', [0.1, 0.2]))
324        assert np.all(pv_t3['t'] == self.t_unit.to('Myr', [10., 20.]))
325
326
327class TestStructuredUnitArithmatic(StructuredTestBaseWithUnits):
328    def test_multiplication(self):
329        pv_times_au = self.pv_unit * u.au
330        assert isinstance(pv_times_au, StructuredUnit)
331        assert pv_times_au.field_names == ('p', 'v')
332        assert pv_times_au['p'] == self.p_unit * u.AU
333        assert pv_times_au['v'] == self.v_unit * u.AU
334        au_times_pv = u.au * self.pv_unit
335        assert au_times_pv == pv_times_au
336        pv_times_au2 = self.pv_unit * 'au'
337        assert pv_times_au2 == pv_times_au
338        au_times_pv2 = 'AU' * self.pv_unit
339        assert au_times_pv2 == pv_times_au
340        with pytest.raises(TypeError):
341            self.pv_unit * self.pv_unit
342        with pytest.raises(TypeError):
343            's,s' * self.pv_unit
344
345    def test_division(self):
346        pv_by_s = self.pv_unit / u.s
347        assert isinstance(pv_by_s, StructuredUnit)
348        assert pv_by_s.field_names == ('p', 'v')
349        assert pv_by_s['p'] == self.p_unit / u.s
350        assert pv_by_s['v'] == self.v_unit / u.s
351        pv_by_s2 = self.pv_unit / 's'
352        assert pv_by_s2 == pv_by_s
353        with pytest.raises(TypeError):
354            1. / self.pv_unit
355        with pytest.raises(TypeError):
356            u.s / self.pv_unit
357
358
359class TestStructuredQuantity(StructuredTestBaseWithUnits):
360    def test_initialization_and_keying(self):
361        q_pv = Quantity(self.pv, self.pv_unit)
362        q_p = q_pv['p']
363        assert isinstance(q_p, Quantity)
364        assert isinstance(q_p.unit, UnitBase)
365        assert np.all(q_p == self.pv['p'] * self.pv_unit['p'])
366        q_v = q_pv['v']
367        assert isinstance(q_v, Quantity)
368        assert isinstance(q_v.unit, UnitBase)
369        assert np.all(q_v == self.pv['v'] * self.pv_unit['v'])
370        q_pv_t = Quantity(self.pv_t, self.pv_t_unit)
371        q_t = q_pv_t['t']
372        assert np.all(q_t == self.pv_t['t'] * self.pv_t_unit['t'])
373        q_pv2 = q_pv_t['pv']
374        assert isinstance(q_pv2, Quantity)
375        assert q_pv2.unit == self.pv_unit
376        with pytest.raises(ValueError):
377            Quantity(self.pv, self.pv_t_unit)
378        with pytest.raises(ValueError):
379            Quantity(self.pv_t, self.pv_unit)
380
381    def test_initialization_with_unit_tuples(self):
382        q_pv_t = Quantity(self.pv_t, (('km', 'km/s'), 's'))
383        assert isinstance(q_pv_t.unit, StructuredUnit)
384        assert q_pv_t.unit == self.pv_t_unit
385
386    def test_initialization_with_string(self):
387        q_pv_t = Quantity(self.pv_t, '(km, km/s), s')
388        assert isinstance(q_pv_t.unit, StructuredUnit)
389        assert q_pv_t.unit == self.pv_t_unit
390
391    def test_initialization_by_multiplication_with_unit(self):
392        q_pv_t = self.pv_t * self.pv_t_unit
393        assert q_pv_t.unit is self.pv_t_unit
394        assert np.all(q_pv_t.value == self.pv_t)
395        assert not np.may_share_memory(q_pv_t, self.pv_t)
396        q_pv_t2 = self.pv_t_unit * self.pv_t
397        assert q_pv_t.unit is self.pv_t_unit
398        # Not testing equality of structured Quantity here.
399        assert np.all(q_pv_t2.value == q_pv_t.value)
400
401    def test_initialization_by_shifting_to_unit(self):
402        q_pv_t = self.pv_t << self.pv_t_unit
403        assert q_pv_t.unit is self.pv_t_unit
404        assert np.all(q_pv_t.value == self.pv_t)
405        assert np.may_share_memory(q_pv_t, self.pv_t)
406
407    def test_getitem(self):
408        q_pv_t = Quantity(self.pv_t, self.pv_t_unit)
409        q_pv_t01 = q_pv_t[:2]
410        assert isinstance(q_pv_t01, Quantity)
411        assert q_pv_t01.unit == q_pv_t.unit
412        assert np.all(q_pv_t01['t'] == q_pv_t['t'][:2])
413        q_pv_t1 = q_pv_t[1]
414        assert isinstance(q_pv_t1, Quantity)
415        assert q_pv_t1.unit == q_pv_t.unit
416        assert q_pv_t1.shape == ()
417        assert q_pv_t1['t'] == q_pv_t['t'][1]
418
419    def test_value(self):
420        q_pv_t = Quantity(self.pv_t, self.pv_t_unit)
421        value = q_pv_t.value
422        assert type(value) is np.ndarray
423        assert np.all(value == self.pv_t)
424        value1 = q_pv_t[1].value
425        assert type(value1) is np.void
426        assert np.all(value1 == self.pv_t[1])
427
428    def test_conversion(self):
429        q_pv = Quantity(self.pv, self.pv_unit)
430        q1 = q_pv.to(('AU', 'AU/day'))
431        assert isinstance(q1, Quantity)
432        assert q1['p'].unit == u.AU
433        assert q1['v'].unit == u.AU / u.day
434        assert np.all(q1['p'] == q_pv['p'].to(u.AU))
435        assert np.all(q1['v'] == q_pv['v'].to(u.AU/u.day))
436        q2 = q_pv.to(self.pv_unit)
437        assert q2['p'].unit == self.p_unit
438        assert q2['v'].unit == self.v_unit
439        assert np.all(q2['p'].value == self.pv['p'])
440        assert np.all(q2['v'].value == self.pv['v'])
441        assert not np.may_share_memory(q2, q_pv)
442        pv1 = q_pv.to_value(('AU', 'AU/day'))
443        assert type(pv1) is np.ndarray
444        assert np.all(pv1['p'] == q_pv['p'].to_value(u.AU))
445        assert np.all(pv1['v'] == q_pv['v'].to_value(u.AU/u.day))
446        pv11 = q_pv[1].to_value(('AU', 'AU/day'))
447        assert type(pv11) is np.void
448        assert pv11 == pv1[1]
449        q_pv_t = Quantity(self.pv_t, self.pv_t_unit)
450        q2 = q_pv_t.to((('kpc', 'kpc/Myr'), 'Myr'))
451        assert q2['pv']['p'].unit == u.kpc
452        assert q2['pv']['v'].unit == u.kpc / u.Myr
453        assert q2['t'].unit == u.Myr
454        assert np.all(q2['pv']['p'] == q_pv_t['pv']['p'].to(u.kpc))
455        assert np.all(q2['pv']['v'] == q_pv_t['pv']['v'].to(u.kpc/u.Myr))
456        assert np.all(q2['t'] == q_pv_t['t'].to(u.Myr))
457
458    def test_conversion_via_lshift(self):
459        q_pv = Quantity(self.pv, self.pv_unit)
460        q1 = q_pv << StructuredUnit(('AU', 'AU/day'))
461        assert isinstance(q1, Quantity)
462        assert q1['p'].unit == u.AU
463        assert q1['v'].unit == u.AU / u.day
464        assert np.all(q1['p'] == q_pv['p'].to(u.AU))
465        assert np.all(q1['v'] == q_pv['v'].to(u.AU/u.day))
466        q2 = q_pv << self.pv_unit
467        assert q2['p'].unit == self.p_unit
468        assert q2['v'].unit == self.v_unit
469        assert np.all(q2['p'].value == self.pv['p'])
470        assert np.all(q2['v'].value == self.pv['v'])
471        assert np.may_share_memory(q2, q_pv)
472        q_pv_t = Quantity(self.pv_t, self.pv_t_unit)
473        q2 = q_pv_t << '(kpc,kpc/Myr),Myr'
474        assert q2['pv']['p'].unit == u.kpc
475        assert q2['pv']['v'].unit == u.kpc / u.Myr
476        assert q2['t'].unit == u.Myr
477        assert np.all(q2['pv']['p'] == q_pv_t['pv']['p'].to(u.kpc))
478        assert np.all(q2['pv']['v'] == q_pv_t['pv']['v'].to(u.kpc/u.Myr))
479        assert np.all(q2['t'] == q_pv_t['t'].to(u.Myr))
480
481    def test_inplace_conversion(self):
482        q_pv = Quantity(self.pv, self.pv_unit)
483        q1 = q_pv.copy()
484        q_link = q1
485        q1 <<= StructuredUnit(('AU', 'AU/day'))
486        assert q1 is q_link
487        assert q1['p'].unit == u.AU
488        assert q1['v'].unit == u.AU / u.day
489        assert np.all(q1['p'] == q_pv['p'].to(u.AU))
490        assert np.all(q1['v'] == q_pv['v'].to(u.AU/u.day))
491        q_pv_t = Quantity(self.pv_t, self.pv_t_unit)
492        q2 = q_pv_t.copy()
493        q_link = q2
494        q2 <<= '(kpc,kpc/Myr),Myr'
495        assert q2 is q_link
496        assert q2['pv']['p'].unit == u.kpc
497        assert q2['pv']['v'].unit == u.kpc / u.Myr
498        assert q2['t'].unit == u.Myr
499        assert np.all(q2['pv']['p'] == q_pv_t['pv']['p'].to(u.kpc))
500        assert np.all(q2['pv']['v'] == q_pv_t['pv']['v'].to(u.kpc/u.Myr))
501        assert np.all(q2['t'] == q_pv_t['t'].to(u.Myr))
502
503    def test_si(self):
504        q_pv_t = Quantity(self.pv_t, self.pv_t_unit)
505        q_pv_t_si = q_pv_t.si
506        assert_array_equal(q_pv_t_si, q_pv_t.to('(m,m/s),s'))
507
508    def test_cgs(self):
509        q_pv_t = Quantity(self.pv_t, self.pv_t_unit)
510        q_pv_t_cgs = q_pv_t.cgs
511        assert_array_equal(q_pv_t_cgs, q_pv_t.to('(cm,cm/s),s'))
512
513    def test_equality(self):
514        q_pv = Quantity(self.pv, self.pv_unit)
515        equal = q_pv == q_pv
516        not_equal = q_pv != q_pv
517        assert np.all(equal)
518        assert not np.any(not_equal)
519        equal2 = q_pv == q_pv[1]
520        not_equal2 = q_pv != q_pv[1]
521        assert np.all(equal2 == [False, True, False])
522        assert np.all(not_equal2 != equal2)
523        q1 = q_pv.to(('AU', 'AU/day'))
524        # Ensure same conversion is done, by placing q1 first.
525        assert np.all(q1 == q_pv)
526        assert not np.any(q1 != q_pv)
527        # Check different names in dtype.
528        assert np.all(q1.value * u.Unit('AU, AU/day') == q_pv)
529        assert not np.any(q1.value * u.Unit('AU, AU/day') != q_pv)
530        assert (q_pv == 'b') is False
531        assert ('b' != q_pv) is True
532        q_pv_t = Quantity(self.pv_t, self.pv_t_unit)
533        assert np.all((q_pv_t[2] == q_pv_t) == [False, False, True])
534        assert np.all((q_pv_t[2] != q_pv_t) != [False, False, True])
535        assert (q_pv == q_pv_t) is False
536        assert (q_pv_t != q_pv) is True
537
538    def test_setitem(self):
539        q_pv = Quantity(self.pv, self.pv_unit)
540        q_pv[1] = (2., 2.) * self.pv_unit
541        assert q_pv[1].value == np.array((2., 2.), self.pv_dtype)
542        q_pv[1:2] = (1., 0.5) * u.Unit('AU, AU/day')
543        assert q_pv['p'][1] == 1. * u.AU
544        assert q_pv['v'][1] == 0.5 * u.AU / u.day
545        q_pv['v'] = 1. * u.km / u.s
546        assert np.all(q_pv['v'] == 1. * u.km / u.s)
547        with pytest.raises(u.UnitsError):
548            q_pv[1] = (1., 1.) * u.Unit('AU, AU')
549        with pytest.raises(u.UnitsError):
550            q_pv['v'] = 1. * u.km
551        q_pv_t = Quantity(self.pv_t, self.pv_t_unit)
552        q_pv_t[1] = ((2., 2.), 3.) * self.pv_t_unit
553        assert q_pv_t[1].value == np.array(((2., 2.), 3.), self.pv_t_dtype)
554        q_pv_t[1:2] = ((1., 0.5), 5.) * u.Unit('(AU, AU/day), yr')
555        assert q_pv_t['pv'][1] == (1., 0.5) * u.Unit('AU, AU/day')
556        assert q_pv_t['t'][1] == 5. * u.yr
557        q_pv_t['pv'] = (1., 0.5) * self.pv_unit
558        assert np.all(q_pv_t['pv'] == (1., 0.5) * self.pv_unit)
559
560
561class TestStructuredQuantityFunctions(StructuredTestBaseWithUnits):
562    @classmethod
563    def setup_class(self):
564        super().setup_class()
565        self.q_pv = self.pv << self.pv_unit
566        self.q_pv_t = self.pv_t << self.pv_t_unit
567
568    def test_empty_like(self):
569        z = np.empty_like(self.q_pv)
570        assert z.dtype == self.pv_dtype
571        assert z.unit == self.pv_unit
572        assert z.shape == self.pv.shape
573
574    @pytest.mark.parametrize('func', [np.zeros_like, np.ones_like])
575    def test_zeros_ones_like(self, func):
576        z = func(self.q_pv)
577        assert z.dtype == self.pv_dtype
578        assert z.unit == self.pv_unit
579        assert z.shape == self.pv.shape
580        assert_array_equal(z, func(self.pv) << self.pv_unit)
581
582
583class TestStructuredSpecificTypeQuantity(StructuredTestBaseWithUnits):
584    def setup_class(self):
585        super().setup_class()
586
587        class PositionVelocity(u.SpecificTypeQuantity):
588            _equivalent_unit = self.pv_unit
589
590        self.PositionVelocity = PositionVelocity
591
592    def test_init(self):
593        pv = self.PositionVelocity(self.pv, self.pv_unit)
594        assert isinstance(pv, self.PositionVelocity)
595        assert type(pv['p']) is u.Quantity
596        assert_array_equal(pv['p'], self.pv['p'] << self.pv_unit['p'])
597
598        pv2 = self.PositionVelocity(self.pv, 'AU,AU/day')
599        assert_array_equal(pv2['p'], self.pv['p'] << u.AU)
600
601    def test_error_on_non_equivalent_unit(self):
602        with pytest.raises(u.UnitsError):
603            self.PositionVelocity(self.pv, 'AU')
604        with pytest.raises(u.UnitsError):
605            self.PositionVelocity(self.pv, 'AU,yr')
606
607
608class TestStructuredLogUnit:
609    def setup_class(self):
610        self.mag_time_dtype = np.dtype([('mag', 'f8'), ('t', 'f8')])
611        self.mag_time = np.array([(20., 10.), (25., 100.)], self.mag_time_dtype)
612
613    def test_unit_initialization(self):
614        mag_time_unit = StructuredUnit((u.STmag, u.s), self.mag_time_dtype)
615        assert mag_time_unit['mag'] == u.STmag
616        assert mag_time_unit['t'] == u.s
617
618        mag_time_unit2 = u.Unit('mag(ST),s')
619        assert mag_time_unit2 == mag_time_unit
620
621    def test_quantity_initialization(self):
622        su = u.Unit('mag(ST),s')
623        mag_time = self.mag_time << su
624        assert isinstance(mag_time['mag'], u.Magnitude)
625        assert isinstance(mag_time['t'], u.Quantity)
626        assert mag_time.unit == su
627        assert_array_equal(mag_time['mag'], self.mag_time['mag'] << u.STmag)
628        assert_array_equal(mag_time['t'], self.mag_time['t'] << u.s)
629
630    def test_quantity_si(self):
631        mag_time = self.mag_time << u.Unit('mag(ST),yr')
632        mag_time_si = mag_time.si
633        assert_array_equal(mag_time_si['mag'], mag_time['mag'].si)
634        assert_array_equal(mag_time_si['t'], mag_time['t'].si)
635
636
637class TestStructuredMaskedQuantity(StructuredTestBaseWithUnits):
638    """Somewhat minimal tests.  Conversion is most stringent."""
639    def setup_class(self):
640        super().setup_class()
641        self.qpv = self.pv << self.pv_unit
642        self.pv_mask = np.array([(True, False),
643                                 (False, False),
644                                 (False, True)], [('p', bool), ('v', bool)])
645        self.mpv = Masked(self.qpv, mask=self.pv_mask)
646
647    def test_init(self):
648        assert isinstance(self.mpv, Masked)
649        assert isinstance(self.mpv, Quantity)
650        assert_array_equal(self.mpv.unmasked, self.qpv)
651        assert_array_equal(self.mpv.mask, self.pv_mask)
652
653    def test_slicing(self):
654        mp = self.mpv['p']
655        assert isinstance(mp, Masked)
656        assert isinstance(mp, Quantity)
657        assert_array_equal(mp.unmasked, self.qpv['p'])
658        assert_array_equal(mp.mask, self.pv_mask['p'])
659
660    def test_conversion(self):
661        mpv = self.mpv.to('AU,AU/day')
662        assert isinstance(mpv, Masked)
663        assert isinstance(mpv, Quantity)
664        assert_array_equal(mpv.unmasked, self.qpv.to('AU,AU/day'))
665        assert_array_equal(mpv.mask, self.pv_mask)
666        assert np.all(mpv == self.mpv)
667
668    def test_si(self):
669        mpv = self.mpv.si
670        assert isinstance(mpv, Masked)
671        assert isinstance(mpv, Quantity)
672        assert_array_equal(mpv.unmasked, self.qpv.si)
673        assert_array_equal(mpv.mask, self.pv_mask)
674        assert np.all(mpv == self.mpv)
675