1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3from astropy.table.table_helpers import ArrayWrapper
4from astropy.coordinates.earth import EarthLocation
5from astropy.units.quantity import Quantity
6from collections import OrderedDict
7from contextlib import nullcontext
8
9import pytest
10import numpy as np
11
12from astropy.table import Table, QTable, TableMergeError, Column, MaskedColumn, NdarrayMixin
13from astropy.table.operations import _get_out_class, join_skycoord, join_distance
14from astropy import units as u
15from astropy.utils import metadata
16from astropy.utils.metadata import MergeConflictError
17from astropy import table
18from astropy.time import Time, TimeDelta
19from astropy.coordinates import (SkyCoord, SphericalRepresentation,
20                                 UnitSphericalRepresentation,
21                                 CartesianRepresentation,
22                                 BaseRepresentationOrDifferential,
23                                 search_around_3d)
24from astropy.coordinates.tests.test_representation import representation_equal
25from astropy.io.misc.asdf.tags.helpers import skycoord_equal
26from astropy.utils.compat.optional_deps import HAS_SCIPY  # noqa
27
28
29def sort_eq(list1, list2):
30    return sorted(list1) == sorted(list2)
31
32
33def check_mask(col, exp_mask):
34    """Check that col.mask == exp_mask"""
35    if hasattr(col, 'mask'):
36        # Coerce expected mask into dtype of col.mask. In particular this is
37        # needed for types like EarthLocation where the mask is a structured
38        # array.
39        exp_mask = np.array(exp_mask).astype(col.mask.dtype)
40        out = np.all(col.mask == exp_mask)
41    else:
42        # With no mask the check is OK if all the expected mask values
43        # are False (i.e. no auto-conversion to MaskedQuantity if it was
44        # not required by the join).
45        out = np.all(exp_mask == False)
46    return out
47
48
49class TestJoin():
50
51    def _setup(self, t_cls=Table):
52        lines1 = [' a   b   c ',
53                  '  0 foo  L1',
54                  '  1 foo  L2',
55                  '  1 bar  L3',
56                  '  2 bar  L4']
57        lines2 = [' a   b   d ',
58                  '  1 foo  R1',
59                  '  1 foo  R2',
60                  '  2 bar  R3',
61                  '  4 bar  R4']
62        self.t1 = t_cls.read(lines1, format='ascii')
63        self.t2 = t_cls.read(lines2, format='ascii')
64        self.t3 = t_cls(self.t2, copy=True)
65
66        self.t1.meta.update(OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)]))
67        self.t2.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)]))
68        self.t3.meta.update(OrderedDict([('b', 3), ('c', [1, 2]), ('d', 2), ('a', 1)]))
69
70        self.meta_merge = OrderedDict([('b', [1, 2, 3, 4]),
71                                       ('c', {'a': 1, 'b': 1}),
72                                       ('d', 1),
73                                       ('a', 1)])
74
75    def test_table_meta_merge(self, operation_table_type):
76        self._setup(operation_table_type)
77        out = table.join(self.t1, self.t2, join_type='inner')
78        assert out.meta == self.meta_merge
79
80    def test_table_meta_merge_conflict(self, operation_table_type):
81        self._setup(operation_table_type)
82
83        with pytest.warns(metadata.MergeConflictWarning) as w:
84            out = table.join(self.t1, self.t3, join_type='inner')
85        assert len(w) == 3
86
87        assert out.meta == self.t3.meta
88
89        with pytest.warns(metadata.MergeConflictWarning) as w:
90            out = table.join(self.t1, self.t3, join_type='inner', metadata_conflicts='warn')
91        assert len(w) == 3
92
93        assert out.meta == self.t3.meta
94
95        out = table.join(self.t1, self.t3, join_type='inner', metadata_conflicts='silent')
96
97        assert out.meta == self.t3.meta
98
99        with pytest.raises(MergeConflictError):
100            out = table.join(self.t1, self.t3, join_type='inner', metadata_conflicts='error')
101
102        with pytest.raises(ValueError):
103            out = table.join(self.t1, self.t3, join_type='inner', metadata_conflicts='nonsense')
104
105    def test_both_unmasked_inner(self, operation_table_type):
106        self._setup(operation_table_type)
107        t1 = self.t1
108        t2 = self.t2
109
110        # Basic join with default parameters (inner join on common keys)
111        t12 = table.join(t1, t2)
112        assert type(t12) is operation_table_type
113        assert type(t12['a']) is type(t1['a'])  # noqa
114        assert type(t12['b']) is type(t1['b'])  # noqa
115        assert type(t12['c']) is type(t1['c'])  # noqa
116        assert type(t12['d']) is type(t2['d'])  # noqa
117        assert t12.masked is False
118        assert sort_eq(t12.pformat(), [' a   b   c   d ',
119                                       '--- --- --- ---',
120                                       '  1 foo  L2  R1',
121                                       '  1 foo  L2  R2',
122                                       '  2 bar  L4  R3'])
123        # Table meta merged properly
124        assert t12.meta == self.meta_merge
125
126    def test_both_unmasked_left_right_outer(self, operation_table_type):
127        if operation_table_type is QTable:
128            pytest.xfail('Quantity columns do not support masking.')
129        self._setup(operation_table_type)
130        t1 = self.t1
131        t2 = self.t2
132
133        # Left join
134        t12 = table.join(t1, t2, join_type='left')
135        assert t12.has_masked_columns is True
136        assert t12.masked is False
137        for name in ('a', 'b', 'c'):
138            assert type(t12[name]) is Column
139        assert type(t12['d']) is MaskedColumn
140        assert sort_eq(t12.pformat(), [' a   b   c   d ',
141                                       '--- --- --- ---',
142                                       '  0 foo  L1  --',
143                                       '  1 bar  L3  --',
144                                       '  1 foo  L2  R1',
145                                       '  1 foo  L2  R2',
146                                       '  2 bar  L4  R3'])
147
148        # Right join
149        t12 = table.join(t1, t2, join_type='right')
150        assert t12.has_masked_columns is True
151        assert t12.masked is False
152        assert sort_eq(t12.pformat(), [' a   b   c   d ',
153                                       '--- --- --- ---',
154                                       '  1 foo  L2  R1',
155                                       '  1 foo  L2  R2',
156                                       '  2 bar  L4  R3',
157                                       '  4 bar  --  R4'])
158
159        # Outer join
160        t12 = table.join(t1, t2, join_type='outer')
161        assert t12.has_masked_columns is True
162        assert t12.masked is False
163        assert sort_eq(t12.pformat(), [' a   b   c   d ',
164                                       '--- --- --- ---',
165                                       '  0 foo  L1  --',
166                                       '  1 bar  L3  --',
167                                       '  1 foo  L2  R1',
168                                       '  1 foo  L2  R2',
169                                       '  2 bar  L4  R3',
170                                       '  4 bar  --  R4'])
171
172        # Check that the common keys are 'a', 'b'
173        t12a = table.join(t1, t2, join_type='outer')
174        t12b = table.join(t1, t2, join_type='outer', keys=['a', 'b'])
175        assert np.all(t12a.as_array() == t12b.as_array())
176
177    def test_both_unmasked_single_key_inner(self, operation_table_type):
178        self._setup(operation_table_type)
179        t1 = self.t1
180        t2 = self.t2
181
182        # Inner join on 'a' column
183        t12 = table.join(t1, t2, keys='a')
184        assert type(t12) is operation_table_type
185        assert type(t12['a']) is type(t1['a'])  # noqa
186        assert type(t12['b_1']) is type(t1['b'])  # noqa
187        assert type(t12['c']) is type(t1['c'])  # noqa
188        assert type(t12['b_2']) is type(t2['b'])  # noqa
189        assert type(t12['d']) is type(t2['d'])  # noqa
190        assert t12.masked is False
191        assert sort_eq(t12.pformat(), [' a  b_1  c  b_2  d ',
192                                       '--- --- --- --- ---',
193                                       '  1 foo  L2 foo  R1',
194                                       '  1 foo  L2 foo  R2',
195                                       '  1 bar  L3 foo  R1',
196                                       '  1 bar  L3 foo  R2',
197                                       '  2 bar  L4 bar  R3'])
198
199    def test_both_unmasked_single_key_left_right_outer(self, operation_table_type):
200        if operation_table_type is QTable:
201            pytest.xfail('Quantity columns do not support masking.')
202        self._setup(operation_table_type)
203        t1 = self.t1
204        t2 = self.t2
205
206        # Left join
207        t12 = table.join(t1, t2, join_type='left', keys='a')
208        assert t12.has_masked_columns is True
209        assert sort_eq(t12.pformat(), [' a  b_1  c  b_2  d ',
210                                       '--- --- --- --- ---',
211                                       '  0 foo  L1  --  --',
212                                       '  1 foo  L2 foo  R1',
213                                       '  1 foo  L2 foo  R2',
214                                       '  1 bar  L3 foo  R1',
215                                       '  1 bar  L3 foo  R2',
216                                       '  2 bar  L4 bar  R3'])
217
218        # Right join
219        t12 = table.join(t1, t2, join_type='right', keys='a')
220        assert t12.has_masked_columns is True
221        assert sort_eq(t12.pformat(), [' a  b_1  c  b_2  d ',
222                                       '--- --- --- --- ---',
223                                       '  1 foo  L2 foo  R1',
224                                       '  1 foo  L2 foo  R2',
225                                       '  1 bar  L3 foo  R1',
226                                       '  1 bar  L3 foo  R2',
227                                       '  2 bar  L4 bar  R3',
228                                       '  4  --  -- bar  R4'])
229
230        # Outer join
231        t12 = table.join(t1, t2, join_type='outer', keys='a')
232        assert t12.has_masked_columns is True
233        assert sort_eq(t12.pformat(), [' a  b_1  c  b_2  d ',
234                                       '--- --- --- --- ---',
235                                       '  0 foo  L1  --  --',
236                                       '  1 foo  L2 foo  R1',
237                                       '  1 foo  L2 foo  R2',
238                                       '  1 bar  L3 foo  R1',
239                                       '  1 bar  L3 foo  R2',
240                                       '  2 bar  L4 bar  R3',
241                                       '  4  --  -- bar  R4'])
242
243    def test_masked_unmasked(self, operation_table_type):
244        if operation_table_type is QTable:
245            pytest.xfail('Quantity columns do not support masking.')
246        self._setup(operation_table_type)
247        t1 = self.t1
248        t1m = operation_table_type(self.t1, masked=True)
249        t2 = self.t2
250
251        # Result table is never masked
252        t1m2 = table.join(t1m, t2, join_type='inner')
253        assert t1m2.masked is False
254
255        # Result should match non-masked result
256        t12 = table.join(t1, t2)
257        assert np.all(t12.as_array() == np.array(t1m2))
258
259        # Mask out some values in left table and make sure they propagate
260        t1m['b'].mask[1] = True
261        t1m['c'].mask[2] = True
262        t1m2 = table.join(t1m, t2, join_type='inner', keys='a')
263        assert sort_eq(t1m2.pformat(), [' a  b_1  c  b_2  d ',
264                                        '--- --- --- --- ---',
265                                        '  1  --  L2 foo  R1',
266                                        '  1  --  L2 foo  R2',
267                                        '  1 bar  -- foo  R1',
268                                        '  1 bar  -- foo  R2',
269                                        '  2 bar  L4 bar  R3'])
270
271        t21m = table.join(t2, t1m, join_type='inner', keys='a')
272        assert sort_eq(t21m.pformat(), [' a  b_1  d  b_2  c ',
273                                        '--- --- --- --- ---',
274                                        '  1 foo  R2  --  L2',
275                                        '  1 foo  R2 bar  --',
276                                        '  1 foo  R1  --  L2',
277                                        '  1 foo  R1 bar  --',
278                                        '  2 bar  R3 bar  L4'])
279
280    def test_masked_masked(self, operation_table_type):
281        self._setup(operation_table_type)
282        """Two masked tables"""
283        if operation_table_type is QTable:
284            pytest.xfail('Quantity columns do not support masking.')
285        t1 = self.t1
286        t1m = operation_table_type(self.t1, masked=True)
287        t2 = self.t2
288        t2m = operation_table_type(self.t2, masked=True)
289
290        # Result table is never masked but original column types are preserved
291        t1m2m = table.join(t1m, t2m, join_type='inner')
292        assert t1m2m.masked is False
293        for col in t1m2m.itercols():
294            assert type(col) is MaskedColumn
295
296        # Result should match non-masked result
297        t12 = table.join(t1, t2)
298        assert np.all(t12.as_array() == np.array(t1m2m))
299
300        # Mask out some values in both tables and make sure they propagate
301        t1m['b'].mask[1] = True
302        t1m['c'].mask[2] = True
303        t2m['d'].mask[2] = True
304        t1m2m = table.join(t1m, t2m, join_type='inner', keys='a')
305        assert sort_eq(t1m2m.pformat(), [' a  b_1  c  b_2  d ',
306                                         '--- --- --- --- ---',
307                                         '  1  --  L2 foo  R1',
308                                         '  1  --  L2 foo  R2',
309                                         '  1 bar  -- foo  R1',
310                                         '  1 bar  -- foo  R2',
311                                         '  2 bar  L4 bar  --'])
312
313    def test_classes(self):
314        """Ensure that classes and subclasses get through as expected"""
315        class MyCol(Column):
316            pass
317
318        class MyMaskedCol(MaskedColumn):
319            pass
320
321        t1 = Table()
322        t1['a'] = MyCol([1])
323        t1['b'] = MyCol([2])
324        t1['c'] = MyMaskedCol([3])
325
326        t2 = Table()
327        t2['a'] = Column([1, 2])
328        t2['d'] = MyCol([3, 4])
329        t2['e'] = MyMaskedCol([5, 6])
330
331        t12 = table.join(t1, t2, join_type='inner')
332        for name, exp_type in (('a', MyCol), ('b', MyCol), ('c', MyMaskedCol),
333                               ('d', MyCol), ('e', MyMaskedCol)):
334            assert type(t12[name] is exp_type)
335
336        t21 = table.join(t2, t1, join_type='left')
337        # Note col 'b' gets upgraded from MyCol to MaskedColumn since it needs to be
338        # masked, but col 'c' stays since MyMaskedCol supports masking.
339        for name, exp_type in (('a', MyCol), ('b', MaskedColumn), ('c', MyMaskedCol),
340                               ('d', MyCol), ('e', MyMaskedCol)):
341            assert type(t21[name] is exp_type)
342
343    def test_col_rename(self, operation_table_type):
344        self._setup(operation_table_type)
345        """
346        Test auto col renaming when there is a conflict.  Use
347        non-default values of uniq_col_name and table_names.
348        """
349        t1 = self.t1
350        t2 = self.t2
351        t12 = table.join(t1, t2, uniq_col_name='x_{table_name}_{col_name}_y',
352                         table_names=['L', 'R'], keys='a')
353        assert t12.colnames == ['a', 'x_L_b_y', 'c', 'x_R_b_y', 'd']
354
355    def test_rename_conflict(self, operation_table_type):
356        self._setup(operation_table_type)
357        """
358        Test that auto-column rename fails because of a conflict
359        with an existing column
360        """
361        t1 = self.t1
362        t2 = self.t2
363        t1['b_1'] = 1  # Add a new column b_1 that will conflict with auto-rename
364        with pytest.raises(TableMergeError):
365            table.join(t1, t2, keys='a')
366
367    def test_missing_keys(self, operation_table_type):
368        self._setup(operation_table_type)
369        """Merge on a key column that doesn't exist"""
370        t1 = self.t1
371        t2 = self.t2
372        with pytest.raises(TableMergeError):
373            table.join(t1, t2, keys=['a', 'not there'])
374
375    def test_bad_join_type(self, operation_table_type):
376        self._setup(operation_table_type)
377        """Bad join_type input"""
378        t1 = self.t1
379        t2 = self.t2
380        with pytest.raises(ValueError):
381            table.join(t1, t2, join_type='illegal value')
382
383    def test_no_common_keys(self, operation_table_type):
384        self._setup(operation_table_type)
385        """Merge tables with no common keys"""
386        t1 = self.t1
387        t2 = self.t2
388        del t1['a']
389        del t1['b']
390        del t2['a']
391        del t2['b']
392        with pytest.raises(TableMergeError):
393            table.join(t1, t2)
394
395    def test_masked_key_column(self, operation_table_type):
396        self._setup(operation_table_type)
397        """Merge on a key column that has a masked element"""
398        if operation_table_type is QTable:
399            pytest.xfail('Quantity columns do not support masking.')
400        t1 = self.t1
401        t2 = operation_table_type(self.t2, masked=True)
402        table.join(t1, t2)  # OK
403        t2['a'].mask[0] = True
404        with pytest.raises(TableMergeError):
405            table.join(t1, t2)
406
407    def test_col_meta_merge(self, operation_table_type):
408        self._setup(operation_table_type)
409        t1 = self.t1
410        t2 = self.t2
411        t2.rename_column('d', 'c')  # force col conflict and renaming
412        meta1 = OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)])
413        meta2 = OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)])
414
415        # Key col 'a', should first value ('cm')
416        t1['a'].unit = 'cm'
417        t2['a'].unit = 'm'
418        # Key col 'b', take first value 't1_b'
419        t1['b'].info.description = 't1_b'
420        # Key col 'b', take first non-empty value 't1_b'
421        t2['b'].info.format = '%6s'
422        # Key col 'a', should be merged meta
423        t1['a'].info.meta = meta1
424        t2['a'].info.meta = meta2
425        # Key col 'b', should be meta2
426        t2['b'].info.meta = meta2
427
428        # All these should pass through
429        t1['c'].info.format = '%3s'
430        t1['c'].info.description = 't1_c'
431
432        t2['c'].info.format = '%6s'
433        t2['c'].info.description = 't2_c'
434
435        if operation_table_type is Table:
436            ctx = pytest.warns(metadata.MergeConflictWarning, match=r"In merged column 'a' the 'unit' attribute does not match \(cm != m\)")  # noqa
437        else:
438            ctx = nullcontext()
439
440        with ctx:
441            t12 = table.join(t1, t2, keys=['a', 'b'])
442
443        assert t12['a'].unit == 'm'
444        assert t12['b'].info.description == 't1_b'
445        assert t12['b'].info.format == '%6s'
446        assert t12['a'].info.meta == self.meta_merge
447        assert t12['b'].info.meta == meta2
448        assert t12['c_1'].info.format == '%3s'
449        assert t12['c_1'].info.description == 't1_c'
450        assert t12['c_2'].info.format == '%6s'
451        assert t12['c_2'].info.description == 't2_c'
452
453    def test_join_multidimensional(self, operation_table_type):
454        self._setup(operation_table_type)
455
456        # Regression test for #2984, which was an issue where join did not work
457        # on multi-dimensional columns.
458
459        t1 = operation_table_type()
460        t1['a'] = [1, 2, 3]
461        t1['b'] = np.ones((3, 4))
462
463        t2 = operation_table_type()
464        t2['a'] = [1, 2, 3]
465        t2['c'] = [4, 5, 6]
466
467        t3 = table.join(t1, t2)
468
469        np.testing.assert_allclose(t3['a'], t1['a'])
470        np.testing.assert_allclose(t3['b'], t1['b'])
471        np.testing.assert_allclose(t3['c'], t2['c'])
472
473    def test_join_multidimensional_masked(self, operation_table_type):
474        self._setup(operation_table_type)
475        """
476        Test for outer join with multidimensional columns where masking is required.
477        (Issue #4059).
478        """
479        if operation_table_type is QTable:
480            pytest.xfail('Quantity columns do not support masking.')
481
482        a = table.MaskedColumn([1, 2, 3], name='a')
483        a2 = table.Column([1, 3, 4], name='a')
484        b = table.MaskedColumn([[1, 2],
485                                [3, 4],
486                                [5, 6]],
487                               name='b',
488                               mask=[[1, 0],
489                                     [0, 1],
490                                     [0, 0]])
491        c = table.Column([[1, 1],
492                          [2, 2],
493                          [3, 3]],
494                         name='c')
495        t1 = operation_table_type([a, b])
496        t2 = operation_table_type([a2, c])
497        t12 = table.join(t1, t2, join_type='inner')
498
499        assert np.all(t12['b'].mask == [[True, False],
500                                        [False, False]])
501        assert not hasattr(t12['c'], 'mask')
502
503        t12 = table.join(t1, t2, join_type='outer')
504        assert np.all(t12['b'].mask == [[True, False],
505                                        [False, True],
506                                        [False, False],
507                                        [True, True]])
508        assert np.all(t12['c'].mask == [[False, False],
509                                        [True, True],
510                                        [False, False],
511                                        [False, False]])
512
513    def test_mixin_functionality(self, mixin_cols):
514        col = mixin_cols['m']
515        cls_name = type(col).__name__
516        len_col = len(col)
517        idx = np.arange(len_col)
518        t1 = table.QTable([idx, col], names=['idx', 'm1'])
519        t2 = table.QTable([idx, col], names=['idx', 'm2'])
520        # Set up join mismatches for different join_type cases
521        t1 = t1[[0, 1, 3]]
522        t2 = t2[[0, 2, 3]]
523
524        # Test inner join, which works for all mixin_cols
525        out = table.join(t1, t2, join_type='inner')
526        assert len(out) == 2
527        assert out['m2'].__class__ is col.__class__
528        assert np.all(out['idx'] == [0, 3])
529        if cls_name == 'SkyCoord':
530            # SkyCoord doesn't support __eq__ so use our own
531            assert skycoord_equal(out['m1'], col[[0, 3]])
532            assert skycoord_equal(out['m2'], col[[0, 3]])
533        elif 'Repr' in cls_name or 'Diff' in cls_name:
534            assert np.all(representation_equal(out['m1'], col[[0, 3]]))
535            assert np.all(representation_equal(out['m2'], col[[0, 3]]))
536        else:
537            assert np.all(out['m1'] == col[[0, 3]])
538            assert np.all(out['m2'] == col[[0, 3]])
539
540        # Check for left, right, outer join which requires masking. Works for
541        # the listed mixins classes.
542        if isinstance(col, (Quantity, Time, TimeDelta)):
543            out = table.join(t1, t2, join_type='left')
544            assert len(out) == 3
545            assert np.all(out['idx'] == [0, 1, 3])
546            assert np.all(out['m1'] == t1['m1'])
547            assert np.all(out['m2'] == t2['m2'])
548            check_mask(out['m1'], [False, False, False])
549            check_mask(out['m2'], [False, True, False])
550
551            out = table.join(t1, t2, join_type='right')
552            assert len(out) == 3
553            assert np.all(out['idx'] == [0, 2, 3])
554            assert np.all(out['m1'] == t1['m1'])
555            assert np.all(out['m2'] == t2['m2'])
556            check_mask(out['m1'], [False, True, False])
557            check_mask(out['m2'], [False, False, False])
558
559            out = table.join(t1, t2, join_type='outer')
560            assert len(out) == 4
561            assert np.all(out['idx'] == [0, 1, 2, 3])
562            assert np.all(out['m1'] == col)
563            assert np.all(out['m2'] == col)
564            assert check_mask(out['m1'], [False, False, True, False])
565            assert check_mask(out['m2'], [False, True, False, False])
566        else:
567            # Otherwise make sure it fails with the right exception message
568            for join_type in ('outer', 'left', 'right'):
569                with pytest.raises(NotImplementedError) as err:
570                    table.join(t1, t2, join_type=join_type)
571                assert ('join requires masking' in str(err.value)
572                        or 'join unavailable' in str(err.value))
573
574    def test_cartesian_join(self, operation_table_type):
575        t1 = Table(rows=[(1, 'a'),
576                         (2, 'b')], names=['a', 'b'])
577        t2 = Table(rows=[(3, 'c'),
578                         (4, 'd')], names=['a', 'c'])
579        t12 = table.join(t1, t2, join_type='cartesian')
580
581        assert t1.colnames == ['a', 'b']
582        assert t2.colnames == ['a', 'c']
583        assert len(t12) == len(t1) * len(t2)
584        assert str(t12).splitlines() == [
585            'a_1  b  a_2  c ',
586            '--- --- --- ---',
587            '  1   a   3   c',
588            '  1   a   4   d',
589            '  2   b   3   c',
590            '  2   b   4   d']
591
592        with pytest.raises(ValueError, match='cannot supply keys for a cartesian join'):
593            t12 = table.join(t1, t2, join_type='cartesian', keys='a')
594
595    @pytest.mark.skipif('not HAS_SCIPY')
596    def test_join_with_join_skycoord_sky(self):
597        sc1 = SkyCoord([0, 1, 1.1, 2], [0, 0, 0, 0], unit='deg')
598        sc2 = SkyCoord([0.5, 1.05, 2.1], [0, 0, 0], unit='deg')
599        t1 = Table([sc1], names=['sc'])
600        t2 = Table([sc2], names=['sc'])
601        t12 = table.join(t1, t2, join_funcs={'sc': join_skycoord(0.2 * u.deg)})
602        exp = ['sc_id   sc_1    sc_2  ',
603               '      deg,deg deg,deg ',
604               '----- ------- --------',
605               '    1 1.0,0.0 1.05,0.0',
606               '    1 1.1,0.0 1.05,0.0',
607               '    2 2.0,0.0  2.1,0.0']
608        assert str(t12).splitlines() == exp
609
610    @pytest.mark.skipif('not HAS_SCIPY')
611    @pytest.mark.parametrize('distance_func', ['search_around_3d', search_around_3d])
612    def test_join_with_join_skycoord_3d(self, distance_func):
613        sc1 = SkyCoord([0, 1, 1.1, 2]*u.deg, [0, 0, 0, 0]*u.deg, [1, 1, 2, 1]*u.m)
614        sc2 = SkyCoord([0.5, 1.05, 2.1]*u.deg, [0, 0, 0]*u.deg, [1, 1, 1]*u.m)
615        t1 = Table([sc1], names=['sc'])
616        t2 = Table([sc2], names=['sc'])
617        join_func = join_skycoord(np.deg2rad(0.2) * u.m,
618                                  distance_func=distance_func)
619        t12 = table.join(t1, t2, join_funcs={'sc': join_func})
620        exp = ['sc_id     sc_1        sc_2    ',
621               '       deg,deg,m   deg,deg,m  ',
622               '----- ----------- ------------',
623               '    1 1.0,0.0,1.0 1.05,0.0,1.0',
624               '    2 2.0,0.0,1.0  2.1,0.0,1.0']
625        assert str(t12).splitlines() == exp
626
627    @pytest.mark.skipif('not HAS_SCIPY')
628    def test_join_with_join_distance_1d(self):
629        c1 = [0, 1, 1.1, 2]
630        c2 = [0.5, 1.05, 2.1]
631        t1 = Table([c1], names=['col'])
632        t2 = Table([c2], names=['col'])
633        join_func = join_distance(0.2,
634                                  kdtree_args={'leafsize': 32},
635                                  query_args={'p': 2})
636        t12 = table.join(t1, t2, join_type='outer', join_funcs={'col': join_func})
637        exp = ['col_id col_1 col_2',
638               '------ ----- -----',
639               '     1   1.0  1.05',
640               '     1   1.1  1.05',
641               '     2   2.0   2.1',
642               '     3   0.0    --',
643               '     4    --   0.5']
644        assert str(t12).splitlines() == exp
645
646    @pytest.mark.skipif('not HAS_SCIPY')
647    def test_join_with_join_distance_1d_multikey(self):
648        from astropy.table.operations import _apply_join_funcs
649
650        c1 = [0, 1, 1.1, 1.2, 2]
651        id1 = [0, 1, 2, 2, 3]
652        o1 = ['a', 'b', 'c', 'd', 'e']
653        c2 = [0.5, 1.05, 2.1]
654        id2 = [0, 2, 4]
655        o2 = ['z', 'y', 'x']
656        t1 = Table([c1, id1, o1], names=['col', 'id', 'o1'])
657        t2 = Table([c2, id2, o2], names=['col', 'id', 'o2'])
658        join_func = join_distance(0.2)
659        join_funcs = {'col': join_func}
660        t12 = table.join(t1, t2, join_type='outer', join_funcs=join_funcs)
661        exp = ['col_id col_1  id  o1 col_2  o2',
662               '------ ----- --- --- ----- ---',
663               '     1   1.0   1   b    --  --',
664               '     1   1.1   2   c  1.05   y',
665               '     1   1.2   2   d  1.05   y',
666               '     2   2.0   3   e    --  --',
667               '     2    --   4  --   2.1   x',
668               '     3   0.0   0   a    --  --',
669               '     4    --   0  --   0.5   z']
670        assert str(t12).splitlines() == exp
671
672        left, right, keys = _apply_join_funcs(t1, t2, ('col', 'id'), join_funcs)
673        assert keys == ('col_id', 'id')
674
675    @pytest.mark.skipif('not HAS_SCIPY')
676    def test_join_with_join_distance_1d_quantity(self):
677        c1 = [0, 1, 1.1, 2] * u.m
678        c2 = [500, 1050, 2100] * u.mm
679        t1 = QTable([c1], names=['col'])
680        t2 = QTable([c2], names=['col'])
681        join_func = join_distance(20 * u.cm)
682        t12 = table.join(t1, t2, join_funcs={'col': join_func})
683        exp = ['col_id col_1 col_2 ',
684               '         m     mm  ',
685               '------ ----- ------',
686               '     1   1.0 1050.0',
687               '     1   1.1 1050.0',
688               '     2   2.0 2100.0']
689        assert str(t12).splitlines() == exp
690
691        # Generate column name conflict
692        t2['col_id'] = [0, 0, 0]
693        t2['col__id'] = [0, 0, 0]
694        t12 = table.join(t1, t2, join_funcs={'col': join_func})
695        exp = ['col___id col_1 col_2  col_id col__id',
696               '           m     mm                 ',
697               '-------- ----- ------ ------ -------',
698               '       1   1.0 1050.0      0       0',
699               '       1   1.1 1050.0      0       0',
700               '       2   2.0 2100.0      0       0']
701        assert str(t12).splitlines() == exp
702
703    @pytest.mark.skipif('not HAS_SCIPY')
704    def test_join_with_join_distance_2d(self):
705        c1 = np.array([[0, 1, 1.1, 2],
706                       [0, 0, 1, 0]]).transpose()
707        c2 = np.array([[0.5, 1.05, 2.1],
708                       [0, 0, 0]]).transpose()
709        t1 = Table([c1], names=['col'])
710        t2 = Table([c2], names=['col'])
711        join_func = join_distance(0.2,
712                                  kdtree_args={'leafsize': 32},
713                                  query_args={'p': 2})
714        t12 = table.join(t1, t2, join_type='outer', join_funcs={'col': join_func})
715        exp = ['col_id col_1 [2]   col_2 [2] ',
716               '------ ---------- -----------',
717               '     1 1.0 .. 0.0 1.05 .. 0.0',
718               '     2 2.0 .. 0.0  2.1 .. 0.0',
719               '     3 0.0 .. 0.0    -- .. --',
720               '     4 1.1 .. 1.0    -- .. --',
721               '     5   -- .. --  0.5 .. 0.0']
722        assert str(t12).splitlines() == exp
723
724    def test_keys_left_right_basic(self):
725        """Test using the keys_left and keys_right args to specify different
726        join keys. This takes the standard test case but renames column 'a'
727        to 'x' and 'y' respectively for tables 1 and 2. Then it compares the
728        normal join on 'a' to the new join on 'x' and 'y'."""
729        self._setup()
730
731        for join_type in ('inner', 'left', 'right', 'outer'):
732            t1 = self.t1.copy()
733            t2 = self.t2.copy()
734            # Expected is same as joining on 'a' but with names 'x', 'y' instead
735            t12_exp = table.join(t1, t2, keys='a', join_type=join_type)
736            t12_exp.add_column(t12_exp['a'], name='x', index=1)
737            t12_exp.add_column(t12_exp['a'], name='y', index=len(t1.colnames) + 1)
738            del t12_exp['a']
739
740            # Different key names
741            t1.rename_column('a', 'x')
742            t2.rename_column('a', 'y')
743            keys_left_list = ['x']  # Test string key name
744            keys_right_list = [['y']]  # Test list of string key names
745            if join_type == 'outer':
746                # Just do this for the outer join (others are the same)
747                keys_left_list.append([t1['x'].tolist()])  # Test list key column
748                keys_right_list.append([t2['y']])  # Test Column key column
749
750            for keys_left, keys_right in zip(keys_left_list, keys_right_list):
751                t12 = table.join(t1, t2, keys_left=keys_left, keys_right=keys_right,
752                                 join_type=join_type)
753
754                assert t12.colnames == t12_exp.colnames
755                for col in t12.values_equal(t12_exp).itercols():
756                    assert np.all(col)
757                assert t12_exp.meta == t12.meta
758
759    def test_keys_left_right_exceptions(self):
760        """Test exceptions using the keys_left and keys_right args to specify
761        different join keys.
762        """
763        self._setup()
764        t1 = self.t1
765        t2 = self.t2
766
767        msg = r"left table does not have key column 'z'"
768        with pytest.raises(ValueError, match=msg):
769            table.join(t1, t2, keys_left='z', keys_right=['a'])
770
771        msg = r"left table has different length from key \[1, 2\]"
772        with pytest.raises(ValueError, match=msg):
773            table.join(t1, t2, keys_left=[[1, 2]], keys_right=['a'])
774
775        msg = r"keys arg must be None if keys_left and keys_right are supplied"
776        with pytest.raises(ValueError, match=msg):
777            table.join(t1, t2, keys_left='z', keys_right=['a'], keys='a')
778
779        msg = r"keys_left and keys_right args must have same length"
780        with pytest.raises(ValueError, match=msg):
781            table.join(t1, t2, keys_left=['a', 'b'], keys_right=['a'])
782
783        msg = r"keys_left and keys_right must both be provided"
784        with pytest.raises(ValueError, match=msg):
785            table.join(t1, t2, keys_left=['a', 'b'])
786
787        msg = r"cannot supply join_funcs arg and keys_left / keys_right"
788        with pytest.raises(ValueError, match=msg):
789            table.join(t1, t2, keys_left=['a'], keys_right=['a'], join_funcs={})
790
791
792class TestSetdiff():
793
794    def _setup(self, t_cls=Table):
795        lines1 = [' a   b ',
796                  '  0 foo ',
797                  '  1 foo ',
798                  '  1 bar ',
799                  '  2 bar ']
800        lines2 = [' a   b ',
801                  '  0 foo ',
802                  '  3 foo ',
803                  '  4 bar ',
804                  '  2 bar ']
805        lines3 = [' a   b   d ',
806                  '  0 foo  R1',
807                  '  8 foo  R2',
808                  '  1 bar  R3',
809                  '  4 bar  R4']
810        self.t1 = t_cls.read(lines1, format='ascii')
811        self.t2 = t_cls.read(lines2, format='ascii')
812        self.t3 = t_cls.read(lines3, format='ascii')
813
814    def test_default_same_columns(self, operation_table_type):
815        self._setup(operation_table_type)
816        out = table.setdiff(self.t1, self.t2)
817        assert type(out['a']) is type(self.t1['a'])  # noqa
818        assert type(out['b']) is type(self.t1['b'])  # noqa
819        assert out.pformat() == [' a   b ',
820                                 '--- ---',
821                                 '  1 bar',
822                                 '  1 foo']
823
824    def test_default_same_tables(self, operation_table_type):
825        self._setup(operation_table_type)
826        out = table.setdiff(self.t1, self.t1)
827
828        assert type(out['a']) is type(self.t1['a'])  # noqa
829        assert type(out['b']) is type(self.t1['b'])  # noqa
830        assert out.pformat() == [' a   b ',
831                                 '--- ---']
832
833    def test_extra_col_left_table(self, operation_table_type):
834        self._setup(operation_table_type)
835
836        with pytest.raises(ValueError):
837            table.setdiff(self.t3, self.t1)
838
839    def test_extra_col_right_table(self, operation_table_type):
840        self._setup(operation_table_type)
841        out = table.setdiff(self.t1, self.t3)
842
843        assert type(out['a']) is type(self.t1['a'])  # noqa
844        assert type(out['b']) is type(self.t1['b'])  # noqa
845        assert out.pformat() == [' a   b ',
846                                 '--- ---',
847                                 '  1 foo',
848                                 '  2 bar']
849
850    def test_keys(self, operation_table_type):
851        self._setup(operation_table_type)
852        out = table.setdiff(self.t3, self.t1, keys=['a', 'b'])
853
854        assert type(out['a']) is type(self.t1['a'])  # noqa
855        assert type(out['b']) is type(self.t1['b'])  # noqa
856        assert out.pformat() == [' a   b   d ',
857                                 '--- --- ---',
858                                 '  4 bar  R4',
859                                 '  8 foo  R2']
860
861    def test_missing_key(self, operation_table_type):
862        self._setup(operation_table_type)
863
864        with pytest.raises(ValueError):
865            table.setdiff(self.t3, self.t1, keys=['a', 'd'])
866
867
868class TestVStack():
869
870    def _setup(self, t_cls=Table):
871        self.t1 = t_cls.read([' a   b',
872                              ' 0. foo',
873                              ' 1. bar'], format='ascii')
874
875        self.t2 = t_cls.read([' a    b   c',
876                              ' 2.  pez  4',
877                              ' 3.  sez  5'], format='ascii')
878
879        self.t3 = t_cls.read([' a    b',
880                              ' 4.   7',
881                              ' 5.   8',
882                              ' 6.   9'], format='ascii')
883        self.t4 = t_cls(self.t1, copy=True, masked=t_cls is Table)
884
885        # The following table has meta-data that conflicts with t1
886        self.t5 = t_cls(self.t1, copy=True)
887
888        self.t1.meta.update(OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)]))
889        self.t2.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)]))
890        self.t4.meta.update(OrderedDict([('b', [5, 6]), ('c', {'c': 1}), ('e', 1)]))
891        self.t5.meta.update(OrderedDict([('b', 3), ('c', 'k'), ('d', 1)]))
892        self.meta_merge = OrderedDict([('b', [1, 2, 3, 4, 5, 6]),
893                                       ('c', {'a': 1, 'b': 1, 'c': 1}),
894                                       ('d', 1),
895                                       ('a', 1),
896                                       ('e', 1)])
897
898    def test_validate_join_type(self):
899        self._setup()
900        with pytest.raises(TypeError, match='Did you accidentally call vstack'):
901            table.vstack(self.t1, self.t2)
902
903    def test_stack_rows(self, operation_table_type):
904        self._setup(operation_table_type)
905        t2 = self.t1.copy()
906        t2.meta.clear()
907        out = table.vstack([self.t1, t2[1]])
908        assert type(out['a']) is type(self.t1['a'])  # noqa
909        assert type(out['b']) is type(self.t1['b'])  # noqa
910        assert out.pformat() == [' a   b ',
911                                 '--- ---',
912                                 '0.0 foo',
913                                 '1.0 bar',
914                                 '1.0 bar']
915
916    def test_stack_table_column(self, operation_table_type):
917        self._setup(operation_table_type)
918        t2 = self.t1.copy()
919        t2.meta.clear()
920        out = table.vstack([self.t1, t2['a']])
921        assert out.masked is False
922        assert out.pformat() == [' a   b ',
923                                 '--- ---',
924                                 '0.0 foo',
925                                 '1.0 bar',
926                                 '0.0  --',
927                                 '1.0  --']
928
929    def test_table_meta_merge(self, operation_table_type):
930        self._setup(operation_table_type)
931        out = table.vstack([self.t1, self.t2, self.t4], join_type='inner')
932        assert out.meta == self.meta_merge
933
934    def test_table_meta_merge_conflict(self, operation_table_type):
935        self._setup(operation_table_type)
936
937        with pytest.warns(metadata.MergeConflictWarning) as w:
938            out = table.vstack([self.t1, self.t5], join_type='inner')
939        assert len(w) == 2
940
941        assert out.meta == self.t5.meta
942
943        with pytest.warns(metadata.MergeConflictWarning) as w:
944            out = table.vstack([self.t1, self.t5], join_type='inner', metadata_conflicts='warn')
945        assert len(w) == 2
946
947        assert out.meta == self.t5.meta
948
949        out = table.vstack([self.t1, self.t5], join_type='inner', metadata_conflicts='silent')
950
951        assert out.meta == self.t5.meta
952
953        with pytest.raises(MergeConflictError):
954            out = table.vstack([self.t1, self.t5], join_type='inner', metadata_conflicts='error')
955
956        with pytest.raises(ValueError):
957            out = table.vstack([self.t1, self.t5], join_type='inner', metadata_conflicts='nonsense')
958
959    def test_bad_input_type(self, operation_table_type):
960        self._setup(operation_table_type)
961        with pytest.raises(ValueError):
962            table.vstack([])
963        with pytest.raises(TypeError):
964            table.vstack(1)
965        with pytest.raises(TypeError):
966            table.vstack([self.t2, 1])
967        with pytest.raises(ValueError):
968            table.vstack([self.t1, self.t2], join_type='invalid join type')
969
970    def test_stack_basic_inner(self, operation_table_type):
971        self._setup(operation_table_type)
972        t1 = self.t1
973        t2 = self.t2
974        t4 = self.t4
975
976        t12 = table.vstack([t1, t2], join_type='inner')
977        assert t12.masked is False
978        assert type(t12) is operation_table_type
979        assert type(t12['a']) is type(t1['a'])  # noqa
980        assert type(t12['b']) is type(t1['b'])  # noqa
981        assert t12.pformat() == [' a   b ',
982                                 '--- ---',
983                                 '0.0 foo',
984                                 '1.0 bar',
985                                 '2.0 pez',
986                                 '3.0 sez']
987
988        t124 = table.vstack([t1, t2, t4], join_type='inner')
989        assert type(t124) is operation_table_type
990        assert type(t12['a']) is type(t1['a'])  # noqa
991        assert type(t12['b']) is type(t1['b'])  # noqa
992        assert t124.pformat() == [' a   b ',
993                                  '--- ---',
994                                  '0.0 foo',
995                                  '1.0 bar',
996                                  '2.0 pez',
997                                  '3.0 sez',
998                                  '0.0 foo',
999                                  '1.0 bar']
1000
1001    def test_stack_basic_outer(self, operation_table_type):
1002        if operation_table_type is QTable:
1003            pytest.xfail('Quantity columns do not support masking.')
1004        self._setup(operation_table_type)
1005        t1 = self.t1
1006        t2 = self.t2
1007        t4 = self.t4
1008        t12 = table.vstack([t1, t2], join_type='outer')
1009        assert t12.masked is False
1010        assert t12.pformat() == [' a   b   c ',
1011                                 '--- --- ---',
1012                                 '0.0 foo  --',
1013                                 '1.0 bar  --',
1014                                 '2.0 pez   4',
1015                                 '3.0 sez   5']
1016
1017        t124 = table.vstack([t1, t2, t4], join_type='outer')
1018        assert t124.masked is False
1019        assert t124.pformat() == [' a   b   c ',
1020                                  '--- --- ---',
1021                                  '0.0 foo  --',
1022                                  '1.0 bar  --',
1023                                  '2.0 pez   4',
1024                                  '3.0 sez   5',
1025                                  '0.0 foo  --',
1026                                  '1.0 bar  --']
1027
1028    def test_stack_incompatible(self, operation_table_type):
1029        self._setup(operation_table_type)
1030        with pytest.raises(TableMergeError) as excinfo:
1031            table.vstack([self.t1, self.t3], join_type='inner')
1032        assert ("The 'b' columns have incompatible types: {}"
1033                .format([self.t1['b'].dtype.name, self.t3['b'].dtype.name])
1034                in str(excinfo.value))
1035
1036        with pytest.raises(TableMergeError) as excinfo:
1037            table.vstack([self.t1, self.t3], join_type='outer')
1038        assert "The 'b' columns have incompatible types:" in str(excinfo.value)
1039
1040        with pytest.raises(TableMergeError):
1041            table.vstack([self.t1, self.t2], join_type='exact')
1042
1043        t1_reshape = self.t1.copy()
1044        t1_reshape['b'].shape = [2, 1]
1045        with pytest.raises(TableMergeError) as excinfo:
1046            table.vstack([self.t1, t1_reshape])
1047        assert "have different shape" in str(excinfo.value)
1048
1049    def test_vstack_one_masked(self, operation_table_type):
1050        if operation_table_type is QTable:
1051            pytest.xfail('Quantity columns do not support masking.')
1052        self._setup(operation_table_type)
1053        t1 = self.t1
1054        t4 = self.t4
1055        t4['b'].mask[1] = True
1056        t14 = table.vstack([t1, t4])
1057        assert t14.masked is False
1058        assert t14.pformat() == [' a   b ',
1059                                 '--- ---',
1060                                 '0.0 foo',
1061                                 '1.0 bar',
1062                                 '0.0 foo',
1063                                 '1.0  --']
1064
1065    def test_col_meta_merge_inner(self, operation_table_type):
1066        self._setup(operation_table_type)
1067        t1 = self.t1
1068        t2 = self.t2
1069        t4 = self.t4
1070
1071        # Key col 'a', should last value ('km')
1072        t1['a'].info.unit = 'cm'
1073        t2['a'].info.unit = 'm'
1074        t4['a'].info.unit = 'km'
1075
1076        # Key col 'a' format should take last when all match
1077        t1['a'].info.format = '%f'
1078        t2['a'].info.format = '%f'
1079        t4['a'].info.format = '%f'
1080
1081        # Key col 'b', take first value 't1_b'
1082        t1['b'].info.description = 't1_b'
1083
1084        # Key col 'b', take first non-empty value '%6s'
1085        t4['b'].info.format = '%6s'
1086
1087        # Key col 'a', should be merged meta
1088        t1['a'].info.meta.update(OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)]))
1089        t2['a'].info.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)]))
1090        t4['a'].info.meta.update(OrderedDict([('b', [5, 6]), ('c', {'c': 1}), ('e', 1)]))
1091
1092        # Key col 'b', should be meta2
1093        t2['b'].info.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)]))
1094
1095        if operation_table_type is Table:
1096            ctx = pytest.warns(metadata.MergeConflictWarning)
1097        else:
1098            ctx = nullcontext()
1099
1100        with ctx as warning_lines:
1101            out = table.vstack([t1, t2, t4], join_type='inner')
1102
1103        if operation_table_type is Table:
1104            assert len(warning_lines) == 2
1105            assert ("In merged column 'a' the 'unit' attribute does not match (cm != m)"
1106                    in str(warning_lines[0].message))
1107            assert ("In merged column 'a' the 'unit' attribute does not match (m != km)"
1108                    in str(warning_lines[1].message))
1109            # Check units are suitably ignored for a regular Table
1110            assert out.pformat() == ['   a       b   ',
1111                                     '   km          ',
1112                                     '-------- ------',
1113                                     '0.000000    foo',
1114                                     '1.000000    bar',
1115                                     '2.000000    pez',
1116                                     '3.000000    sez',
1117                                     '0.000000    foo',
1118                                     '1.000000    bar']
1119        else:
1120            # Check QTable correctly dealt with units.
1121            assert out.pformat() == ['   a       b   ',
1122                                     '   km          ',
1123                                     '-------- ------',
1124                                     '0.000000    foo',
1125                                     '0.000010    bar',
1126                                     '0.002000    pez',
1127                                     '0.003000    sez',
1128                                     '0.000000    foo',
1129                                     '1.000000    bar']
1130        assert out['a'].info.unit == 'km'
1131        assert out['a'].info.format == '%f'
1132        assert out['b'].info.description == 't1_b'
1133        assert out['b'].info.format == '%6s'
1134        assert out['a'].info.meta == self.meta_merge
1135        assert out['b'].info.meta == OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)])
1136
1137    def test_col_meta_merge_outer(self, operation_table_type):
1138        if operation_table_type is QTable:
1139            pytest.xfail('Quantity columns do not support masking.')
1140        self._setup(operation_table_type)
1141        t1 = self.t1
1142        t2 = self.t2
1143        t4 = self.t4
1144
1145        # Key col 'a', should last value ('km')
1146        t1['a'].unit = 'cm'
1147        t2['a'].unit = 'm'
1148        t4['a'].unit = 'km'
1149
1150        # Key col 'a' format should take last when all match
1151        t1['a'].info.format = '%0d'
1152        t2['a'].info.format = '%0d'
1153        t4['a'].info.format = '%0d'
1154
1155        # Key col 'b', take first value 't1_b'
1156        t1['b'].info.description = 't1_b'
1157
1158        # Key col 'b', take first non-empty value '%6s'
1159        t4['b'].info.format = '%6s'
1160
1161        # Key col 'a', should be merged meta
1162        t1['a'].info.meta.update(OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)]))
1163        t2['a'].info.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)]))
1164        t4['a'].info.meta.update(OrderedDict([('b', [5, 6]), ('c', {'c': 1}), ('e', 1)]))
1165
1166        # Key col 'b', should be meta2
1167        t2['b'].info.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)]))
1168
1169        # All these should pass through
1170        t2['c'].unit = 'm'
1171        t2['c'].info.format = '%6s'
1172        t2['c'].info.description = 't2_c'
1173
1174        with pytest.warns(metadata.MergeConflictWarning) as warning_lines:
1175            out = table.vstack([t1, t2, t4], join_type='outer')
1176
1177        assert len(warning_lines) == 2
1178        assert ("In merged column 'a' the 'unit' attribute does not match (cm != m)"
1179                in str(warning_lines[0].message))
1180        assert ("In merged column 'a' the 'unit' attribute does not match (m != km)"
1181                in str(warning_lines[1].message))
1182        assert out['a'].unit == 'km'
1183        assert out['a'].info.format == '%0d'
1184        assert out['b'].info.description == 't1_b'
1185        assert out['b'].info.format == '%6s'
1186        assert out['a'].info.meta == self.meta_merge
1187        assert out['b'].info.meta == OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)])
1188        assert out['c'].info.unit == 'm'
1189        assert out['c'].info.format == '%6s'
1190        assert out['c'].info.description == 't2_c'
1191
1192    def test_vstack_one_table(self, operation_table_type):
1193        self._setup(operation_table_type)
1194        """Regression test for issue #3313"""
1195        assert (self.t1 == table.vstack(self.t1)).all()
1196        assert (self.t1 == table.vstack([self.t1])).all()
1197
1198    def test_mixin_functionality(self, mixin_cols):
1199        col = mixin_cols['m']
1200        len_col = len(col)
1201        t = table.QTable([col], names=['a'])
1202        cls_name = type(col).__name__
1203
1204        # Vstack works for these classes:
1205        if isinstance(col, (u.Quantity, Time, TimeDelta, SkyCoord, EarthLocation,
1206                            BaseRepresentationOrDifferential)):
1207            out = table.vstack([t, t])
1208            assert len(out) == len_col * 2
1209            if cls_name == 'SkyCoord':
1210                # Argh, SkyCoord needs __eq__!!
1211                assert skycoord_equal(out['a'][len_col:], col)
1212                assert skycoord_equal(out['a'][:len_col], col)
1213            elif 'Repr' in cls_name or 'Diff' in cls_name:
1214                assert np.all(representation_equal(out['a'][:len_col], col))
1215                assert np.all(representation_equal(out['a'][len_col:], col))
1216            else:
1217                assert np.all(out['a'][:len_col] == col)
1218                assert np.all(out['a'][len_col:] == col)
1219        else:
1220            with pytest.raises(NotImplementedError) as err:
1221                table.vstack([t, t])
1222            assert ('vstack unavailable for mixin column type(s): {}'
1223                    .format(cls_name) in str(err.value))
1224
1225        # Check for outer stack which requires masking.  Only Time supports
1226        # this currently.
1227        t2 = table.QTable([col], names=['b'])  # different from col name for t
1228        if isinstance(col, (Time, TimeDelta, Quantity)):
1229            out = table.vstack([t, t2], join_type='outer')
1230            assert len(out) == len_col * 2
1231            assert np.all(out['a'][:len_col] == col)
1232            assert np.all(out['b'][len_col:] == col)
1233            assert check_mask(out['a'], [False] * len_col + [True] * len_col)
1234            assert check_mask(out['b'], [True] * len_col + [False] * len_col)
1235            # check directly stacking mixin columns:
1236            out2 = table.vstack([t, t2['b']])
1237            assert np.all(out['a'] == out2['a'])
1238            assert np.all(out['b'] == out2['b'])
1239        else:
1240            with pytest.raises(NotImplementedError) as err:
1241                table.vstack([t, t2], join_type='outer')
1242            assert ('vstack requires masking' in str(err.value)
1243                    or 'vstack unavailable' in str(err.value))
1244
1245    def test_vstack_different_representation(self):
1246        """Test that representations can be mixed together."""
1247        rep1 = CartesianRepresentation([1, 2]*u.km, [3, 4]*u.km, 1*u.km)
1248        rep2 = SphericalRepresentation([0]*u.deg, [0]*u.deg, 10*u.km)
1249        t1 = Table([rep1])
1250        t2 = Table([rep2])
1251        t12 = table.vstack([t1, t2])
1252        expected = CartesianRepresentation([1, 2, 10]*u.km,
1253                                           [3, 4, 0]*u.km,
1254                                           [1, 1, 0]*u.km)
1255        assert np.all(representation_equal(t12['col0'], expected))
1256
1257        rep3 = UnitSphericalRepresentation([0]*u.deg, [0]*u.deg)
1258        t3 = Table([rep3])
1259        with pytest.raises(ValueError, match='representations are inconsistent'):
1260            table.vstack([t1, t3])
1261
1262
1263class TestDStack():
1264
1265    def _setup(self, t_cls=Table):
1266        self.t1 = t_cls.read([' a   b',
1267                              ' 0. foo',
1268                              ' 1. bar'], format='ascii')
1269
1270        self.t2 = t_cls.read([' a    b   c',
1271                              ' 2.  pez  4',
1272                              ' 3.  sez  5'], format='ascii')
1273        self.t2['d'] = Time([1, 2], format='cxcsec')
1274
1275        self.t3 = t_cls({'a': [[5., 6.], [4., 3.]],
1276                         'b': [['foo', 'bar'], ['pez', 'sez']]},
1277                        names=('a', 'b'))
1278
1279        self.t4 = t_cls(self.t1, copy=True, masked=t_cls is Table)
1280
1281        self.t5 = t_cls({'a': [[4., 2.], [1., 6.]],
1282                         'b': [['foo', 'pez'], ['bar', 'sez']]},
1283                        names=('a', 'b'))
1284        self.t6 = t_cls.read([' a    b   c',
1285                              ' 7.  pez  2',
1286                              ' 4.  sez  6',
1287                              ' 6.  foo  3'], format='ascii')
1288
1289    def test_validate_join_type(self):
1290        self._setup()
1291        with pytest.raises(TypeError, match='Did you accidentally call dstack'):
1292            table.dstack(self.t1, self.t2)
1293
1294    @staticmethod
1295    def compare_dstack(tables, out):
1296        for ii, tbl in enumerate(tables):
1297            for name, out_col in out.columns.items():
1298                if name in tbl.colnames:
1299                    # Columns always compare equal
1300                    assert np.all(tbl[name] == out[name][:, ii])
1301
1302                    # If input has a mask then output must have same mask
1303                    if hasattr(tbl[name], 'mask'):
1304                        assert np.all(tbl[name].mask == out[name].mask[:, ii])
1305
1306                    # If input has no mask then output might have a mask (if other table
1307                    # is missing that column). If so then all mask values should be False.
1308                    elif hasattr(out[name], 'mask'):
1309                        assert not np.any(out[name].mask[:, ii])
1310
1311                else:
1312                    # Column missing for this table, out must have a mask with all True.
1313                    assert np.all(out[name].mask[:, ii])
1314
1315    def test_dstack_table_column(self, operation_table_type):
1316        """Stack a table with 3 cols and one column (gets auto-converted to Table).
1317        """
1318        self._setup(operation_table_type)
1319        t2 = self.t1.copy()
1320        out = table.dstack([self.t1, t2['a']])
1321        self.compare_dstack([self.t1, t2[('a',)]], out)
1322
1323    def test_dstack_basic_outer(self, operation_table_type):
1324        if operation_table_type is QTable:
1325            pytest.xfail('Quantity columns do not support masking.')
1326        self._setup(operation_table_type)
1327        t1 = self.t1
1328        t2 = self.t2
1329        t4 = self.t4
1330        t4['a'].mask[0] = True
1331        # Test for non-masked table
1332        t12 = table.dstack([t1, t2], join_type='outer')
1333        assert type(t12) is operation_table_type
1334        assert type(t12['a']) is type(t1['a'])  # noqa
1335        assert type(t12['b']) is type(t1['b'])  # noqa
1336        self.compare_dstack([t1, t2], t12)
1337
1338        # Test for masked table
1339        t124 = table.dstack([t1, t2, t4], join_type='outer')
1340        assert type(t124) is operation_table_type
1341        assert type(t124['a']) is type(t4['a'])  # noqa
1342        assert type(t124['b']) is type(t4['b'])  # noqa
1343        self.compare_dstack([t1, t2, t4], t124)
1344
1345    def test_dstack_basic_inner(self, operation_table_type):
1346        self._setup(operation_table_type)
1347        t1 = self.t1
1348        t2 = self.t2
1349        t4 = self.t4
1350
1351        # Test for masked table
1352        t124 = table.dstack([t1, t2, t4], join_type='inner')
1353        assert type(t124) is operation_table_type
1354        assert type(t124['a']) is type(t4['a'])  # noqa
1355        assert type(t124['b']) is type(t4['b'])  # noqa
1356        self.compare_dstack([t1, t2, t4], t124)
1357
1358    def test_dstack_multi_dimension_column(self, operation_table_type):
1359        self._setup(operation_table_type)
1360        t3 = self.t3
1361        t5 = self.t5
1362        t2 = self.t2
1363        t35 = table.dstack([t3, t5])
1364        assert type(t35) is operation_table_type
1365        assert type(t35['a']) is type(t3['a'])  # noqa
1366        assert type(t35['b']) is type(t3['b'])  # noqa
1367        self.compare_dstack([t3, t5], t35)
1368
1369        with pytest.raises(TableMergeError):
1370            table.dstack([t2, t3])
1371
1372    def test_dstack_different_length_table(self, operation_table_type):
1373        self._setup(operation_table_type)
1374        t2 = self.t2
1375        t6 = self.t6
1376        with pytest.raises(ValueError):
1377            table.dstack([t2, t6])
1378
1379    def test_dstack_single_table(self):
1380        self._setup(Table)
1381        out = table.dstack(self.t1)
1382        assert np.all(out == self.t1)
1383
1384    def test_dstack_representation(self):
1385        rep1 = SphericalRepresentation([1, 2]*u.deg, [3, 4]*u.deg, 1*u.kpc)
1386        rep2 = SphericalRepresentation([10, 20]*u.deg, [30, 40]*u.deg, 10*u.kpc)
1387        t1 = Table([rep1])
1388        t2 = Table([rep2])
1389        t12 = table.dstack([t1, t2])
1390        assert np.all(representation_equal(t12['col0'][:, 0], rep1))
1391        assert np.all(representation_equal(t12['col0'][:, 1], rep2))
1392
1393    def test_dstack_skycoord(self):
1394        sc1 = SkyCoord([1, 2]*u.deg, [3, 4]*u.deg)
1395        sc2 = SkyCoord([10, 20]*u.deg, [30, 40]*u.deg)
1396        t1 = Table([sc1])
1397        t2 = Table([sc2])
1398        t12 = table.dstack([t1, t2])
1399        assert skycoord_equal(sc1, t12['col0'][:, 0])
1400        assert skycoord_equal(sc2, t12['col0'][:, 1])
1401
1402
1403class TestHStack():
1404
1405    def _setup(self, t_cls=Table):
1406        self.t1 = t_cls.read([' a    b',
1407                              ' 0. foo',
1408                              ' 1. bar'], format='ascii')
1409
1410        self.t2 = t_cls.read([' a    b   c',
1411                              ' 2.  pez  4',
1412                              ' 3.  sez  5'], format='ascii')
1413
1414        self.t3 = t_cls.read([' d    e',
1415                              ' 4.   7',
1416                              ' 5.   8',
1417                              ' 6.   9'], format='ascii')
1418        self.t4 = t_cls(self.t1, copy=True, masked=True)
1419        self.t4['a'].name = 'f'
1420        self.t4['b'].name = 'g'
1421
1422        # The following table has meta-data that conflicts with t1
1423        self.t5 = t_cls(self.t1, copy=True)
1424
1425        self.t1.meta.update(OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)]))
1426        self.t2.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)]))
1427        self.t4.meta.update(OrderedDict([('b', [5, 6]), ('c', {'c': 1}), ('e', 1)]))
1428        self.t5.meta.update(OrderedDict([('b', 3), ('c', 'k'), ('d', 1)]))
1429        self.meta_merge = OrderedDict([('b', [1, 2, 3, 4, 5, 6]),
1430                                       ('c', {'a': 1, 'b': 1, 'c': 1}),
1431                                       ('d', 1),
1432                                       ('a', 1),
1433                                       ('e', 1)])
1434
1435    def test_validate_join_type(self):
1436        self._setup()
1437        with pytest.raises(TypeError, match='Did you accidentally call hstack'):
1438            table.hstack(self.t1, self.t2)
1439
1440    def test_stack_same_table(self, operation_table_type):
1441        """
1442        From #2995, test that hstack'ing references to the same table has the
1443        expected output.
1444        """
1445        self._setup(operation_table_type)
1446        out = table.hstack([self.t1, self.t1])
1447        assert out.masked is False
1448        assert out.pformat() == ['a_1 b_1 a_2 b_2',
1449                                 '--- --- --- ---',
1450                                 '0.0 foo 0.0 foo',
1451                                 '1.0 bar 1.0 bar']
1452
1453    def test_stack_rows(self, operation_table_type):
1454        self._setup(operation_table_type)
1455        out = table.hstack([self.t1[0], self.t2[1]])
1456        assert out.masked is False
1457        assert out.pformat() == ['a_1 b_1 a_2 b_2  c ',
1458                                 '--- --- --- --- ---',
1459                                 '0.0 foo 3.0 sez   5']
1460
1461    def test_stack_columns(self, operation_table_type):
1462        self._setup(operation_table_type)
1463        out = table.hstack([self.t1, self.t2['c']])
1464        assert type(out['a']) is type(self.t1['a'])  # noqa
1465        assert type(out['b']) is type(self.t1['b'])  # noqa
1466        assert type(out['c']) is type(self.t2['c'])  # noqa
1467        assert out.pformat() == [' a   b   c ',
1468                                 '--- --- ---',
1469                                 '0.0 foo   4',
1470                                 '1.0 bar   5']
1471
1472    def test_table_meta_merge(self, operation_table_type):
1473        self._setup(operation_table_type)
1474        out = table.hstack([self.t1, self.t2, self.t4], join_type='inner')
1475        assert out.meta == self.meta_merge
1476
1477    def test_table_meta_merge_conflict(self, operation_table_type):
1478        self._setup(operation_table_type)
1479
1480        with pytest.warns(metadata.MergeConflictWarning) as w:
1481            out = table.hstack([self.t1, self.t5], join_type='inner')
1482        assert len(w) == 2
1483
1484        assert out.meta == self.t5.meta
1485
1486        with pytest.warns(metadata.MergeConflictWarning) as w:
1487            out = table.hstack([self.t1, self.t5], join_type='inner', metadata_conflicts='warn')
1488        assert len(w) == 2
1489
1490        assert out.meta == self.t5.meta
1491
1492        out = table.hstack([self.t1, self.t5], join_type='inner', metadata_conflicts='silent')
1493
1494        assert out.meta == self.t5.meta
1495
1496        with pytest.raises(MergeConflictError):
1497            out = table.hstack([self.t1, self.t5], join_type='inner', metadata_conflicts='error')
1498
1499        with pytest.raises(ValueError):
1500            out = table.hstack([self.t1, self.t5], join_type='inner', metadata_conflicts='nonsense')
1501
1502    def test_bad_input_type(self, operation_table_type):
1503        self._setup(operation_table_type)
1504        with pytest.raises(ValueError):
1505            table.hstack([])
1506        with pytest.raises(TypeError):
1507            table.hstack(1)
1508        with pytest.raises(TypeError):
1509            table.hstack([self.t2, 1])
1510        with pytest.raises(ValueError):
1511            table.hstack([self.t1, self.t2], join_type='invalid join type')
1512
1513    def test_stack_basic(self, operation_table_type):
1514        self._setup(operation_table_type)
1515        t1 = self.t1
1516        t2 = self.t2
1517        t3 = self.t3
1518        t4 = self.t4
1519
1520        out = table.hstack([t1, t2], join_type='inner')
1521        assert out.masked is False
1522        assert type(out) is operation_table_type
1523        assert type(out['a_1']) is type(t1['a'])  # noqa
1524        assert type(out['b_1']) is type(t1['b'])  # noqa
1525        assert type(out['a_2']) is type(t2['a'])  # noqa
1526        assert type(out['b_2']) is type(t2['b'])  # noqa
1527        assert out.pformat() == ['a_1 b_1 a_2 b_2  c ',
1528                                 '--- --- --- --- ---',
1529                                 '0.0 foo 2.0 pez   4',
1530                                 '1.0 bar 3.0 sez   5']
1531
1532        # stacking as a list gives same result
1533        out_list = table.hstack([t1, t2], join_type='inner')
1534        assert out.pformat() == out_list.pformat()
1535
1536        out = table.hstack([t1, t2], join_type='outer')
1537        assert out.pformat() == out_list.pformat()
1538
1539        out = table.hstack([t1, t2, t3, t4], join_type='outer')
1540        assert out.masked is False
1541        assert out.pformat() == ['a_1 b_1 a_2 b_2  c   d   e   f   g ',
1542                                 '--- --- --- --- --- --- --- --- ---',
1543                                 '0.0 foo 2.0 pez   4 4.0   7 0.0 foo',
1544                                 '1.0 bar 3.0 sez   5 5.0   8 1.0 bar',
1545                                 ' --  --  --  --  -- 6.0   9  --  --']
1546
1547        out = table.hstack([t1, t2, t3, t4], join_type='inner')
1548        assert out.masked is False
1549        assert out.pformat() == ['a_1 b_1 a_2 b_2  c   d   e   f   g ',
1550                                 '--- --- --- --- --- --- --- --- ---',
1551                                 '0.0 foo 2.0 pez   4 4.0   7 0.0 foo',
1552                                 '1.0 bar 3.0 sez   5 5.0   8 1.0 bar']
1553
1554    def test_stack_incompatible(self, operation_table_type):
1555        self._setup(operation_table_type)
1556        # For join_type exact, which will fail here because n_rows
1557        # does not match
1558        with pytest.raises(TableMergeError):
1559            table.hstack([self.t1, self.t3], join_type='exact')
1560
1561    def test_hstack_one_masked(self, operation_table_type):
1562        if operation_table_type is QTable:
1563            pytest.xfail()
1564        self._setup(operation_table_type)
1565        t1 = self.t1
1566        t2 = operation_table_type(t1, copy=True, masked=True)
1567        t2.meta.clear()
1568        t2['b'].mask[1] = True
1569        out = table.hstack([t1, t2])
1570        assert out.pformat() == ['a_1 b_1 a_2 b_2',
1571                                 '--- --- --- ---',
1572                                 '0.0 foo 0.0 foo',
1573                                 '1.0 bar 1.0  --']
1574
1575    def test_table_col_rename(self, operation_table_type):
1576        self._setup(operation_table_type)
1577        out = table.hstack([self.t1, self.t2], join_type='inner',
1578                           uniq_col_name='{table_name}_{col_name}',
1579                           table_names=('left', 'right'))
1580        assert out.masked is False
1581        assert out.pformat() == ['left_a left_b right_a right_b  c ',
1582                                 '------ ------ ------- ------- ---',
1583                                 '   0.0    foo     2.0     pez   4',
1584                                 '   1.0    bar     3.0     sez   5']
1585
1586    def test_col_meta_merge(self, operation_table_type):
1587        self._setup(operation_table_type)
1588        t1 = self.t1
1589        t3 = self.t3[:2]
1590        t4 = self.t4
1591
1592        # Just set a bunch of meta and make sure it is the same in output
1593        meta1 = OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)])
1594        t1['a'].unit = 'cm'
1595        t1['b'].info.description = 't1_b'
1596        t4['f'].info.format = '%6s'
1597        t1['b'].info.meta.update(meta1)
1598        t3['d'].info.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)]))
1599        t4['g'].info.meta.update(OrderedDict([('b', [5, 6]), ('c', {'c': 1}), ('e', 1)]))
1600        t3['e'].info.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)]))
1601        t3['d'].unit = 'm'
1602        t3['d'].info.format = '%6s'
1603        t3['d'].info.description = 't3_c'
1604
1605        out = table.hstack([t1, t3, t4], join_type='exact')
1606
1607        for t in [t1, t3, t4]:
1608            for name in t.colnames:
1609                for attr in ('meta', 'unit', 'format', 'description'):
1610                    assert getattr(out[name].info, attr) == getattr(t[name].info, attr)
1611
1612        # Make sure we got a copy of meta, not ref
1613        t1['b'].info.meta['b'] = None
1614        assert out['b'].info.meta['b'] == [1, 2]
1615
1616    def test_hstack_one_table(self, operation_table_type):
1617        self._setup(operation_table_type)
1618        """Regression test for issue #3313"""
1619        assert (self.t1 == table.hstack(self.t1)).all()
1620        assert (self.t1 == table.hstack([self.t1])).all()
1621
1622    def test_mixin_functionality(self, mixin_cols):
1623        col1 = mixin_cols['m']
1624        col2 = col1[2:4]  # Shorter version of col1
1625        t1 = table.QTable([col1])
1626        t2 = table.QTable([col2])
1627
1628        cls_name = type(col1).__name__
1629
1630        out = table.hstack([t1, t2], join_type='inner')
1631        assert type(out['col0_1']) is type(out['col0_2'])  # noqa
1632        assert len(out) == len(col2)
1633
1634        # Check that columns are as expected.
1635        if cls_name == 'SkyCoord':
1636            assert skycoord_equal(out['col0_1'], col1[:len(col2)])
1637            assert skycoord_equal(out['col0_2'], col2)
1638        elif 'Repr' in cls_name or 'Diff' in cls_name:
1639            assert np.all(representation_equal(out['col0_1'], col1[:len(col2)]))
1640            assert np.all(representation_equal(out['col0_2'], col2))
1641        else:
1642            assert np.all(out['col0_1'] == col1[:len(col2)])
1643            assert np.all(out['col0_2'] == col2)
1644
1645        # Time class supports masking, all other mixins do not
1646        if isinstance(col1, (Time, TimeDelta, Quantity)):
1647            out = table.hstack([t1, t2], join_type='outer')
1648            assert len(out) == len(t1)
1649            assert np.all(out['col0_1'] == col1)
1650            assert np.all(out['col0_2'][:len(col2)] == col2)
1651            assert check_mask(out['col0_2'], [False, False, True, True])
1652
1653            # check directly stacking mixin columns:
1654            out2 = table.hstack([t1, t2['col0']], join_type='outer')
1655            assert np.all(out['col0_1'] == out2['col0_1'])
1656            assert np.all(out['col0_2'] == out2['col0_2'])
1657        else:
1658            with pytest.raises(NotImplementedError) as err:
1659                table.hstack([t1, t2], join_type='outer')
1660            assert 'hstack requires masking' in str(err.value)
1661
1662
1663def test_unique(operation_table_type):
1664    t = operation_table_type.read(
1665        [' a b  c  d',
1666         ' 2 b 7.0 0',
1667         ' 1 c 3.0 5',
1668         ' 2 b 6.0 2',
1669         ' 2 a 4.0 3',
1670         ' 1 a 1.0 7',
1671         ' 2 b 5.0 1',
1672         ' 0 a 0.0 4',
1673         ' 1 a 2.0 6',
1674         ' 1 c 3.0 5',
1675         ], format='ascii')
1676
1677    tu = operation_table_type(np.sort(t[:-1]))
1678
1679    t_all = table.unique(t)
1680    assert sort_eq(t_all.pformat(), tu.pformat())
1681    t_s = t.copy()
1682    del t_s['b', 'c', 'd']
1683    t_all = table.unique(t_s)
1684    assert sort_eq(t_all.pformat(), [' a ',
1685                                     '---',
1686                                     '  0',
1687                                     '  1',
1688                                     '  2'])
1689
1690    key1 = 'a'
1691    t1a = table.unique(t, key1)
1692    assert sort_eq(t1a.pformat(), [' a   b   c   d ',
1693                                   '--- --- --- ---',
1694                                   '  0   a 0.0   4',
1695                                   '  1   c 3.0   5',
1696                                   '  2   b 7.0   0'])
1697    t1b = table.unique(t, key1, keep='last')
1698    assert sort_eq(t1b.pformat(), [' a   b   c   d ',
1699                                   '--- --- --- ---',
1700                                   '  0   a 0.0   4',
1701                                   '  1   c 3.0   5',
1702                                   '  2   b 5.0   1'])
1703    t1c = table.unique(t, key1, keep='none')
1704    assert sort_eq(t1c.pformat(), [' a   b   c   d ',
1705                                   '--- --- --- ---',
1706                                   '  0   a 0.0   4'])
1707
1708    key2 = ['a', 'b']
1709    t2a = table.unique(t, key2)
1710    assert sort_eq(t2a.pformat(), [' a   b   c   d ',
1711                                   '--- --- --- ---',
1712                                   '  0   a 0.0   4',
1713                                   '  1   a 1.0   7',
1714                                   '  1   c 3.0   5',
1715                                   '  2   a 4.0   3',
1716                                   '  2   b 7.0   0'])
1717
1718    t2b = table.unique(t, key2, keep='last')
1719    assert sort_eq(t2b.pformat(), [' a   b   c   d ',
1720                                   '--- --- --- ---',
1721                                   '  0   a 0.0   4',
1722                                   '  1   a 2.0   6',
1723                                   '  1   c 3.0   5',
1724                                   '  2   a 4.0   3',
1725                                   '  2   b 5.0   1'])
1726    t2c = table.unique(t, key2, keep='none')
1727    assert sort_eq(t2c.pformat(), [' a   b   c   d ',
1728                                   '--- --- --- ---',
1729                                   '  0   a 0.0   4',
1730                                   '  2   a 4.0   3'])
1731
1732    key2 = ['a', 'a']
1733    with pytest.raises(ValueError) as exc:
1734        t2a = table.unique(t, key2)
1735    assert exc.value.args[0] == "duplicate key names"
1736
1737    with pytest.raises(ValueError) as exc:
1738        table.unique(t, key2, keep=True)
1739    assert exc.value.args[0] == (
1740        "'keep' should be one of 'first', 'last', 'none'")
1741
1742    t1_m = operation_table_type(t1a, masked=True)
1743    t1_m['a'].mask[1] = True
1744
1745    with pytest.raises(ValueError) as exc:
1746        t1_mu = table.unique(t1_m)
1747    assert exc.value.args[0] == (
1748        "cannot use columns with masked values as keys; "
1749        "remove column 'a' from keys and rerun unique()")
1750
1751    t1_mu = table.unique(t1_m, silent=True)
1752    assert t1_mu.masked is False
1753    assert t1_mu.pformat() == [' a   b   c   d ',
1754                               '--- --- --- ---',
1755                               '  0   a 0.0   4',
1756                               '  2   b 7.0   0',
1757                               ' --   c 3.0   5']
1758
1759    with pytest.raises(ValueError):
1760        t1_mu = table.unique(t1_m, silent=True, keys='a')
1761
1762    t1_m = operation_table_type(t, masked=True)
1763    t1_m['a'].mask[1] = True
1764    t1_m['d'].mask[3] = True
1765
1766    # Test that multiple masked key columns get removed in the correct
1767    # order
1768    t1_mu = table.unique(t1_m, keys=['d', 'a', 'b'], silent=True)
1769    assert t1_mu.masked is False
1770    assert t1_mu.pformat() == [' a   b   c   d ',
1771                               '--- --- --- ---',
1772                               '  2   a 4.0  --',
1773                               '  2   b 7.0   0',
1774                               ' --   c 3.0   5']
1775
1776
1777def test_vstack_bytes(operation_table_type):
1778    """
1779    Test for issue #5617 when vstack'ing bytes columns in Py3.
1780    This is really an upstream numpy issue numpy/numpy/#8403.
1781    """
1782    t = operation_table_type([[b'a']], names=['a'])
1783    assert t['a'].itemsize == 1
1784
1785    t2 = table.vstack([t, t])
1786    assert len(t2) == 2
1787    assert t2['a'].itemsize == 1
1788
1789
1790def test_vstack_unicode():
1791    """
1792    Test for problem related to issue #5617 when vstack'ing *unicode*
1793    columns.  In this case the character size gets multiplied by 4.
1794    """
1795    t = table.Table([['a']], names=['a'])
1796    assert t['a'].itemsize == 4  # 4-byte / char for U dtype
1797
1798    t2 = table.vstack([t, t])
1799    assert len(t2) == 2
1800    assert t2['a'].itemsize == 4
1801
1802
1803def test_join_mixins_time_quantity():
1804    """
1805    Test for table join using non-ndarray key columns.
1806    """
1807    tm1 = Time([2, 1, 2], format='cxcsec')
1808    q1 = [2, 1, 1] * u.m
1809    idx1 = [1, 2, 3]
1810    tm2 = Time([2, 3], format='cxcsec')
1811    q2 = [2, 3] * u.m
1812    idx2 = [10, 20]
1813    t1 = Table([tm1, q1, idx1], names=['tm', 'q', 'idx'])
1814    t2 = Table([tm2, q2, idx2], names=['tm', 'q', 'idx'])
1815    # Output:
1816    #
1817    # <Table length=4>
1818    #         tm            q    idx_1 idx_2
1819    #                       m
1820    #       object       float64 int64 int64
1821    # ------------------ ------- ----- -----
1822    # 0.9999999999969589     1.0     2    --
1823    #   2.00000000000351     1.0     3    --
1824    #   2.00000000000351     2.0     1    10
1825    #  3.000000000000469     3.0    --    20
1826
1827    t12 = table.join(t1, t2, join_type='outer', keys=['tm', 'q'])
1828    # Key cols are lexically sorted
1829    assert np.all(t12['tm'] == Time([1, 2, 2, 3], format='cxcsec'))
1830    assert np.all(t12['q'] == [1, 1, 2, 3] * u.m)
1831    assert np.all(t12['idx_1'] == np.ma.array([2, 3, 1, 0], mask=[0, 0, 0, 1]))
1832    assert np.all(t12['idx_2'] == np.ma.array([0, 0, 10, 20], mask=[1, 1, 0, 0]))
1833
1834
1835def test_join_mixins_not_sortable():
1836    """
1837    Test for table join using non-ndarray key columns that are not sortable.
1838    """
1839    sc = SkyCoord([1, 2], [3, 4], unit='deg,deg')
1840    t1 = Table([sc, [1, 2]], names=['sc', 'idx1'])
1841    t2 = Table([sc, [10, 20]], names=['sc', 'idx2'])
1842
1843    with pytest.raises(TypeError, match='one or more key columns are not sortable'):
1844        table.join(t1, t2, keys='sc')
1845
1846
1847def test_join_non_1d_key_column():
1848    c1 = [[1, 2], [3, 4]]
1849    c2 = [1, 2]
1850    t1 = Table([c1, c2], names=['a', 'b'])
1851    t2 = t1.copy()
1852    with pytest.raises(ValueError, match="key column 'a' must be 1-d"):
1853        table.join(t1, t2, keys='a')
1854
1855
1856def test_argsort_time_column():
1857    """Regression test for #10823."""
1858    times = Time(['2016-01-01', '2018-01-01', '2017-01-01'])
1859    t = Table([times], names=['time'])
1860    i = t.argsort('time')
1861    assert np.all(i == times.argsort())
1862
1863
1864def test_sort_indexed_table():
1865    """Test fix for #9473 and #6545 - and another regression test for #10823."""
1866    t = Table([[1, 3, 2], [6, 4, 5]], names=('a', 'b'))
1867    t.add_index('a')
1868    t.sort('a')
1869    assert np.all(t['a'] == [1, 2, 3])
1870    assert np.all(t['b'] == [6, 5, 4])
1871    t.sort('b')
1872    assert np.all(t['b'] == [4, 5, 6])
1873    assert np.all(t['a'] == [3, 2, 1])
1874
1875    times = ['2016-01-01', '2018-01-01', '2017-01-01']
1876    tm = Time(times)
1877    t2 = Table([tm, [3, 2, 1]], names=['time', 'flux'])
1878    t2.sort('flux')
1879    assert np.all(t2['flux'] == [1, 2, 3])
1880    t2.sort('time')
1881    assert np.all(t2['flux'] == [3, 1, 2])
1882    assert np.all(t2['time'] == tm[[0, 2, 1]])
1883
1884    # Using the table as a TimeSeries implicitly sets the index, so
1885    # this test is a bit different from the above.
1886    from astropy.timeseries import TimeSeries
1887    ts = TimeSeries(time=times)
1888    ts['flux'] = [3, 2, 1]
1889    ts.sort('flux')
1890    assert np.all(ts['flux'] == [1, 2, 3])
1891    ts.sort('time')
1892    assert np.all(ts['flux'] == [3, 1, 2])
1893    assert np.all(ts['time'] == tm[[0, 2, 1]])
1894
1895
1896def test_get_out_class():
1897    c = table.Column([1, 2])
1898    mc = table.MaskedColumn([1, 2])
1899    q = [1, 2] * u.m
1900
1901    assert _get_out_class([c, mc]) is mc.__class__
1902    assert _get_out_class([mc, c]) is mc.__class__
1903    assert _get_out_class([c, c]) is c.__class__
1904    assert _get_out_class([c]) is c.__class__
1905
1906    with pytest.raises(ValueError):
1907        _get_out_class([c, q])
1908
1909    with pytest.raises(ValueError):
1910        _get_out_class([q, c])
1911
1912
1913def test_masking_required_exception():
1914    """
1915    Test that outer join, hstack and vstack fail for a mixin column which
1916    does not support masking.
1917    """
1918    col = table.NdarrayMixin([0, 1, 2, 3])
1919    t1 = table.QTable([[1, 2, 3, 4], col], names=['a', 'b'])
1920    t2 = table.QTable([[1, 2], col[:2]], names=['a', 'c'])
1921
1922    with pytest.raises(NotImplementedError) as err:
1923        table.vstack([t1, t2], join_type='outer')
1924    assert 'vstack unavailable' in str(err.value)
1925
1926    with pytest.raises(NotImplementedError) as err:
1927        table.hstack([t1, t2], join_type='outer')
1928    assert 'hstack requires masking' in str(err.value)
1929
1930    with pytest.raises(NotImplementedError) as err:
1931        table.join(t1, t2, join_type='outer')
1932    assert 'join requires masking' in str(err.value)
1933
1934
1935def test_stack_columns():
1936    c = table.Column([1, 2])
1937    mc = table.MaskedColumn([1, 2])
1938    q = [1, 2] * u.m
1939    time = Time(['2001-01-02T12:34:56', '2001-02-03T00:01:02'])
1940    sc = SkyCoord([1, 2], [3, 4], unit='deg')
1941    cq = table.Column([11, 22], unit=u.m)
1942
1943    t = table.hstack([c, q])
1944    assert t.__class__ is table.QTable
1945    assert t.masked is False
1946    t = table.hstack([q, c])
1947    assert t.__class__ is table.QTable
1948    assert t.masked is False
1949
1950    t = table.hstack([mc, q])
1951    assert t.__class__ is table.QTable
1952    assert t.masked is False
1953
1954    t = table.hstack([c, mc])
1955    assert t.__class__ is table.Table
1956    assert t.masked is False
1957
1958    t = table.vstack([q, q])
1959    assert t.__class__ is table.QTable
1960
1961    t = table.vstack([c, c])
1962    assert t.__class__ is table.Table
1963
1964    t = table.hstack([c, time])
1965    assert t.__class__ is table.Table
1966    t = table.hstack([c, sc])
1967    assert t.__class__ is table.Table
1968    t = table.hstack([q, time, sc])
1969    assert t.__class__ is table.QTable
1970
1971    with pytest.raises(ValueError):
1972        table.vstack([c, q])
1973
1974    with pytest.raises(ValueError):
1975        t = table.vstack([q, cq])
1976
1977
1978def test_mixin_join_regression():
1979    # This used to trigger a ValueError:
1980    # ValueError: NumPy boolean array indexing assignment cannot assign
1981    # 6 input values to the 4 output values where the mask is true
1982
1983    t1 = QTable()
1984    t1['index'] = [1, 2, 3, 4, 5]
1985    t1['flux1'] = [2, 3, 2, 1, 1] * u.Jy
1986    t1['flux2'] = [2, 3, 2, 1, 1] * u.Jy
1987
1988    t2 = QTable()
1989    t2['index'] = [3, 4, 5, 6]
1990    t2['flux1'] = [2, 1, 1, 3] * u.Jy
1991    t2['flux2'] = [2, 1, 1, 3] * u.Jy
1992
1993    t12 = table.join(t1, t2, keys=('index', 'flux1', 'flux2'), join_type='outer')
1994
1995    assert len(t12) == 6
1996