1"""
2High-level table operations:
3
4- join()
5- setdiff()
6- hstack()
7- vstack()
8- dstack()
9"""
10# Licensed under a 3-clause BSD style license - see LICENSE.rst
11
12from copy import deepcopy
13import collections
14import itertools
15from collections import OrderedDict, Counter
16from collections.abc import Mapping, Sequence
17
18import numpy as np
19
20from astropy.utils import metadata
21from astropy.utils.masked import Masked
22from .table import Table, QTable, Row, Column, MaskedColumn
23from astropy.units import Quantity
24
25from . import _np_utils
26from .np_utils import fix_column_name, TableMergeError
27
28__all__ = ['join', 'setdiff', 'hstack', 'vstack', 'unique',
29           'join_skycoord', 'join_distance']
30
31__doctest_requires__ = {'join_skycoord': ['scipy'], 'join_distance': ['scipy']}
32
33
34def _merge_table_meta(out, tables, metadata_conflicts='warn'):
35    out_meta = deepcopy(tables[0].meta)
36    for table in tables[1:]:
37        out_meta = metadata.merge(out_meta, table.meta, metadata_conflicts=metadata_conflicts)
38    out.meta.update(out_meta)
39
40
41def _get_list_of_tables(tables):
42    """
43    Check that tables is a Table or sequence of Tables.  Returns the
44    corresponding list of Tables.
45    """
46
47    # Make sure we have a list of things
48    if not isinstance(tables, Sequence):
49        tables = [tables]
50
51    # Make sure there is something to stack
52    if len(tables) == 0:
53        raise ValueError('no values provided to stack.')
54
55    # Convert inputs (Table, Row, or anything column-like) to Tables.
56    # Special case that Quantity converts to a QTable.
57    for ii, val in enumerate(tables):
58        if isinstance(val, Table):
59            pass
60        elif isinstance(val, Row):
61            tables[ii] = Table(val)
62        elif isinstance(val, Quantity):
63            tables[ii] = QTable([val])
64        else:
65            try:
66                tables[ii] = Table([val])
67            except (ValueError, TypeError) as err:
68                raise TypeError(f'Cannot convert {val} to table column.') from err
69
70    return tables
71
72
73def _get_out_class(objs):
74    """
75    From a list of input objects ``objs`` get merged output object class.
76
77    This is just taken as the deepest subclass. This doesn't handle complicated
78    inheritance schemes, but as a special case, classes which share ``info``
79    are taken to be compatible.
80    """
81    out_class = objs[0].__class__
82    for obj in objs[1:]:
83        if issubclass(obj.__class__, out_class):
84            out_class = obj.__class__
85
86    if any(not (issubclass(out_class, obj.__class__)
87                or out_class.info is obj.__class__.info) for obj in objs):
88        raise ValueError('unmergeable object classes {}'
89                         .format([obj.__class__.__name__ for obj in objs]))
90
91    return out_class
92
93
94def join_skycoord(distance, distance_func='search_around_sky'):
95    """Helper function to join on SkyCoord columns using distance matching.
96
97    This function is intended for use in ``table.join()`` to allow performing a
98    table join where the key columns are both ``SkyCoord`` objects, matched by
99    computing the distance between points and accepting values below
100    ``distance``.
101
102    The distance cross-matching is done using either
103    `~astropy.coordinates.search_around_sky` or
104    `~astropy.coordinates.search_around_3d`, depending on the value of
105    ``distance_func``.  The default is ``'search_around_sky'``.
106
107    One can also provide a function object for ``distance_func``, in which case
108    it must be a function that follows the same input and output API as
109    `~astropy.coordinates.search_around_sky`. In this case the function will
110    be called with ``(skycoord1, skycoord2, distance)`` as arguments.
111
112    Parameters
113    ----------
114    distance : `~astropy.units.Quantity` ['angle', 'length']
115        Maximum distance between points to be considered a join match.
116        Must have angular or distance units.
117    distance_func : str or function
118        Specifies the function for performing the cross-match based on
119        ``distance``. If supplied as a string this specifies the name of a
120        function in `astropy.coordinates`. If supplied as a function then that
121        function is called directly.
122
123    Returns
124    -------
125    join_func : function
126        Function that accepts two ``SkyCoord`` columns (col1, col2) and returns
127        the tuple (ids1, ids2) of pair-matched unique identifiers.
128
129    Examples
130    --------
131    This example shows an inner join of two ``SkyCoord`` columns, taking any
132    sources within 0.2 deg to be a match.  Note the new ``sc_id`` column which
133    is added and provides a unique source identifier for the matches.
134
135      >>> from astropy.coordinates import SkyCoord
136      >>> import astropy.units as u
137      >>> from astropy.table import Table, join_skycoord
138      >>> from astropy import table
139
140      >>> sc1 = SkyCoord([0, 1, 1.1, 2], [0, 0, 0, 0], unit='deg')
141      >>> sc2 = SkyCoord([0.5, 1.05, 2.1], [0, 0, 0], unit='deg')
142
143      >>> join_func = join_skycoord(0.2 * u.deg)
144      >>> join_func(sc1, sc2)  # Associate each coordinate with unique source ID
145      (array([3, 1, 1, 2]), array([4, 1, 2]))
146
147      >>> t1 = Table([sc1], names=['sc'])
148      >>> t2 = Table([sc2], names=['sc'])
149      >>> t12 = table.join(t1, t2, join_funcs={'sc': join_skycoord(0.2 * u.deg)})
150      >>> print(t12)  # Note new `sc_id` column with the IDs from join_func()
151      sc_id   sc_1    sc_2
152            deg,deg deg,deg
153      ----- ------- --------
154          1 1.0,0.0 1.05,0.0
155          1 1.1,0.0 1.05,0.0
156          2 2.0,0.0  2.1,0.0
157
158    """
159    if isinstance(distance_func, str):
160        import astropy.coordinates as coords
161        try:
162            distance_func = getattr(coords, distance_func)
163        except AttributeError as err:
164            raise ValueError('distance_func must be a function in astropy.coordinates') from err
165    else:
166        from inspect import isfunction
167        if not isfunction(distance_func):
168            raise ValueError('distance_func must be a str or function')
169
170    def join_func(sc1, sc2):
171
172        # Call the appropriate SkyCoord method to find pairs within distance
173        idxs1, idxs2, d2d, d3d = distance_func(sc1, sc2, distance)
174
175        # Now convert that into unique identifiers for each near-pair. This is
176        # taken to be transitive, so that if points 1 and 2 are "near" and points
177        # 1 and 3 are "near", then 1, 2, and 3 are all given the same identifier.
178        # This identifier will then be used in the table join matching.
179
180        # Identifiers for each column, initialized to all zero.
181        ids1 = np.zeros(len(sc1), dtype=int)
182        ids2 = np.zeros(len(sc2), dtype=int)
183
184        # Start the identifier count at 1
185        id_ = 1
186        for idx1, idx2 in zip(idxs1, idxs2):
187            # If this col1 point is previously identified then set corresponding
188            # col2 point to same identifier.  Likewise for col2 and col1.
189            if ids1[idx1] > 0:
190                ids2[idx2] = ids1[idx1]
191            elif ids2[idx2] > 0:
192                ids1[idx1] = ids2[idx2]
193            else:
194                # Not yet seen so set identifier for col1 and col2
195                ids1[idx1] = id_
196                ids2[idx2] = id_
197                id_ += 1
198
199        # Fill in unique identifiers for points with no near neighbor
200        for ids in (ids1, ids2):
201            for idx in np.flatnonzero(ids == 0):
202                ids[idx] = id_
203                id_ += 1
204
205        # End of enclosure join_func()
206        return ids1, ids2
207
208    return join_func
209
210
211def join_distance(distance, kdtree_args=None, query_args=None):
212    """Helper function to join table columns using distance matching.
213
214    This function is intended for use in ``table.join()`` to allow performing
215    a table join where the key columns are matched by computing the distance
216    between points and accepting values below ``distance``. This numerical
217    "fuzzy" match can apply to 1-D or 2-D columns, where in the latter case
218    the distance is a vector distance.
219
220    The distance cross-matching is done using `scipy.spatial.cKDTree`. If
221    necessary you can tweak the default behavior by providing ``dict`` values
222    for the ``kdtree_args`` or ``query_args``.
223
224    Parameters
225    ----------
226    distance : float or `~astropy.units.Quantity` ['length']
227        Maximum distance between points to be considered a join match
228    kdtree_args : dict, None
229        Optional extra args for `~scipy.spatial.cKDTree`
230    query_args : dict, None
231        Optional extra args for `~scipy.spatial.cKDTree.query_ball_tree`
232
233    Returns
234    -------
235    join_func : function
236        Function that accepts (skycoord1, skycoord2) and returns the tuple
237        (ids1, ids2) of pair-matched unique identifiers.
238
239    Examples
240    --------
241
242      >>> from astropy.table import Table, join_distance
243      >>> from astropy import table
244
245      >>> c1 = [0, 1, 1.1, 2]
246      >>> c2 = [0.5, 1.05, 2.1]
247
248      >>> t1 = Table([c1], names=['col'])
249      >>> t2 = Table([c2], names=['col'])
250      >>> t12 = table.join(t1, t2, join_type='outer', join_funcs={'col': join_distance(0.2)})
251      >>> print(t12)
252      col_id col_1 col_2
253      ------ ----- -----
254           1   1.0  1.05
255           1   1.1  1.05
256           2   2.0   2.1
257           3   0.0    --
258           4    --   0.5
259
260    """
261    try:
262        from scipy.spatial import cKDTree
263    except ImportError as exc:
264        raise ImportError('scipy is required to use join_distance()') from exc
265
266    if kdtree_args is None:
267        kdtree_args = {}
268    if query_args is None:
269        query_args = {}
270
271    def join_func(col1, col2):
272        if col1.ndim > 2 or col2.ndim > 2:
273            raise ValueError('columns for isclose_join must be 1- or 2-dimensional')
274
275        if isinstance(distance, Quantity):
276            # Convert to np.array with common unit
277            col1 = col1.to_value(distance.unit)
278            col2 = col2.to_value(distance.unit)
279            dist = distance.value
280        else:
281            # Convert to np.array to allow later in-place shape changing
282            col1 = np.asarray(col1)
283            col2 = np.asarray(col2)
284            dist = distance
285
286        # Ensure columns are pure np.array and are 2-D for use with KDTree
287        if col1.ndim == 1:
288            col1.shape = col1.shape + (1,)
289        if col2.ndim == 1:
290            col2.shape = col2.shape + (1,)
291
292        # Cross-match col1 and col2 within dist using KDTree
293        kd1 = cKDTree(col1, **kdtree_args)
294        kd2 = cKDTree(col2, **kdtree_args)
295        nears = kd1.query_ball_tree(kd2, r=dist, **query_args)
296
297        # Output of above is nears which is a list of lists, where the outer
298        # list corresponds to each item in col1, and where the inner lists are
299        # indexes into col2 of elements within the distance tolerance.  This
300        # identifies col1 / col2 near pairs.
301
302        # Now convert that into unique identifiers for each near-pair. This is
303        # taken to be transitive, so that if points 1 and 2 are "near" and points
304        # 1 and 3 are "near", then 1, 2, and 3 are all given the same identifier.
305        # This identifier will then be used in the table join matching.
306
307        # Identifiers for each column, initialized to all zero.
308        ids1 = np.zeros(len(col1), dtype=int)
309        ids2 = np.zeros(len(col2), dtype=int)
310
311        # Start the identifier count at 1
312        id_ = 1
313        for idx1, idxs2 in enumerate(nears):
314            for idx2 in idxs2:
315                # If this col1 point is previously identified then set corresponding
316                # col2 point to same identifier.  Likewise for col2 and col1.
317                if ids1[idx1] > 0:
318                    ids2[idx2] = ids1[idx1]
319                elif ids2[idx2] > 0:
320                    ids1[idx1] = ids2[idx2]
321                else:
322                    # Not yet seen so set identifier for col1 and col2
323                    ids1[idx1] = id_
324                    ids2[idx2] = id_
325                    id_ += 1
326
327        # Fill in unique identifiers for points with no near neighbor
328        for ids in (ids1, ids2):
329            for idx in np.flatnonzero(ids == 0):
330                ids[idx] = id_
331                id_ += 1
332
333        # End of enclosure join_func()
334        return ids1, ids2
335
336    return join_func
337
338
339def join(left, right, keys=None, join_type='inner', *,
340         keys_left=None, keys_right=None,
341         uniq_col_name='{col_name}_{table_name}',
342         table_names=['1', '2'], metadata_conflicts='warn',
343         join_funcs=None):
344    """
345    Perform a join of the left table with the right table on specified keys.
346
347    Parameters
348    ----------
349    left : `~astropy.table.Table`-like object
350        Left side table in the join. If not a Table, will call ``Table(left)``
351    right : `~astropy.table.Table`-like object
352        Right side table in the join. If not a Table, will call ``Table(right)``
353    keys : str or list of str
354        Name(s) of column(s) used to match rows of left and right tables.
355        Default is to use all columns which are common to both tables.
356    join_type : str
357        Join type ('inner' | 'outer' | 'left' | 'right' | 'cartesian'), default is 'inner'
358    keys_left : str or list of str or list of column-like, optional
359        Left column(s) used to match rows instead of ``keys`` arg. This can be
360        be a single left table column name or list of column names, or a list of
361        column-like values with the same lengths as the left table.
362    keys_right : str or list of str or list of column-like, optional
363        Same as ``keys_left``, but for the right side of the join.
364    uniq_col_name : str or None
365        String generate a unique output column name in case of a conflict.
366        The default is '{col_name}_{table_name}'.
367    table_names : list of str or None
368        Two-element list of table names used when generating unique output
369        column names.  The default is ['1', '2'].
370    metadata_conflicts : str
371        How to proceed with metadata conflicts. This should be one of:
372            * ``'silent'``: silently pick the last conflicting meta-data value
373            * ``'warn'``: pick the last conflicting meta-data value, but emit a warning (default)
374            * ``'error'``: raise an exception.
375    join_funcs : dict, None
376        Dict of functions to use for matching the corresponding key column(s).
377        See `~astropy.table.join_skycoord` for an example and details.
378
379    Returns
380    -------
381    joined_table : `~astropy.table.Table` object
382        New table containing the result of the join operation.
383    """
384
385    # Try converting inputs to Table as needed
386    if not isinstance(left, Table):
387        left = Table(left)
388    if not isinstance(right, Table):
389        right = Table(right)
390
391    col_name_map = OrderedDict()
392    out = _join(left, right, keys, join_type,
393                uniq_col_name, table_names, col_name_map, metadata_conflicts,
394                join_funcs,
395                keys_left=keys_left, keys_right=keys_right)
396
397    # Merge the column and table meta data. Table subclasses might override
398    # these methods for custom merge behavior.
399    _merge_table_meta(out, [left, right], metadata_conflicts=metadata_conflicts)
400
401    return out
402
403
404def setdiff(table1, table2, keys=None):
405    """
406    Take a set difference of table rows.
407
408    The row set difference will contain all rows in ``table1`` that are not
409    present in ``table2``. If the keys parameter is not defined, all columns in
410    ``table1`` will be included in the output table.
411
412    Parameters
413    ----------
414    table1 : `~astropy.table.Table`
415        ``table1`` is on the left side of the set difference.
416    table2 : `~astropy.table.Table`
417        ``table2`` is on the right side of the set difference.
418    keys : str or list of str
419        Name(s) of column(s) used to match rows of left and right tables.
420        Default is to use all columns in ``table1``.
421
422    Returns
423    -------
424    diff_table : `~astropy.table.Table`
425        New table containing the set difference between tables. If the set
426        difference is none, an empty table will be returned.
427
428    Examples
429    --------
430    To get a set difference between two tables::
431
432      >>> from astropy.table import setdiff, Table
433      >>> t1 = Table({'a': [1, 4, 9], 'b': ['c', 'd', 'f']}, names=('a', 'b'))
434      >>> t2 = Table({'a': [1, 5, 9], 'b': ['c', 'b', 'f']}, names=('a', 'b'))
435      >>> print(t1)
436       a   b
437      --- ---
438        1   c
439        4   d
440        9   f
441      >>> print(t2)
442       a   b
443      --- ---
444        1   c
445        5   b
446        9   f
447      >>> print(setdiff(t1, t2))
448       a   b
449      --- ---
450        4   d
451
452      >>> print(setdiff(t2, t1))
453       a   b
454      --- ---
455        5   b
456    """
457    if keys is None:
458        keys = table1.colnames
459
460    # Check that all keys are in table1 and table2
461    for tbl, tbl_str in ((table1, 'table1'), (table2, 'table2')):
462        diff_keys = np.setdiff1d(keys, tbl.colnames)
463        if len(diff_keys) != 0:
464            raise ValueError("The {} columns are missing from {}, cannot take "
465                             "a set difference.".format(diff_keys, tbl_str))
466
467    # Make a light internal copy of both tables
468    t1 = table1.copy(copy_data=False)
469    t1.meta = {}
470    t1.keep_columns(keys)
471    t1['__index1__'] = np.arange(len(table1))  # Keep track of rows indices
472
473    # Make a light internal copy to avoid touching table2
474    t2 = table2.copy(copy_data=False)
475    t2.meta = {}
476    t2.keep_columns(keys)
477    # Dummy column to recover rows after join
478    t2['__index2__'] = np.zeros(len(t2), dtype=np.uint8)  # dummy column
479
480    t12 = _join(t1, t2, join_type='left', keys=keys,
481                metadata_conflicts='silent')
482
483    # If t12 index2 is masked then that means some rows were in table1 but not table2.
484    if hasattr(t12['__index2__'], 'mask'):
485        # Define bool mask of table1 rows not in table2
486        diff = t12['__index2__'].mask
487        # Get the row indices of table1 for those rows
488        idx = t12['__index1__'][diff]
489        # Select corresponding table1 rows straight from table1 to ensure
490        # correct table and column types.
491        t12_diff = table1[idx]
492    else:
493        t12_diff = table1[[]]
494
495    return t12_diff
496
497
498def dstack(tables, join_type='outer', metadata_conflicts='warn'):
499    """
500    Stack columns within tables depth-wise
501
502    A ``join_type`` of 'exact' means that the tables must all have exactly
503    the same column names (though the order can vary).  If ``join_type``
504    is 'inner' then the intersection of common columns will be the output.
505    A value of 'outer' (default) means the output will have the union of
506    all columns, with table values being masked where no common values are
507    available.
508
509    Parameters
510    ----------
511    tables : `~astropy.table.Table` or `~astropy.table.Row` or list thereof
512        Table(s) to stack along depth-wise with the current table
513        Table columns should have same shape and name for depth-wise stacking
514    join_type : str
515        Join type ('inner' | 'exact' | 'outer'), default is 'outer'
516    metadata_conflicts : str
517        How to proceed with metadata conflicts. This should be one of:
518            * ``'silent'``: silently pick the last conflicting meta-data value
519            * ``'warn'``: pick the last conflicting meta-data value, but emit a warning (default)
520            * ``'error'``: raise an exception.
521
522    Returns
523    -------
524    stacked_table : `~astropy.table.Table` object
525        New table containing the stacked data from the input tables.
526
527    Examples
528    --------
529    To stack two tables along rows do::
530
531      >>> from astropy.table import vstack, Table
532      >>> t1 = Table({'a': [1, 2], 'b': [3, 4]}, names=('a', 'b'))
533      >>> t2 = Table({'a': [5, 6], 'b': [7, 8]}, names=('a', 'b'))
534      >>> print(t1)
535       a   b
536      --- ---
537        1   3
538        2   4
539      >>> print(t2)
540       a   b
541      --- ---
542        5   7
543        6   8
544      >>> print(dstack([t1, t2]))
545      a [2]  b [2]
546      ------ ------
547      1 .. 5 3 .. 7
548      2 .. 6 4 .. 8
549    """
550    _check_join_type(join_type, 'dstack')
551
552    tables = _get_list_of_tables(tables)
553    if len(tables) == 1:
554        return tables[0]  # no point in stacking a single table
555
556    n_rows = set(len(table) for table in tables)
557    if len(n_rows) != 1:
558        raise ValueError('Table lengths must all match for dstack')
559    n_row = n_rows.pop()
560
561    out = vstack(tables, join_type, metadata_conflicts)
562
563    for name, col in out.columns.items():
564        col = out[name]
565
566        # Reshape to so each original column is now in a row.
567        # If entries are not 0-dim then those additional shape dims
568        # are just carried along.
569        # [x x x y y y] => [[x x x],
570        #                   [y y y]]
571        new_shape = (len(tables), n_row) + col.shape[1:]
572        try:
573            col.shape = (len(tables), n_row) + col.shape[1:]
574        except AttributeError:
575            col = col.reshape(new_shape)
576
577        # Transpose the table and row axes to get to
578        # [[x, y],
579        #  [x, y]
580        #  [x, y]]
581        axes = np.arange(len(col.shape))
582        axes[:2] = [1, 0]
583
584        # This temporarily makes `out` be corrupted (columns of different
585        # length) but it all works out in the end.
586        out.columns.__setitem__(name, col.transpose(axes), validated=True)
587
588    return out
589
590
591def vstack(tables, join_type='outer', metadata_conflicts='warn'):
592    """
593    Stack tables vertically (along rows)
594
595    A ``join_type`` of 'exact' means that the tables must all have exactly
596    the same column names (though the order can vary).  If ``join_type``
597    is 'inner' then the intersection of common columns will be the output.
598    A value of 'outer' (default) means the output will have the union of
599    all columns, with table values being masked where no common values are
600    available.
601
602    Parameters
603    ----------
604    tables : `~astropy.table.Table` or `~astropy.table.Row` or list thereof
605        Table(s) to stack along rows (vertically) with the current table
606    join_type : str
607        Join type ('inner' | 'exact' | 'outer'), default is 'outer'
608    metadata_conflicts : str
609        How to proceed with metadata conflicts. This should be one of:
610            * ``'silent'``: silently pick the last conflicting meta-data value
611            * ``'warn'``: pick the last conflicting meta-data value, but emit a warning (default)
612            * ``'error'``: raise an exception.
613
614    Returns
615    -------
616    stacked_table : `~astropy.table.Table` object
617        New table containing the stacked data from the input tables.
618
619    Examples
620    --------
621    To stack two tables along rows do::
622
623      >>> from astropy.table import vstack, Table
624      >>> t1 = Table({'a': [1, 2], 'b': [3, 4]}, names=('a', 'b'))
625      >>> t2 = Table({'a': [5, 6], 'b': [7, 8]}, names=('a', 'b'))
626      >>> print(t1)
627       a   b
628      --- ---
629        1   3
630        2   4
631      >>> print(t2)
632       a   b
633      --- ---
634        5   7
635        6   8
636      >>> print(vstack([t1, t2]))
637       a   b
638      --- ---
639        1   3
640        2   4
641        5   7
642        6   8
643    """
644    _check_join_type(join_type, 'vstack')
645
646    tables = _get_list_of_tables(tables)  # validates input
647    if len(tables) == 1:
648        return tables[0]  # no point in stacking a single table
649    col_name_map = OrderedDict()
650
651    out = _vstack(tables, join_type, col_name_map, metadata_conflicts)
652
653    # Merge table metadata
654    _merge_table_meta(out, tables, metadata_conflicts=metadata_conflicts)
655
656    return out
657
658
659def hstack(tables, join_type='outer',
660           uniq_col_name='{col_name}_{table_name}', table_names=None,
661           metadata_conflicts='warn'):
662    """
663    Stack tables along columns (horizontally)
664
665    A ``join_type`` of 'exact' means that the tables must all
666    have exactly the same number of rows.  If ``join_type`` is 'inner' then
667    the intersection of rows will be the output.  A value of 'outer' (default)
668    means the output will have the union of all rows, with table values being
669    masked where no common values are available.
670
671    Parameters
672    ----------
673    tables : `~astropy.table.Table` or `~astropy.table.Row` or list thereof
674        Tables to stack along columns (horizontally) with the current table
675    join_type : str
676        Join type ('inner' | 'exact' | 'outer'), default is 'outer'
677    uniq_col_name : str or None
678        String generate a unique output column name in case of a conflict.
679        The default is '{col_name}_{table_name}'.
680    table_names : list of str or None
681        Two-element list of table names used when generating unique output
682        column names.  The default is ['1', '2', ..].
683    metadata_conflicts : str
684        How to proceed with metadata conflicts. This should be one of:
685            * ``'silent'``: silently pick the last conflicting meta-data value
686            * ``'warn'``: pick the last conflicting meta-data value,
687              but emit a warning (default)
688            * ``'error'``: raise an exception.
689
690    Returns
691    -------
692    stacked_table : `~astropy.table.Table` object
693        New table containing the stacked data from the input tables.
694
695    See Also
696    --------
697    Table.add_columns, Table.replace_column, Table.update
698
699    Examples
700    --------
701    To stack two tables horizontally (along columns) do::
702
703      >>> from astropy.table import Table, hstack
704      >>> t1 = Table({'a': [1, 2], 'b': [3, 4]}, names=('a', 'b'))
705      >>> t2 = Table({'c': [5, 6], 'd': [7, 8]}, names=('c', 'd'))
706      >>> print(t1)
707       a   b
708      --- ---
709        1   3
710        2   4
711      >>> print(t2)
712       c   d
713      --- ---
714        5   7
715        6   8
716      >>> print(hstack([t1, t2]))
717       a   b   c   d
718      --- --- --- ---
719        1   3   5   7
720        2   4   6   8
721    """
722    _check_join_type(join_type, 'hstack')
723
724    tables = _get_list_of_tables(tables)  # validates input
725    if len(tables) == 1:
726        return tables[0]  # no point in stacking a single table
727    col_name_map = OrderedDict()
728
729    out = _hstack(tables, join_type, uniq_col_name, table_names,
730                  col_name_map)
731
732    _merge_table_meta(out, tables, metadata_conflicts=metadata_conflicts)
733
734    return out
735
736
737def unique(input_table, keys=None, silent=False, keep='first'):
738    """
739    Returns the unique rows of a table.
740
741    Parameters
742    ----------
743    input_table : table-like
744    keys : str or list of str
745        Name(s) of column(s) used to create unique rows.
746        Default is to use all columns.
747    keep : {'first', 'last', 'none'}
748        Whether to keep the first or last row for each set of
749        duplicates. If 'none', all rows that are duplicate are
750        removed, leaving only rows that are already unique in
751        the input.
752        Default is 'first'.
753    silent : bool
754        If `True`, masked value column(s) are silently removed from
755        ``keys``. If `False`, an exception is raised when ``keys``
756        contains masked value column(s).
757        Default is `False`.
758
759    Returns
760    -------
761    unique_table : `~astropy.table.Table` object
762        New table containing only the unique rows of ``input_table``.
763
764    Examples
765    --------
766    >>> from astropy.table import unique, Table
767    >>> import numpy as np
768    >>> table = Table(data=[[1,2,3,2,3,3],
769    ... [2,3,4,5,4,6],
770    ... [3,4,5,6,7,8]],
771    ... names=['col1', 'col2', 'col3'],
772    ... dtype=[np.int32, np.int32, np.int32])
773    >>> table
774    <Table length=6>
775     col1  col2  col3
776    int32 int32 int32
777    ----- ----- -----
778        1     2     3
779        2     3     4
780        3     4     5
781        2     5     6
782        3     4     7
783        3     6     8
784    >>> unique(table, keys='col1')
785    <Table length=3>
786     col1  col2  col3
787    int32 int32 int32
788    ----- ----- -----
789        1     2     3
790        2     3     4
791        3     4     5
792    >>> unique(table, keys=['col1'], keep='last')
793    <Table length=3>
794     col1  col2  col3
795    int32 int32 int32
796    ----- ----- -----
797        1     2     3
798        2     5     6
799        3     6     8
800    >>> unique(table, keys=['col1', 'col2'])
801    <Table length=5>
802     col1  col2  col3
803    int32 int32 int32
804    ----- ----- -----
805        1     2     3
806        2     3     4
807        2     5     6
808        3     4     5
809        3     6     8
810    >>> unique(table, keys=['col1', 'col2'], keep='none')
811    <Table length=4>
812     col1  col2  col3
813    int32 int32 int32
814    ----- ----- -----
815        1     2     3
816        2     3     4
817        2     5     6
818        3     6     8
819    >>> unique(table, keys=['col1'], keep='none')
820    <Table length=1>
821     col1  col2  col3
822    int32 int32 int32
823    ----- ----- -----
824        1     2     3
825
826    """
827
828    if keep not in ('first', 'last', 'none'):
829        raise ValueError("'keep' should be one of 'first', 'last', 'none'")
830
831    if isinstance(keys, str):
832        keys = [keys]
833    if keys is None:
834        keys = input_table.colnames
835    else:
836        if len(set(keys)) != len(keys):
837            raise ValueError("duplicate key names")
838
839    # Check for columns with masked values
840    for key in keys[:]:
841        col = input_table[key]
842        if hasattr(col, 'mask') and np.any(col.mask):
843            if not silent:
844                raise ValueError(
845                    "cannot use columns with masked values as keys; "
846                    "remove column '{}' from keys and rerun "
847                    "unique()".format(key))
848            del keys[keys.index(key)]
849    if len(keys) == 0:
850        raise ValueError("no column remained in ``keys``; "
851                         "unique() cannot work with masked value "
852                         "key columns")
853
854    grouped_table = input_table.group_by(keys)
855    indices = grouped_table.groups.indices
856    if keep == 'first':
857        indices = indices[:-1]
858    elif keep == 'last':
859        indices = indices[1:] - 1
860    else:
861        indices = indices[:-1][np.diff(indices) == 1]
862
863    return grouped_table[indices]
864
865
866def get_col_name_map(arrays, common_names, uniq_col_name='{col_name}_{table_name}',
867                     table_names=None):
868    """
869    Find the column names mapping when merging the list of tables
870    ``arrays``.  It is assumed that col names in ``common_names`` are to be
871    merged into a single column while the rest will be uniquely represented
872    in the output.  The args ``uniq_col_name`` and ``table_names`` specify
873    how to rename columns in case of conflicts.
874
875    Returns a dict mapping each output column name to the input(s).  This takes the form
876    {outname : (col_name_0, col_name_1, ...), ... }.  For key columns all of input names
877    will be present, while for the other non-key columns the value will be (col_name_0,
878    None, ..) or (None, col_name_1, ..) etc.
879    """
880
881    col_name_map = collections.defaultdict(lambda: [None] * len(arrays))
882    col_name_list = []
883
884    if table_names is None:
885        table_names = [str(ii + 1) for ii in range(len(arrays))]
886
887    for idx, array in enumerate(arrays):
888        table_name = table_names[idx]
889        for name in array.colnames:
890            out_name = name
891
892            if name in common_names:
893                # If name is in the list of common_names then insert into
894                # the column name list, but just once.
895                if name not in col_name_list:
896                    col_name_list.append(name)
897            else:
898                # If name is not one of the common column outputs, and it collides
899                # with the names in one of the other arrays, then rename
900                others = list(arrays)
901                others.pop(idx)
902                if any(name in other.colnames for other in others):
903                    out_name = uniq_col_name.format(table_name=table_name, col_name=name)
904                col_name_list.append(out_name)
905
906            col_name_map[out_name][idx] = name
907
908    # Check for duplicate output column names
909    col_name_count = Counter(col_name_list)
910    repeated_names = [name for name, count in col_name_count.items() if count > 1]
911    if repeated_names:
912        raise TableMergeError('Merging column names resulted in duplicates: {}.  '
913                              'Change uniq_col_name or table_names args to fix this.'
914                              .format(repeated_names))
915
916    # Convert col_name_map to a regular dict with tuple (immutable) values
917    col_name_map = OrderedDict((name, col_name_map[name]) for name in col_name_list)
918
919    return col_name_map
920
921
922def get_descrs(arrays, col_name_map):
923    """
924    Find the dtypes descrs resulting from merging the list of arrays' dtypes,
925    using the column name mapping ``col_name_map``.
926
927    Return a list of descrs for the output.
928    """
929
930    out_descrs = []
931
932    for out_name, in_names in col_name_map.items():
933        # List of input arrays that contribute to this output column
934        in_cols = [arr[name] for arr, name in zip(arrays, in_names) if name is not None]
935
936        # List of names of the columns that contribute to this output column.
937        names = [name for name in in_names if name is not None]
938
939        # Output dtype is the superset of all dtypes in in_arrays
940        try:
941            dtype = common_dtype(in_cols)
942        except TableMergeError as tme:
943            # Beautify the error message when we are trying to merge columns with incompatible
944            # types by including the name of the columns that originated the error.
945            raise TableMergeError("The '{}' columns have incompatible types: {}"
946                                  .format(names[0], tme._incompat_types)) from tme
947
948        # Make sure all input shapes are the same
949        uniq_shapes = set(col.shape[1:] for col in in_cols)
950        if len(uniq_shapes) != 1:
951            raise TableMergeError(f'Key columns {names!r} have different shape')
952        shape = uniq_shapes.pop()
953
954        out_descrs.append((fix_column_name(out_name), dtype, shape))
955
956    return out_descrs
957
958
959def common_dtype(cols):
960    """
961    Use numpy to find the common dtype for a list of columns.
962
963    Only allow columns within the following fundamental numpy data types:
964    np.bool_, np.object_, np.number, np.character, np.void
965    """
966    try:
967        return metadata.common_dtype(cols)
968    except metadata.MergeConflictError as err:
969        tme = TableMergeError(f'Columns have incompatible types {err._incompat_types}')
970        tme._incompat_types = err._incompat_types
971        raise tme from err
972
973
974def _get_join_sort_idxs(keys, left, right):
975    # Go through each of the key columns in order and make columns for
976    # a new structured array that represents the lexical ordering of those
977    # key columns. This structured array is then argsort'ed. The trick here
978    # is that some columns (e.g. Time) may need to be expanded into multiple
979    # columns for ordering here.
980
981    ii = 0  # Index for uniquely naming the sort columns
982    sort_keys_dtypes = []  # sortable_table dtypes as list of (name, dtype_str, shape) tuples
983    sort_keys = []  # sortable_table (structured ndarray) column names
984    sort_left = {}  # sortable ndarrays from left table
985    sort_right = {}  # sortable ndarray from right table
986
987    for key in keys:
988        # get_sortable_arrays() returns a list of ndarrays that can be lexically
989        # sorted to represent the order of the column. In most cases this is just
990        # a single element of the column itself.
991        left_sort_cols = left[key].info.get_sortable_arrays()
992        right_sort_cols = right[key].info.get_sortable_arrays()
993
994        if len(left_sort_cols) != len(right_sort_cols):
995            # Should never happen because cols are screened beforehand for compatibility
996            raise RuntimeError('mismatch in sort cols lengths')
997
998        for left_sort_col, right_sort_col in zip(left_sort_cols, right_sort_cols):
999            # Check for consistency of shapes. Mismatch should never happen.
1000            shape = left_sort_col.shape[1:]
1001            if shape != right_sort_col.shape[1:]:
1002                raise RuntimeError('mismatch in shape of left vs. right sort array')
1003
1004            if shape != ():
1005                raise ValueError(f'sort key column {key!r} must be 1-d')
1006
1007            sort_key = str(ii)
1008            sort_keys.append(sort_key)
1009            sort_left[sort_key] = left_sort_col
1010            sort_right[sort_key] = right_sort_col
1011
1012            # Build up dtypes for the structured array that gets sorted.
1013            dtype_str = common_dtype([left_sort_col, right_sort_col])
1014            sort_keys_dtypes.append((sort_key, dtype_str))
1015            ii += 1
1016
1017    # Make the empty sortable table and fill it
1018    len_left = len(left)
1019    sortable_table = np.empty(len_left + len(right), dtype=sort_keys_dtypes)
1020    for key in sort_keys:
1021        sortable_table[key][:len_left] = sort_left[key]
1022        sortable_table[key][len_left:] = sort_right[key]
1023
1024    # Finally do the (lexical) argsort and make a new sorted version
1025    idx_sort = sortable_table.argsort(order=sort_keys)
1026    sorted_table = sortable_table[idx_sort]
1027
1028    # Get indexes of unique elements (i.e. the group boundaries)
1029    diffs = np.concatenate(([True], sorted_table[1:] != sorted_table[:-1], [True]))
1030    idxs = np.flatnonzero(diffs)
1031
1032    return idxs, idx_sort
1033
1034
1035def _apply_join_funcs(left, right, keys, join_funcs):
1036    """Apply join_funcs
1037    """
1038    # Make light copies of left and right, then add new index columns.
1039    left = left.copy(copy_data=False)
1040    right = right.copy(copy_data=False)
1041    for key, join_func in join_funcs.items():
1042        ids1, ids2 = join_func(left[key], right[key])
1043        # Define a unique id_key name, and keep adding underscores until we have
1044        # a name not yet present.
1045        id_key = key + '_id'
1046        while id_key in left.columns or id_key in right.columns:
1047            id_key = id_key[:-2] + '_id'
1048
1049        keys = tuple(id_key if orig_key == key else orig_key for orig_key in keys)
1050        left.add_column(ids1, index=0, name=id_key)  # [id_key] = ids1
1051        right.add_column(ids2, index=0, name=id_key)  # [id_key] = ids2
1052
1053    return left, right, keys
1054
1055
1056def _join(left, right, keys=None, join_type='inner',
1057          uniq_col_name='{col_name}_{table_name}',
1058          table_names=['1', '2'],
1059          col_name_map=None, metadata_conflicts='warn',
1060          join_funcs=None,
1061          keys_left=None, keys_right=None):
1062    """
1063    Perform a join of the left and right Tables on specified keys.
1064
1065    Parameters
1066    ----------
1067    left : Table
1068        Left side table in the join
1069    right : Table
1070        Right side table in the join
1071    keys : str or list of str
1072        Name(s) of column(s) used to match rows of left and right tables.
1073        Default is to use all columns which are common to both tables.
1074    join_type : str
1075        Join type ('inner' | 'outer' | 'left' | 'right' | 'cartesian'), default is 'inner'
1076    uniq_col_name : str or None
1077        String generate a unique output column name in case of a conflict.
1078        The default is '{col_name}_{table_name}'.
1079    table_names : list of str or None
1080        Two-element list of table names used when generating unique output
1081        column names.  The default is ['1', '2'].
1082    col_name_map : empty dict or None
1083        If passed as a dict then it will be updated in-place with the
1084        mapping of output to input column names.
1085    metadata_conflicts : str
1086        How to proceed with metadata conflicts. This should be one of:
1087            * ``'silent'``: silently pick the last conflicting meta-data value
1088            * ``'warn'``: pick the last conflicting meta-data value, but emit a warning (default)
1089            * ``'error'``: raise an exception.
1090    join_funcs : dict, None
1091        Dict of functions to use for matching the corresponding key column(s).
1092        See `~astropy.table.join_skycoord` for an example and details.
1093
1094    Returns
1095    -------
1096    joined_table : `~astropy.table.Table` object
1097        New table containing the result of the join operation.
1098    """
1099    # Store user-provided col_name_map until the end
1100    _col_name_map = col_name_map
1101
1102    # Special column name for cartesian join, should never collide with real column
1103    cartesian_index_name = '__table_cartesian_join_temp_index__'
1104
1105    if join_type not in ('inner', 'outer', 'left', 'right', 'cartesian'):
1106        raise ValueError("The 'join_type' argument should be in 'inner', "
1107                         "'outer', 'left', 'right', or 'cartesian' "
1108                         "(got '{}' instead)".
1109                         format(join_type))
1110
1111    if join_type == 'cartesian':
1112        if keys:
1113            raise ValueError('cannot supply keys for a cartesian join')
1114
1115        if join_funcs:
1116            raise ValueError('cannot supply join_funcs for a cartesian join')
1117
1118        # Make light copies of left and right, then add temporary index columns
1119        # with all the same value so later an outer join turns into a cartesian join.
1120        left = left.copy(copy_data=False)
1121        right = right.copy(copy_data=False)
1122        left[cartesian_index_name] = np.uint8(0)
1123        right[cartesian_index_name] = np.uint8(0)
1124        keys = (cartesian_index_name, )
1125
1126    # Handle the case of join key columns that are different between left and
1127    # right via keys_left/keys_right args. This is done by saving the original
1128    # input tables and making new left and right tables that contain only the
1129    # key cols but with common column names ['0', '1', etc]. This sets `keys` to
1130    # those fake key names in the left and right tables
1131    if keys_left is not None or keys_right is not None:
1132        left_orig = left
1133        right_orig = right
1134        left, right, keys = _join_keys_left_right(
1135            left, right, keys, keys_left, keys_right, join_funcs)
1136
1137    if keys is None:
1138        keys = tuple(name for name in left.colnames if name in right.colnames)
1139        if len(keys) == 0:
1140            raise TableMergeError('No keys in common between left and right tables')
1141    elif isinstance(keys, str):
1142        # If we have a single key, put it in a tuple
1143        keys = (keys,)
1144
1145    # Check the key columns
1146    for arr, arr_label in ((left, 'Left'), (right, 'Right')):
1147        for name in keys:
1148            if name not in arr.colnames:
1149                raise TableMergeError('{} table does not have key column {!r}'
1150                                      .format(arr_label, name))
1151            if hasattr(arr[name], 'mask') and np.any(arr[name].mask):
1152                raise TableMergeError('{} key column {!r} has missing values'
1153                                      .format(arr_label, name))
1154
1155    if join_funcs is not None:
1156        if not all(key in keys for key in join_funcs):
1157            raise ValueError(f'join_funcs keys {join_funcs.keys()} must be a '
1158                             f'subset of join keys {keys}')
1159        left, right, keys = _apply_join_funcs(left, right, keys, join_funcs)
1160
1161    len_left, len_right = len(left), len(right)
1162
1163    if len_left == 0 or len_right == 0:
1164        raise ValueError('input tables for join must both have at least one row')
1165
1166    try:
1167        idxs, idx_sort = _get_join_sort_idxs(keys, left, right)
1168    except NotImplementedError:
1169        raise TypeError('one or more key columns are not sortable')
1170
1171    # Now that we have idxs and idx_sort, revert to the original table args to
1172    # carry on with making the output joined table. `keys` is set to to an empty
1173    # list so that all original left and right columns are included in the
1174    # output table.
1175    if keys_left is not None or keys_right is not None:
1176        keys = []
1177        left = left_orig
1178        right = right_orig
1179
1180    # Joined array dtype as a list of descr (name, type_str, shape) tuples
1181    col_name_map = get_col_name_map([left, right], keys, uniq_col_name, table_names)
1182    out_descrs = get_descrs([left, right], col_name_map)
1183
1184    # Main inner loop in Cython to compute the cartesian product
1185    # indices for the given join type
1186    int_join_type = {'inner': 0, 'outer': 1, 'left': 2, 'right': 3,
1187                     'cartesian': 1}[join_type]
1188    masked, n_out, left_out, left_mask, right_out, right_mask = \
1189        _np_utils.join_inner(idxs, idx_sort, len_left, int_join_type)
1190
1191    out = _get_out_class([left, right])()
1192
1193    for out_name, dtype, shape in out_descrs:
1194        if out_name == cartesian_index_name:
1195            continue
1196
1197        left_name, right_name = col_name_map[out_name]
1198        if left_name and right_name:  # this is a key which comes from left and right
1199            cols = [left[left_name], right[right_name]]
1200
1201            col_cls = _get_out_class(cols)
1202            if not hasattr(col_cls.info, 'new_like'):
1203                raise NotImplementedError('join unavailable for mixin column type(s): {}'
1204                                          .format(col_cls.__name__))
1205
1206            out[out_name] = col_cls.info.new_like(cols, n_out, metadata_conflicts, out_name)
1207            out[out_name][:] = np.where(right_mask,
1208                                        left[left_name].take(left_out),
1209                                        right[right_name].take(right_out))
1210            continue
1211        elif left_name:  # out_name came from the left table
1212            name, array, array_out, array_mask = left_name, left, left_out, left_mask
1213        elif right_name:
1214            name, array, array_out, array_mask = right_name, right, right_out, right_mask
1215        else:
1216            raise TableMergeError('Unexpected column names (maybe one is ""?)')
1217
1218        # Select the correct elements from the original table
1219        col = array[name][array_out]
1220
1221        # If the output column is masked then set the output column masking
1222        # accordingly.  Check for columns that don't support a mask attribute.
1223        if masked and np.any(array_mask):
1224            # If col is a Column but not MaskedColumn then upgrade at this point
1225            # because masking is required.
1226            if isinstance(col, Column) and not isinstance(col, MaskedColumn):
1227                col = out.MaskedColumn(col, copy=False)
1228
1229            if isinstance(col, Quantity) and not isinstance(col, Masked):
1230                col = Masked(col, copy=False)
1231
1232            # array_mask is 1-d corresponding to length of output column.  We need
1233            # make it have the correct shape for broadcasting, i.e. (length, 1, 1, ..).
1234            # Mixin columns might not have ndim attribute so use len(col.shape).
1235            array_mask.shape = (col.shape[0],) + (1,) * (len(col.shape) - 1)
1236
1237            # Now broadcast to the correct final shape
1238            array_mask = np.broadcast_to(array_mask, col.shape)
1239
1240            try:
1241                col[array_mask] = col.info.mask_val
1242            except Exception as err:  # Not clear how different classes will fail here
1243                raise NotImplementedError(
1244                    "join requires masking column '{}' but column"
1245                    " type {} does not support masking"
1246                    .format(out_name, col.__class__.__name__)) from err
1247
1248        # Set the output table column to the new joined column
1249        out[out_name] = col
1250
1251    # If col_name_map supplied as a dict input, then update.
1252    if isinstance(_col_name_map, Mapping):
1253        _col_name_map.update(col_name_map)
1254
1255    return out
1256
1257
1258def _join_keys_left_right(left, right, keys, keys_left, keys_right, join_funcs):
1259    """Do processing to handle keys_left / keys_right args for join.
1260
1261    This takes the keys_left/right inputs and turns them into a list of left/right
1262    columns corresponding to those inputs (which can be column names or column
1263    data values). It also generates the list of fake key column names (strings
1264    of "1", "2", etc.) that correspond to the input keys.
1265    """
1266    def _keys_to_cols(keys, table, label):
1267        # Process input `keys`, which is a str or list of str column names in
1268        # `table` or a list of column-like objects. The `label` is just for
1269        # error reporting.
1270        if isinstance(keys, str):
1271            keys = [keys]
1272        cols = []
1273        for key in keys:
1274            if isinstance(key, str):
1275                try:
1276                    cols.append(table[key])
1277                except KeyError:
1278                    raise ValueError(f'{label} table does not have key column {key!r}')
1279            else:
1280                if len(key) != len(table):
1281                    raise ValueError(f'{label} table has different length from key {key}')
1282                cols.append(key)
1283        return cols
1284
1285    if join_funcs is not None:
1286        raise ValueError('cannot supply join_funcs arg and keys_left / keys_right')
1287
1288    if keys_left is None or keys_right is None:
1289        raise ValueError('keys_left and keys_right must both be provided')
1290
1291    if keys is not None:
1292        raise ValueError('keys arg must be None if keys_left and keys_right are supplied')
1293
1294    cols_left = _keys_to_cols(keys_left, left, 'left')
1295    cols_right = _keys_to_cols(keys_right, right, 'right')
1296
1297    if len(cols_left) != len(cols_right):
1298        raise ValueError('keys_left and keys_right args must have same length')
1299
1300    # Make two new temp tables for the join with only the join columns and
1301    # key columns in common.
1302    keys = [f'{ii}' for ii in range(len(cols_left))]
1303
1304    left = left.__class__(cols_left, names=keys, copy=False)
1305    right = right.__class__(cols_right, names=keys, copy=False)
1306
1307    return left, right, keys
1308
1309
1310def _check_join_type(join_type, func_name):
1311    """Check join_type arg in hstack and vstack.
1312
1313    This specifically checks for the common mistake of call vstack(t1, t2)
1314    instead of vstack([t1, t2]). The subsequent check of
1315    ``join_type in ('inner', ..)`` does not raise in this case.
1316    """
1317    if not isinstance(join_type, str):
1318        msg = '`join_type` arg must be a string'
1319        if isinstance(join_type, Table):
1320            msg += ('. Did you accidentally '
1321                    f'call {func_name}(t1, t2, ..) instead of '
1322                    f'{func_name}([t1, t2], ..)?')
1323        raise TypeError(msg)
1324
1325    if join_type not in ('inner', 'exact', 'outer'):
1326        raise ValueError("`join_type` arg must be one of 'inner', 'exact' or 'outer'")
1327
1328
1329def _vstack(arrays, join_type='outer', col_name_map=None, metadata_conflicts='warn'):
1330    """
1331    Stack Tables vertically (by rows)
1332
1333    A ``join_type`` of 'exact' (default) means that the arrays must all
1334    have exactly the same column names (though the order can vary).  If
1335    ``join_type`` is 'inner' then the intersection of common columns will
1336    be the output.  A value of 'outer' means the output will have the union of
1337    all columns, with array values being masked where no common values are
1338    available.
1339
1340    Parameters
1341    ----------
1342    arrays : list of Tables
1343        Tables to stack by rows (vertically)
1344    join_type : str
1345        Join type ('inner' | 'exact' | 'outer'), default is 'outer'
1346    col_name_map : empty dict or None
1347        If passed as a dict then it will be updated in-place with the
1348        mapping of output to input column names.
1349
1350    Returns
1351    -------
1352    stacked_table : `~astropy.table.Table` object
1353        New table containing the stacked data from the input tables.
1354    """
1355    # Store user-provided col_name_map until the end
1356    _col_name_map = col_name_map
1357
1358    # Trivial case of one input array
1359    if len(arrays) == 1:
1360        return arrays[0]
1361
1362    # Start by assuming an outer match where all names go to output
1363    names = set(itertools.chain(*[arr.colnames for arr in arrays]))
1364    col_name_map = get_col_name_map(arrays, names)
1365
1366    # If require_match is True then the output must have exactly the same
1367    # number of columns as each input array
1368    if join_type == 'exact':
1369        for names in col_name_map.values():
1370            if any(x is None for x in names):
1371                raise TableMergeError('Inconsistent columns in input arrays '
1372                                      "(use 'inner' or 'outer' join_type to "
1373                                      "allow non-matching columns)")
1374        join_type = 'outer'
1375
1376    # For an inner join, keep only columns where all input arrays have that column
1377    if join_type == 'inner':
1378        col_name_map = OrderedDict((name, in_names) for name, in_names in col_name_map.items()
1379                                   if all(x is not None for x in in_names))
1380        if len(col_name_map) == 0:
1381            raise TableMergeError('Input arrays have no columns in common')
1382
1383    lens = [len(arr) for arr in arrays]
1384    n_rows = sum(lens)
1385    out = _get_out_class(arrays)()
1386
1387    for out_name, in_names in col_name_map.items():
1388        # List of input arrays that contribute to this output column
1389        cols = [arr[name] for arr, name in zip(arrays, in_names) if name is not None]
1390
1391        col_cls = _get_out_class(cols)
1392        if not hasattr(col_cls.info, 'new_like'):
1393            raise NotImplementedError('vstack unavailable for mixin column type(s): {}'
1394                                      .format(col_cls.__name__))
1395        try:
1396            col = col_cls.info.new_like(cols, n_rows, metadata_conflicts, out_name)
1397        except metadata.MergeConflictError as err:
1398            # Beautify the error message when we are trying to merge columns with incompatible
1399            # types by including the name of the columns that originated the error.
1400            raise TableMergeError("The '{}' columns have incompatible types: {}"
1401                                  .format(out_name, err._incompat_types)) from err
1402
1403        idx0 = 0
1404        for name, array in zip(in_names, arrays):
1405            idx1 = idx0 + len(array)
1406            if name in array.colnames:
1407                col[idx0:idx1] = array[name]
1408            else:
1409                # If col is a Column but not MaskedColumn then upgrade at this point
1410                # because masking is required.
1411                if isinstance(col, Column) and not isinstance(col, MaskedColumn):
1412                    col = out.MaskedColumn(col, copy=False)
1413
1414                if isinstance(col, Quantity) and not isinstance(col, Masked):
1415                    col = Masked(col, copy=False)
1416
1417                try:
1418                    col[idx0:idx1] = col.info.mask_val
1419                except Exception as err:
1420                    raise NotImplementedError(
1421                        "vstack requires masking column '{}' but column"
1422                        " type {} does not support masking"
1423                        .format(out_name, col.__class__.__name__)) from err
1424            idx0 = idx1
1425
1426        out[out_name] = col
1427
1428    # If col_name_map supplied as a dict input, then update.
1429    if isinstance(_col_name_map, Mapping):
1430        _col_name_map.update(col_name_map)
1431
1432    return out
1433
1434
1435def _hstack(arrays, join_type='outer', uniq_col_name='{col_name}_{table_name}',
1436            table_names=None, col_name_map=None):
1437    """
1438    Stack tables horizontally (by columns)
1439
1440    A ``join_type`` of 'exact' (default) means that the arrays must all
1441    have exactly the same number of rows.  If ``join_type`` is 'inner' then
1442    the intersection of rows will be the output.  A value of 'outer' means
1443    the output will have the union of all rows, with array values being
1444    masked where no common values are available.
1445
1446    Parameters
1447    ----------
1448    arrays : List of tables
1449        Tables to stack by columns (horizontally)
1450    join_type : str
1451        Join type ('inner' | 'exact' | 'outer'), default is 'outer'
1452    uniq_col_name : str or None
1453        String generate a unique output column name in case of a conflict.
1454        The default is '{col_name}_{table_name}'.
1455    table_names : list of str or None
1456        Two-element list of table names used when generating unique output
1457        column names.  The default is ['1', '2', ..].
1458
1459    Returns
1460    -------
1461    stacked_table : `~astropy.table.Table` object
1462        New table containing the stacked data from the input tables.
1463    """
1464
1465    # Store user-provided col_name_map until the end
1466    _col_name_map = col_name_map
1467
1468    if table_names is None:
1469        table_names = [f'{ii + 1}' for ii in range(len(arrays))]
1470    if len(arrays) != len(table_names):
1471        raise ValueError('Number of arrays must match number of table_names')
1472
1473    # Trivial case of one input arrays
1474    if len(arrays) == 1:
1475        return arrays[0]
1476
1477    col_name_map = get_col_name_map(arrays, [], uniq_col_name, table_names)
1478
1479    # If require_match is True then all input arrays must have the same length
1480    arr_lens = [len(arr) for arr in arrays]
1481    if join_type == 'exact':
1482        if len(set(arr_lens)) > 1:
1483            raise TableMergeError("Inconsistent number of rows in input arrays "
1484                                  "(use 'inner' or 'outer' join_type to allow "
1485                                  "non-matching rows)")
1486        join_type = 'outer'
1487
1488    # For an inner join, keep only the common rows
1489    if join_type == 'inner':
1490        min_arr_len = min(arr_lens)
1491        if len(set(arr_lens)) > 1:
1492            arrays = [arr[:min_arr_len] for arr in arrays]
1493        arr_lens = [min_arr_len for arr in arrays]
1494
1495    # If there are any output rows where one or more input arrays are missing
1496    # then the output must be masked.  If any input arrays are masked then
1497    # output is masked.
1498
1499    n_rows = max(arr_lens)
1500    out = _get_out_class(arrays)()
1501
1502    for out_name, in_names in col_name_map.items():
1503        for name, array, arr_len in zip(in_names, arrays, arr_lens):
1504            if name is None:
1505                continue
1506
1507            if n_rows > arr_len:
1508                indices = np.arange(n_rows)
1509                indices[arr_len:] = 0
1510                col = array[name][indices]
1511
1512                # If col is a Column but not MaskedColumn then upgrade at this point
1513                # because masking is required.
1514                if isinstance(col, Column) and not isinstance(col, MaskedColumn):
1515                    col = out.MaskedColumn(col, copy=False)
1516
1517                if isinstance(col, Quantity) and not isinstance(col, Masked):
1518                    col = Masked(col, copy=False)
1519
1520                try:
1521                    col[arr_len:] = col.info.mask_val
1522                except Exception as err:
1523                    raise NotImplementedError(
1524                        "hstack requires masking column '{}' but column"
1525                        " type {} does not support masking"
1526                        .format(out_name, col.__class__.__name__)) from err
1527            else:
1528                col = array[name][:n_rows]
1529
1530            out[out_name] = col
1531
1532    # If col_name_map supplied as a dict input, then update.
1533    if isinstance(_col_name_map, Mapping):
1534        _col_name_map.update(col_name_map)
1535
1536    return out
1537