1""" 2High-level table operations: 3 4- join() 5- setdiff() 6- hstack() 7- vstack() 8- dstack() 9""" 10# Licensed under a 3-clause BSD style license - see LICENSE.rst 11 12from copy import deepcopy 13import collections 14import itertools 15from collections import OrderedDict, Counter 16from collections.abc import Mapping, Sequence 17 18import numpy as np 19 20from astropy.utils import metadata 21from astropy.utils.masked import Masked 22from .table import Table, QTable, Row, Column, MaskedColumn 23from astropy.units import Quantity 24 25from . import _np_utils 26from .np_utils import fix_column_name, TableMergeError 27 28__all__ = ['join', 'setdiff', 'hstack', 'vstack', 'unique', 29 'join_skycoord', 'join_distance'] 30 31__doctest_requires__ = {'join_skycoord': ['scipy'], 'join_distance': ['scipy']} 32 33 34def _merge_table_meta(out, tables, metadata_conflicts='warn'): 35 out_meta = deepcopy(tables[0].meta) 36 for table in tables[1:]: 37 out_meta = metadata.merge(out_meta, table.meta, metadata_conflicts=metadata_conflicts) 38 out.meta.update(out_meta) 39 40 41def _get_list_of_tables(tables): 42 """ 43 Check that tables is a Table or sequence of Tables. Returns the 44 corresponding list of Tables. 45 """ 46 47 # Make sure we have a list of things 48 if not isinstance(tables, Sequence): 49 tables = [tables] 50 51 # Make sure there is something to stack 52 if len(tables) == 0: 53 raise ValueError('no values provided to stack.') 54 55 # Convert inputs (Table, Row, or anything column-like) to Tables. 56 # Special case that Quantity converts to a QTable. 57 for ii, val in enumerate(tables): 58 if isinstance(val, Table): 59 pass 60 elif isinstance(val, Row): 61 tables[ii] = Table(val) 62 elif isinstance(val, Quantity): 63 tables[ii] = QTable([val]) 64 else: 65 try: 66 tables[ii] = Table([val]) 67 except (ValueError, TypeError) as err: 68 raise TypeError(f'Cannot convert {val} to table column.') from err 69 70 return tables 71 72 73def _get_out_class(objs): 74 """ 75 From a list of input objects ``objs`` get merged output object class. 76 77 This is just taken as the deepest subclass. This doesn't handle complicated 78 inheritance schemes, but as a special case, classes which share ``info`` 79 are taken to be compatible. 80 """ 81 out_class = objs[0].__class__ 82 for obj in objs[1:]: 83 if issubclass(obj.__class__, out_class): 84 out_class = obj.__class__ 85 86 if any(not (issubclass(out_class, obj.__class__) 87 or out_class.info is obj.__class__.info) for obj in objs): 88 raise ValueError('unmergeable object classes {}' 89 .format([obj.__class__.__name__ for obj in objs])) 90 91 return out_class 92 93 94def join_skycoord(distance, distance_func='search_around_sky'): 95 """Helper function to join on SkyCoord columns using distance matching. 96 97 This function is intended for use in ``table.join()`` to allow performing a 98 table join where the key columns are both ``SkyCoord`` objects, matched by 99 computing the distance between points and accepting values below 100 ``distance``. 101 102 The distance cross-matching is done using either 103 `~astropy.coordinates.search_around_sky` or 104 `~astropy.coordinates.search_around_3d`, depending on the value of 105 ``distance_func``. The default is ``'search_around_sky'``. 106 107 One can also provide a function object for ``distance_func``, in which case 108 it must be a function that follows the same input and output API as 109 `~astropy.coordinates.search_around_sky`. In this case the function will 110 be called with ``(skycoord1, skycoord2, distance)`` as arguments. 111 112 Parameters 113 ---------- 114 distance : `~astropy.units.Quantity` ['angle', 'length'] 115 Maximum distance between points to be considered a join match. 116 Must have angular or distance units. 117 distance_func : str or function 118 Specifies the function for performing the cross-match based on 119 ``distance``. If supplied as a string this specifies the name of a 120 function in `astropy.coordinates`. If supplied as a function then that 121 function is called directly. 122 123 Returns 124 ------- 125 join_func : function 126 Function that accepts two ``SkyCoord`` columns (col1, col2) and returns 127 the tuple (ids1, ids2) of pair-matched unique identifiers. 128 129 Examples 130 -------- 131 This example shows an inner join of two ``SkyCoord`` columns, taking any 132 sources within 0.2 deg to be a match. Note the new ``sc_id`` column which 133 is added and provides a unique source identifier for the matches. 134 135 >>> from astropy.coordinates import SkyCoord 136 >>> import astropy.units as u 137 >>> from astropy.table import Table, join_skycoord 138 >>> from astropy import table 139 140 >>> sc1 = SkyCoord([0, 1, 1.1, 2], [0, 0, 0, 0], unit='deg') 141 >>> sc2 = SkyCoord([0.5, 1.05, 2.1], [0, 0, 0], unit='deg') 142 143 >>> join_func = join_skycoord(0.2 * u.deg) 144 >>> join_func(sc1, sc2) # Associate each coordinate with unique source ID 145 (array([3, 1, 1, 2]), array([4, 1, 2])) 146 147 >>> t1 = Table([sc1], names=['sc']) 148 >>> t2 = Table([sc2], names=['sc']) 149 >>> t12 = table.join(t1, t2, join_funcs={'sc': join_skycoord(0.2 * u.deg)}) 150 >>> print(t12) # Note new `sc_id` column with the IDs from join_func() 151 sc_id sc_1 sc_2 152 deg,deg deg,deg 153 ----- ------- -------- 154 1 1.0,0.0 1.05,0.0 155 1 1.1,0.0 1.05,0.0 156 2 2.0,0.0 2.1,0.0 157 158 """ 159 if isinstance(distance_func, str): 160 import astropy.coordinates as coords 161 try: 162 distance_func = getattr(coords, distance_func) 163 except AttributeError as err: 164 raise ValueError('distance_func must be a function in astropy.coordinates') from err 165 else: 166 from inspect import isfunction 167 if not isfunction(distance_func): 168 raise ValueError('distance_func must be a str or function') 169 170 def join_func(sc1, sc2): 171 172 # Call the appropriate SkyCoord method to find pairs within distance 173 idxs1, idxs2, d2d, d3d = distance_func(sc1, sc2, distance) 174 175 # Now convert that into unique identifiers for each near-pair. This is 176 # taken to be transitive, so that if points 1 and 2 are "near" and points 177 # 1 and 3 are "near", then 1, 2, and 3 are all given the same identifier. 178 # This identifier will then be used in the table join matching. 179 180 # Identifiers for each column, initialized to all zero. 181 ids1 = np.zeros(len(sc1), dtype=int) 182 ids2 = np.zeros(len(sc2), dtype=int) 183 184 # Start the identifier count at 1 185 id_ = 1 186 for idx1, idx2 in zip(idxs1, idxs2): 187 # If this col1 point is previously identified then set corresponding 188 # col2 point to same identifier. Likewise for col2 and col1. 189 if ids1[idx1] > 0: 190 ids2[idx2] = ids1[idx1] 191 elif ids2[idx2] > 0: 192 ids1[idx1] = ids2[idx2] 193 else: 194 # Not yet seen so set identifier for col1 and col2 195 ids1[idx1] = id_ 196 ids2[idx2] = id_ 197 id_ += 1 198 199 # Fill in unique identifiers for points with no near neighbor 200 for ids in (ids1, ids2): 201 for idx in np.flatnonzero(ids == 0): 202 ids[idx] = id_ 203 id_ += 1 204 205 # End of enclosure join_func() 206 return ids1, ids2 207 208 return join_func 209 210 211def join_distance(distance, kdtree_args=None, query_args=None): 212 """Helper function to join table columns using distance matching. 213 214 This function is intended for use in ``table.join()`` to allow performing 215 a table join where the key columns are matched by computing the distance 216 between points and accepting values below ``distance``. This numerical 217 "fuzzy" match can apply to 1-D or 2-D columns, where in the latter case 218 the distance is a vector distance. 219 220 The distance cross-matching is done using `scipy.spatial.cKDTree`. If 221 necessary you can tweak the default behavior by providing ``dict`` values 222 for the ``kdtree_args`` or ``query_args``. 223 224 Parameters 225 ---------- 226 distance : float or `~astropy.units.Quantity` ['length'] 227 Maximum distance between points to be considered a join match 228 kdtree_args : dict, None 229 Optional extra args for `~scipy.spatial.cKDTree` 230 query_args : dict, None 231 Optional extra args for `~scipy.spatial.cKDTree.query_ball_tree` 232 233 Returns 234 ------- 235 join_func : function 236 Function that accepts (skycoord1, skycoord2) and returns the tuple 237 (ids1, ids2) of pair-matched unique identifiers. 238 239 Examples 240 -------- 241 242 >>> from astropy.table import Table, join_distance 243 >>> from astropy import table 244 245 >>> c1 = [0, 1, 1.1, 2] 246 >>> c2 = [0.5, 1.05, 2.1] 247 248 >>> t1 = Table([c1], names=['col']) 249 >>> t2 = Table([c2], names=['col']) 250 >>> t12 = table.join(t1, t2, join_type='outer', join_funcs={'col': join_distance(0.2)}) 251 >>> print(t12) 252 col_id col_1 col_2 253 ------ ----- ----- 254 1 1.0 1.05 255 1 1.1 1.05 256 2 2.0 2.1 257 3 0.0 -- 258 4 -- 0.5 259 260 """ 261 try: 262 from scipy.spatial import cKDTree 263 except ImportError as exc: 264 raise ImportError('scipy is required to use join_distance()') from exc 265 266 if kdtree_args is None: 267 kdtree_args = {} 268 if query_args is None: 269 query_args = {} 270 271 def join_func(col1, col2): 272 if col1.ndim > 2 or col2.ndim > 2: 273 raise ValueError('columns for isclose_join must be 1- or 2-dimensional') 274 275 if isinstance(distance, Quantity): 276 # Convert to np.array with common unit 277 col1 = col1.to_value(distance.unit) 278 col2 = col2.to_value(distance.unit) 279 dist = distance.value 280 else: 281 # Convert to np.array to allow later in-place shape changing 282 col1 = np.asarray(col1) 283 col2 = np.asarray(col2) 284 dist = distance 285 286 # Ensure columns are pure np.array and are 2-D for use with KDTree 287 if col1.ndim == 1: 288 col1.shape = col1.shape + (1,) 289 if col2.ndim == 1: 290 col2.shape = col2.shape + (1,) 291 292 # Cross-match col1 and col2 within dist using KDTree 293 kd1 = cKDTree(col1, **kdtree_args) 294 kd2 = cKDTree(col2, **kdtree_args) 295 nears = kd1.query_ball_tree(kd2, r=dist, **query_args) 296 297 # Output of above is nears which is a list of lists, where the outer 298 # list corresponds to each item in col1, and where the inner lists are 299 # indexes into col2 of elements within the distance tolerance. This 300 # identifies col1 / col2 near pairs. 301 302 # Now convert that into unique identifiers for each near-pair. This is 303 # taken to be transitive, so that if points 1 and 2 are "near" and points 304 # 1 and 3 are "near", then 1, 2, and 3 are all given the same identifier. 305 # This identifier will then be used in the table join matching. 306 307 # Identifiers for each column, initialized to all zero. 308 ids1 = np.zeros(len(col1), dtype=int) 309 ids2 = np.zeros(len(col2), dtype=int) 310 311 # Start the identifier count at 1 312 id_ = 1 313 for idx1, idxs2 in enumerate(nears): 314 for idx2 in idxs2: 315 # If this col1 point is previously identified then set corresponding 316 # col2 point to same identifier. Likewise for col2 and col1. 317 if ids1[idx1] > 0: 318 ids2[idx2] = ids1[idx1] 319 elif ids2[idx2] > 0: 320 ids1[idx1] = ids2[idx2] 321 else: 322 # Not yet seen so set identifier for col1 and col2 323 ids1[idx1] = id_ 324 ids2[idx2] = id_ 325 id_ += 1 326 327 # Fill in unique identifiers for points with no near neighbor 328 for ids in (ids1, ids2): 329 for idx in np.flatnonzero(ids == 0): 330 ids[idx] = id_ 331 id_ += 1 332 333 # End of enclosure join_func() 334 return ids1, ids2 335 336 return join_func 337 338 339def join(left, right, keys=None, join_type='inner', *, 340 keys_left=None, keys_right=None, 341 uniq_col_name='{col_name}_{table_name}', 342 table_names=['1', '2'], metadata_conflicts='warn', 343 join_funcs=None): 344 """ 345 Perform a join of the left table with the right table on specified keys. 346 347 Parameters 348 ---------- 349 left : `~astropy.table.Table`-like object 350 Left side table in the join. If not a Table, will call ``Table(left)`` 351 right : `~astropy.table.Table`-like object 352 Right side table in the join. If not a Table, will call ``Table(right)`` 353 keys : str or list of str 354 Name(s) of column(s) used to match rows of left and right tables. 355 Default is to use all columns which are common to both tables. 356 join_type : str 357 Join type ('inner' | 'outer' | 'left' | 'right' | 'cartesian'), default is 'inner' 358 keys_left : str or list of str or list of column-like, optional 359 Left column(s) used to match rows instead of ``keys`` arg. This can be 360 be a single left table column name or list of column names, or a list of 361 column-like values with the same lengths as the left table. 362 keys_right : str or list of str or list of column-like, optional 363 Same as ``keys_left``, but for the right side of the join. 364 uniq_col_name : str or None 365 String generate a unique output column name in case of a conflict. 366 The default is '{col_name}_{table_name}'. 367 table_names : list of str or None 368 Two-element list of table names used when generating unique output 369 column names. The default is ['1', '2']. 370 metadata_conflicts : str 371 How to proceed with metadata conflicts. This should be one of: 372 * ``'silent'``: silently pick the last conflicting meta-data value 373 * ``'warn'``: pick the last conflicting meta-data value, but emit a warning (default) 374 * ``'error'``: raise an exception. 375 join_funcs : dict, None 376 Dict of functions to use for matching the corresponding key column(s). 377 See `~astropy.table.join_skycoord` for an example and details. 378 379 Returns 380 ------- 381 joined_table : `~astropy.table.Table` object 382 New table containing the result of the join operation. 383 """ 384 385 # Try converting inputs to Table as needed 386 if not isinstance(left, Table): 387 left = Table(left) 388 if not isinstance(right, Table): 389 right = Table(right) 390 391 col_name_map = OrderedDict() 392 out = _join(left, right, keys, join_type, 393 uniq_col_name, table_names, col_name_map, metadata_conflicts, 394 join_funcs, 395 keys_left=keys_left, keys_right=keys_right) 396 397 # Merge the column and table meta data. Table subclasses might override 398 # these methods for custom merge behavior. 399 _merge_table_meta(out, [left, right], metadata_conflicts=metadata_conflicts) 400 401 return out 402 403 404def setdiff(table1, table2, keys=None): 405 """ 406 Take a set difference of table rows. 407 408 The row set difference will contain all rows in ``table1`` that are not 409 present in ``table2``. If the keys parameter is not defined, all columns in 410 ``table1`` will be included in the output table. 411 412 Parameters 413 ---------- 414 table1 : `~astropy.table.Table` 415 ``table1`` is on the left side of the set difference. 416 table2 : `~astropy.table.Table` 417 ``table2`` is on the right side of the set difference. 418 keys : str or list of str 419 Name(s) of column(s) used to match rows of left and right tables. 420 Default is to use all columns in ``table1``. 421 422 Returns 423 ------- 424 diff_table : `~astropy.table.Table` 425 New table containing the set difference between tables. If the set 426 difference is none, an empty table will be returned. 427 428 Examples 429 -------- 430 To get a set difference between two tables:: 431 432 >>> from astropy.table import setdiff, Table 433 >>> t1 = Table({'a': [1, 4, 9], 'b': ['c', 'd', 'f']}, names=('a', 'b')) 434 >>> t2 = Table({'a': [1, 5, 9], 'b': ['c', 'b', 'f']}, names=('a', 'b')) 435 >>> print(t1) 436 a b 437 --- --- 438 1 c 439 4 d 440 9 f 441 >>> print(t2) 442 a b 443 --- --- 444 1 c 445 5 b 446 9 f 447 >>> print(setdiff(t1, t2)) 448 a b 449 --- --- 450 4 d 451 452 >>> print(setdiff(t2, t1)) 453 a b 454 --- --- 455 5 b 456 """ 457 if keys is None: 458 keys = table1.colnames 459 460 # Check that all keys are in table1 and table2 461 for tbl, tbl_str in ((table1, 'table1'), (table2, 'table2')): 462 diff_keys = np.setdiff1d(keys, tbl.colnames) 463 if len(diff_keys) != 0: 464 raise ValueError("The {} columns are missing from {}, cannot take " 465 "a set difference.".format(diff_keys, tbl_str)) 466 467 # Make a light internal copy of both tables 468 t1 = table1.copy(copy_data=False) 469 t1.meta = {} 470 t1.keep_columns(keys) 471 t1['__index1__'] = np.arange(len(table1)) # Keep track of rows indices 472 473 # Make a light internal copy to avoid touching table2 474 t2 = table2.copy(copy_data=False) 475 t2.meta = {} 476 t2.keep_columns(keys) 477 # Dummy column to recover rows after join 478 t2['__index2__'] = np.zeros(len(t2), dtype=np.uint8) # dummy column 479 480 t12 = _join(t1, t2, join_type='left', keys=keys, 481 metadata_conflicts='silent') 482 483 # If t12 index2 is masked then that means some rows were in table1 but not table2. 484 if hasattr(t12['__index2__'], 'mask'): 485 # Define bool mask of table1 rows not in table2 486 diff = t12['__index2__'].mask 487 # Get the row indices of table1 for those rows 488 idx = t12['__index1__'][diff] 489 # Select corresponding table1 rows straight from table1 to ensure 490 # correct table and column types. 491 t12_diff = table1[idx] 492 else: 493 t12_diff = table1[[]] 494 495 return t12_diff 496 497 498def dstack(tables, join_type='outer', metadata_conflicts='warn'): 499 """ 500 Stack columns within tables depth-wise 501 502 A ``join_type`` of 'exact' means that the tables must all have exactly 503 the same column names (though the order can vary). If ``join_type`` 504 is 'inner' then the intersection of common columns will be the output. 505 A value of 'outer' (default) means the output will have the union of 506 all columns, with table values being masked where no common values are 507 available. 508 509 Parameters 510 ---------- 511 tables : `~astropy.table.Table` or `~astropy.table.Row` or list thereof 512 Table(s) to stack along depth-wise with the current table 513 Table columns should have same shape and name for depth-wise stacking 514 join_type : str 515 Join type ('inner' | 'exact' | 'outer'), default is 'outer' 516 metadata_conflicts : str 517 How to proceed with metadata conflicts. This should be one of: 518 * ``'silent'``: silently pick the last conflicting meta-data value 519 * ``'warn'``: pick the last conflicting meta-data value, but emit a warning (default) 520 * ``'error'``: raise an exception. 521 522 Returns 523 ------- 524 stacked_table : `~astropy.table.Table` object 525 New table containing the stacked data from the input tables. 526 527 Examples 528 -------- 529 To stack two tables along rows do:: 530 531 >>> from astropy.table import vstack, Table 532 >>> t1 = Table({'a': [1, 2], 'b': [3, 4]}, names=('a', 'b')) 533 >>> t2 = Table({'a': [5, 6], 'b': [7, 8]}, names=('a', 'b')) 534 >>> print(t1) 535 a b 536 --- --- 537 1 3 538 2 4 539 >>> print(t2) 540 a b 541 --- --- 542 5 7 543 6 8 544 >>> print(dstack([t1, t2])) 545 a [2] b [2] 546 ------ ------ 547 1 .. 5 3 .. 7 548 2 .. 6 4 .. 8 549 """ 550 _check_join_type(join_type, 'dstack') 551 552 tables = _get_list_of_tables(tables) 553 if len(tables) == 1: 554 return tables[0] # no point in stacking a single table 555 556 n_rows = set(len(table) for table in tables) 557 if len(n_rows) != 1: 558 raise ValueError('Table lengths must all match for dstack') 559 n_row = n_rows.pop() 560 561 out = vstack(tables, join_type, metadata_conflicts) 562 563 for name, col in out.columns.items(): 564 col = out[name] 565 566 # Reshape to so each original column is now in a row. 567 # If entries are not 0-dim then those additional shape dims 568 # are just carried along. 569 # [x x x y y y] => [[x x x], 570 # [y y y]] 571 new_shape = (len(tables), n_row) + col.shape[1:] 572 try: 573 col.shape = (len(tables), n_row) + col.shape[1:] 574 except AttributeError: 575 col = col.reshape(new_shape) 576 577 # Transpose the table and row axes to get to 578 # [[x, y], 579 # [x, y] 580 # [x, y]] 581 axes = np.arange(len(col.shape)) 582 axes[:2] = [1, 0] 583 584 # This temporarily makes `out` be corrupted (columns of different 585 # length) but it all works out in the end. 586 out.columns.__setitem__(name, col.transpose(axes), validated=True) 587 588 return out 589 590 591def vstack(tables, join_type='outer', metadata_conflicts='warn'): 592 """ 593 Stack tables vertically (along rows) 594 595 A ``join_type`` of 'exact' means that the tables must all have exactly 596 the same column names (though the order can vary). If ``join_type`` 597 is 'inner' then the intersection of common columns will be the output. 598 A value of 'outer' (default) means the output will have the union of 599 all columns, with table values being masked where no common values are 600 available. 601 602 Parameters 603 ---------- 604 tables : `~astropy.table.Table` or `~astropy.table.Row` or list thereof 605 Table(s) to stack along rows (vertically) with the current table 606 join_type : str 607 Join type ('inner' | 'exact' | 'outer'), default is 'outer' 608 metadata_conflicts : str 609 How to proceed with metadata conflicts. This should be one of: 610 * ``'silent'``: silently pick the last conflicting meta-data value 611 * ``'warn'``: pick the last conflicting meta-data value, but emit a warning (default) 612 * ``'error'``: raise an exception. 613 614 Returns 615 ------- 616 stacked_table : `~astropy.table.Table` object 617 New table containing the stacked data from the input tables. 618 619 Examples 620 -------- 621 To stack two tables along rows do:: 622 623 >>> from astropy.table import vstack, Table 624 >>> t1 = Table({'a': [1, 2], 'b': [3, 4]}, names=('a', 'b')) 625 >>> t2 = Table({'a': [5, 6], 'b': [7, 8]}, names=('a', 'b')) 626 >>> print(t1) 627 a b 628 --- --- 629 1 3 630 2 4 631 >>> print(t2) 632 a b 633 --- --- 634 5 7 635 6 8 636 >>> print(vstack([t1, t2])) 637 a b 638 --- --- 639 1 3 640 2 4 641 5 7 642 6 8 643 """ 644 _check_join_type(join_type, 'vstack') 645 646 tables = _get_list_of_tables(tables) # validates input 647 if len(tables) == 1: 648 return tables[0] # no point in stacking a single table 649 col_name_map = OrderedDict() 650 651 out = _vstack(tables, join_type, col_name_map, metadata_conflicts) 652 653 # Merge table metadata 654 _merge_table_meta(out, tables, metadata_conflicts=metadata_conflicts) 655 656 return out 657 658 659def hstack(tables, join_type='outer', 660 uniq_col_name='{col_name}_{table_name}', table_names=None, 661 metadata_conflicts='warn'): 662 """ 663 Stack tables along columns (horizontally) 664 665 A ``join_type`` of 'exact' means that the tables must all 666 have exactly the same number of rows. If ``join_type`` is 'inner' then 667 the intersection of rows will be the output. A value of 'outer' (default) 668 means the output will have the union of all rows, with table values being 669 masked where no common values are available. 670 671 Parameters 672 ---------- 673 tables : `~astropy.table.Table` or `~astropy.table.Row` or list thereof 674 Tables to stack along columns (horizontally) with the current table 675 join_type : str 676 Join type ('inner' | 'exact' | 'outer'), default is 'outer' 677 uniq_col_name : str or None 678 String generate a unique output column name in case of a conflict. 679 The default is '{col_name}_{table_name}'. 680 table_names : list of str or None 681 Two-element list of table names used when generating unique output 682 column names. The default is ['1', '2', ..]. 683 metadata_conflicts : str 684 How to proceed with metadata conflicts. This should be one of: 685 * ``'silent'``: silently pick the last conflicting meta-data value 686 * ``'warn'``: pick the last conflicting meta-data value, 687 but emit a warning (default) 688 * ``'error'``: raise an exception. 689 690 Returns 691 ------- 692 stacked_table : `~astropy.table.Table` object 693 New table containing the stacked data from the input tables. 694 695 See Also 696 -------- 697 Table.add_columns, Table.replace_column, Table.update 698 699 Examples 700 -------- 701 To stack two tables horizontally (along columns) do:: 702 703 >>> from astropy.table import Table, hstack 704 >>> t1 = Table({'a': [1, 2], 'b': [3, 4]}, names=('a', 'b')) 705 >>> t2 = Table({'c': [5, 6], 'd': [7, 8]}, names=('c', 'd')) 706 >>> print(t1) 707 a b 708 --- --- 709 1 3 710 2 4 711 >>> print(t2) 712 c d 713 --- --- 714 5 7 715 6 8 716 >>> print(hstack([t1, t2])) 717 a b c d 718 --- --- --- --- 719 1 3 5 7 720 2 4 6 8 721 """ 722 _check_join_type(join_type, 'hstack') 723 724 tables = _get_list_of_tables(tables) # validates input 725 if len(tables) == 1: 726 return tables[0] # no point in stacking a single table 727 col_name_map = OrderedDict() 728 729 out = _hstack(tables, join_type, uniq_col_name, table_names, 730 col_name_map) 731 732 _merge_table_meta(out, tables, metadata_conflicts=metadata_conflicts) 733 734 return out 735 736 737def unique(input_table, keys=None, silent=False, keep='first'): 738 """ 739 Returns the unique rows of a table. 740 741 Parameters 742 ---------- 743 input_table : table-like 744 keys : str or list of str 745 Name(s) of column(s) used to create unique rows. 746 Default is to use all columns. 747 keep : {'first', 'last', 'none'} 748 Whether to keep the first or last row for each set of 749 duplicates. If 'none', all rows that are duplicate are 750 removed, leaving only rows that are already unique in 751 the input. 752 Default is 'first'. 753 silent : bool 754 If `True`, masked value column(s) are silently removed from 755 ``keys``. If `False`, an exception is raised when ``keys`` 756 contains masked value column(s). 757 Default is `False`. 758 759 Returns 760 ------- 761 unique_table : `~astropy.table.Table` object 762 New table containing only the unique rows of ``input_table``. 763 764 Examples 765 -------- 766 >>> from astropy.table import unique, Table 767 >>> import numpy as np 768 >>> table = Table(data=[[1,2,3,2,3,3], 769 ... [2,3,4,5,4,6], 770 ... [3,4,5,6,7,8]], 771 ... names=['col1', 'col2', 'col3'], 772 ... dtype=[np.int32, np.int32, np.int32]) 773 >>> table 774 <Table length=6> 775 col1 col2 col3 776 int32 int32 int32 777 ----- ----- ----- 778 1 2 3 779 2 3 4 780 3 4 5 781 2 5 6 782 3 4 7 783 3 6 8 784 >>> unique(table, keys='col1') 785 <Table length=3> 786 col1 col2 col3 787 int32 int32 int32 788 ----- ----- ----- 789 1 2 3 790 2 3 4 791 3 4 5 792 >>> unique(table, keys=['col1'], keep='last') 793 <Table length=3> 794 col1 col2 col3 795 int32 int32 int32 796 ----- ----- ----- 797 1 2 3 798 2 5 6 799 3 6 8 800 >>> unique(table, keys=['col1', 'col2']) 801 <Table length=5> 802 col1 col2 col3 803 int32 int32 int32 804 ----- ----- ----- 805 1 2 3 806 2 3 4 807 2 5 6 808 3 4 5 809 3 6 8 810 >>> unique(table, keys=['col1', 'col2'], keep='none') 811 <Table length=4> 812 col1 col2 col3 813 int32 int32 int32 814 ----- ----- ----- 815 1 2 3 816 2 3 4 817 2 5 6 818 3 6 8 819 >>> unique(table, keys=['col1'], keep='none') 820 <Table length=1> 821 col1 col2 col3 822 int32 int32 int32 823 ----- ----- ----- 824 1 2 3 825 826 """ 827 828 if keep not in ('first', 'last', 'none'): 829 raise ValueError("'keep' should be one of 'first', 'last', 'none'") 830 831 if isinstance(keys, str): 832 keys = [keys] 833 if keys is None: 834 keys = input_table.colnames 835 else: 836 if len(set(keys)) != len(keys): 837 raise ValueError("duplicate key names") 838 839 # Check for columns with masked values 840 for key in keys[:]: 841 col = input_table[key] 842 if hasattr(col, 'mask') and np.any(col.mask): 843 if not silent: 844 raise ValueError( 845 "cannot use columns with masked values as keys; " 846 "remove column '{}' from keys and rerun " 847 "unique()".format(key)) 848 del keys[keys.index(key)] 849 if len(keys) == 0: 850 raise ValueError("no column remained in ``keys``; " 851 "unique() cannot work with masked value " 852 "key columns") 853 854 grouped_table = input_table.group_by(keys) 855 indices = grouped_table.groups.indices 856 if keep == 'first': 857 indices = indices[:-1] 858 elif keep == 'last': 859 indices = indices[1:] - 1 860 else: 861 indices = indices[:-1][np.diff(indices) == 1] 862 863 return grouped_table[indices] 864 865 866def get_col_name_map(arrays, common_names, uniq_col_name='{col_name}_{table_name}', 867 table_names=None): 868 """ 869 Find the column names mapping when merging the list of tables 870 ``arrays``. It is assumed that col names in ``common_names`` are to be 871 merged into a single column while the rest will be uniquely represented 872 in the output. The args ``uniq_col_name`` and ``table_names`` specify 873 how to rename columns in case of conflicts. 874 875 Returns a dict mapping each output column name to the input(s). This takes the form 876 {outname : (col_name_0, col_name_1, ...), ... }. For key columns all of input names 877 will be present, while for the other non-key columns the value will be (col_name_0, 878 None, ..) or (None, col_name_1, ..) etc. 879 """ 880 881 col_name_map = collections.defaultdict(lambda: [None] * len(arrays)) 882 col_name_list = [] 883 884 if table_names is None: 885 table_names = [str(ii + 1) for ii in range(len(arrays))] 886 887 for idx, array in enumerate(arrays): 888 table_name = table_names[idx] 889 for name in array.colnames: 890 out_name = name 891 892 if name in common_names: 893 # If name is in the list of common_names then insert into 894 # the column name list, but just once. 895 if name not in col_name_list: 896 col_name_list.append(name) 897 else: 898 # If name is not one of the common column outputs, and it collides 899 # with the names in one of the other arrays, then rename 900 others = list(arrays) 901 others.pop(idx) 902 if any(name in other.colnames for other in others): 903 out_name = uniq_col_name.format(table_name=table_name, col_name=name) 904 col_name_list.append(out_name) 905 906 col_name_map[out_name][idx] = name 907 908 # Check for duplicate output column names 909 col_name_count = Counter(col_name_list) 910 repeated_names = [name for name, count in col_name_count.items() if count > 1] 911 if repeated_names: 912 raise TableMergeError('Merging column names resulted in duplicates: {}. ' 913 'Change uniq_col_name or table_names args to fix this.' 914 .format(repeated_names)) 915 916 # Convert col_name_map to a regular dict with tuple (immutable) values 917 col_name_map = OrderedDict((name, col_name_map[name]) for name in col_name_list) 918 919 return col_name_map 920 921 922def get_descrs(arrays, col_name_map): 923 """ 924 Find the dtypes descrs resulting from merging the list of arrays' dtypes, 925 using the column name mapping ``col_name_map``. 926 927 Return a list of descrs for the output. 928 """ 929 930 out_descrs = [] 931 932 for out_name, in_names in col_name_map.items(): 933 # List of input arrays that contribute to this output column 934 in_cols = [arr[name] for arr, name in zip(arrays, in_names) if name is not None] 935 936 # List of names of the columns that contribute to this output column. 937 names = [name for name in in_names if name is not None] 938 939 # Output dtype is the superset of all dtypes in in_arrays 940 try: 941 dtype = common_dtype(in_cols) 942 except TableMergeError as tme: 943 # Beautify the error message when we are trying to merge columns with incompatible 944 # types by including the name of the columns that originated the error. 945 raise TableMergeError("The '{}' columns have incompatible types: {}" 946 .format(names[0], tme._incompat_types)) from tme 947 948 # Make sure all input shapes are the same 949 uniq_shapes = set(col.shape[1:] for col in in_cols) 950 if len(uniq_shapes) != 1: 951 raise TableMergeError(f'Key columns {names!r} have different shape') 952 shape = uniq_shapes.pop() 953 954 out_descrs.append((fix_column_name(out_name), dtype, shape)) 955 956 return out_descrs 957 958 959def common_dtype(cols): 960 """ 961 Use numpy to find the common dtype for a list of columns. 962 963 Only allow columns within the following fundamental numpy data types: 964 np.bool_, np.object_, np.number, np.character, np.void 965 """ 966 try: 967 return metadata.common_dtype(cols) 968 except metadata.MergeConflictError as err: 969 tme = TableMergeError(f'Columns have incompatible types {err._incompat_types}') 970 tme._incompat_types = err._incompat_types 971 raise tme from err 972 973 974def _get_join_sort_idxs(keys, left, right): 975 # Go through each of the key columns in order and make columns for 976 # a new structured array that represents the lexical ordering of those 977 # key columns. This structured array is then argsort'ed. The trick here 978 # is that some columns (e.g. Time) may need to be expanded into multiple 979 # columns for ordering here. 980 981 ii = 0 # Index for uniquely naming the sort columns 982 sort_keys_dtypes = [] # sortable_table dtypes as list of (name, dtype_str, shape) tuples 983 sort_keys = [] # sortable_table (structured ndarray) column names 984 sort_left = {} # sortable ndarrays from left table 985 sort_right = {} # sortable ndarray from right table 986 987 for key in keys: 988 # get_sortable_arrays() returns a list of ndarrays that can be lexically 989 # sorted to represent the order of the column. In most cases this is just 990 # a single element of the column itself. 991 left_sort_cols = left[key].info.get_sortable_arrays() 992 right_sort_cols = right[key].info.get_sortable_arrays() 993 994 if len(left_sort_cols) != len(right_sort_cols): 995 # Should never happen because cols are screened beforehand for compatibility 996 raise RuntimeError('mismatch in sort cols lengths') 997 998 for left_sort_col, right_sort_col in zip(left_sort_cols, right_sort_cols): 999 # Check for consistency of shapes. Mismatch should never happen. 1000 shape = left_sort_col.shape[1:] 1001 if shape != right_sort_col.shape[1:]: 1002 raise RuntimeError('mismatch in shape of left vs. right sort array') 1003 1004 if shape != (): 1005 raise ValueError(f'sort key column {key!r} must be 1-d') 1006 1007 sort_key = str(ii) 1008 sort_keys.append(sort_key) 1009 sort_left[sort_key] = left_sort_col 1010 sort_right[sort_key] = right_sort_col 1011 1012 # Build up dtypes for the structured array that gets sorted. 1013 dtype_str = common_dtype([left_sort_col, right_sort_col]) 1014 sort_keys_dtypes.append((sort_key, dtype_str)) 1015 ii += 1 1016 1017 # Make the empty sortable table and fill it 1018 len_left = len(left) 1019 sortable_table = np.empty(len_left + len(right), dtype=sort_keys_dtypes) 1020 for key in sort_keys: 1021 sortable_table[key][:len_left] = sort_left[key] 1022 sortable_table[key][len_left:] = sort_right[key] 1023 1024 # Finally do the (lexical) argsort and make a new sorted version 1025 idx_sort = sortable_table.argsort(order=sort_keys) 1026 sorted_table = sortable_table[idx_sort] 1027 1028 # Get indexes of unique elements (i.e. the group boundaries) 1029 diffs = np.concatenate(([True], sorted_table[1:] != sorted_table[:-1], [True])) 1030 idxs = np.flatnonzero(diffs) 1031 1032 return idxs, idx_sort 1033 1034 1035def _apply_join_funcs(left, right, keys, join_funcs): 1036 """Apply join_funcs 1037 """ 1038 # Make light copies of left and right, then add new index columns. 1039 left = left.copy(copy_data=False) 1040 right = right.copy(copy_data=False) 1041 for key, join_func in join_funcs.items(): 1042 ids1, ids2 = join_func(left[key], right[key]) 1043 # Define a unique id_key name, and keep adding underscores until we have 1044 # a name not yet present. 1045 id_key = key + '_id' 1046 while id_key in left.columns or id_key in right.columns: 1047 id_key = id_key[:-2] + '_id' 1048 1049 keys = tuple(id_key if orig_key == key else orig_key for orig_key in keys) 1050 left.add_column(ids1, index=0, name=id_key) # [id_key] = ids1 1051 right.add_column(ids2, index=0, name=id_key) # [id_key] = ids2 1052 1053 return left, right, keys 1054 1055 1056def _join(left, right, keys=None, join_type='inner', 1057 uniq_col_name='{col_name}_{table_name}', 1058 table_names=['1', '2'], 1059 col_name_map=None, metadata_conflicts='warn', 1060 join_funcs=None, 1061 keys_left=None, keys_right=None): 1062 """ 1063 Perform a join of the left and right Tables on specified keys. 1064 1065 Parameters 1066 ---------- 1067 left : Table 1068 Left side table in the join 1069 right : Table 1070 Right side table in the join 1071 keys : str or list of str 1072 Name(s) of column(s) used to match rows of left and right tables. 1073 Default is to use all columns which are common to both tables. 1074 join_type : str 1075 Join type ('inner' | 'outer' | 'left' | 'right' | 'cartesian'), default is 'inner' 1076 uniq_col_name : str or None 1077 String generate a unique output column name in case of a conflict. 1078 The default is '{col_name}_{table_name}'. 1079 table_names : list of str or None 1080 Two-element list of table names used when generating unique output 1081 column names. The default is ['1', '2']. 1082 col_name_map : empty dict or None 1083 If passed as a dict then it will be updated in-place with the 1084 mapping of output to input column names. 1085 metadata_conflicts : str 1086 How to proceed with metadata conflicts. This should be one of: 1087 * ``'silent'``: silently pick the last conflicting meta-data value 1088 * ``'warn'``: pick the last conflicting meta-data value, but emit a warning (default) 1089 * ``'error'``: raise an exception. 1090 join_funcs : dict, None 1091 Dict of functions to use for matching the corresponding key column(s). 1092 See `~astropy.table.join_skycoord` for an example and details. 1093 1094 Returns 1095 ------- 1096 joined_table : `~astropy.table.Table` object 1097 New table containing the result of the join operation. 1098 """ 1099 # Store user-provided col_name_map until the end 1100 _col_name_map = col_name_map 1101 1102 # Special column name for cartesian join, should never collide with real column 1103 cartesian_index_name = '__table_cartesian_join_temp_index__' 1104 1105 if join_type not in ('inner', 'outer', 'left', 'right', 'cartesian'): 1106 raise ValueError("The 'join_type' argument should be in 'inner', " 1107 "'outer', 'left', 'right', or 'cartesian' " 1108 "(got '{}' instead)". 1109 format(join_type)) 1110 1111 if join_type == 'cartesian': 1112 if keys: 1113 raise ValueError('cannot supply keys for a cartesian join') 1114 1115 if join_funcs: 1116 raise ValueError('cannot supply join_funcs for a cartesian join') 1117 1118 # Make light copies of left and right, then add temporary index columns 1119 # with all the same value so later an outer join turns into a cartesian join. 1120 left = left.copy(copy_data=False) 1121 right = right.copy(copy_data=False) 1122 left[cartesian_index_name] = np.uint8(0) 1123 right[cartesian_index_name] = np.uint8(0) 1124 keys = (cartesian_index_name, ) 1125 1126 # Handle the case of join key columns that are different between left and 1127 # right via keys_left/keys_right args. This is done by saving the original 1128 # input tables and making new left and right tables that contain only the 1129 # key cols but with common column names ['0', '1', etc]. This sets `keys` to 1130 # those fake key names in the left and right tables 1131 if keys_left is not None or keys_right is not None: 1132 left_orig = left 1133 right_orig = right 1134 left, right, keys = _join_keys_left_right( 1135 left, right, keys, keys_left, keys_right, join_funcs) 1136 1137 if keys is None: 1138 keys = tuple(name for name in left.colnames if name in right.colnames) 1139 if len(keys) == 0: 1140 raise TableMergeError('No keys in common between left and right tables') 1141 elif isinstance(keys, str): 1142 # If we have a single key, put it in a tuple 1143 keys = (keys,) 1144 1145 # Check the key columns 1146 for arr, arr_label in ((left, 'Left'), (right, 'Right')): 1147 for name in keys: 1148 if name not in arr.colnames: 1149 raise TableMergeError('{} table does not have key column {!r}' 1150 .format(arr_label, name)) 1151 if hasattr(arr[name], 'mask') and np.any(arr[name].mask): 1152 raise TableMergeError('{} key column {!r} has missing values' 1153 .format(arr_label, name)) 1154 1155 if join_funcs is not None: 1156 if not all(key in keys for key in join_funcs): 1157 raise ValueError(f'join_funcs keys {join_funcs.keys()} must be a ' 1158 f'subset of join keys {keys}') 1159 left, right, keys = _apply_join_funcs(left, right, keys, join_funcs) 1160 1161 len_left, len_right = len(left), len(right) 1162 1163 if len_left == 0 or len_right == 0: 1164 raise ValueError('input tables for join must both have at least one row') 1165 1166 try: 1167 idxs, idx_sort = _get_join_sort_idxs(keys, left, right) 1168 except NotImplementedError: 1169 raise TypeError('one or more key columns are not sortable') 1170 1171 # Now that we have idxs and idx_sort, revert to the original table args to 1172 # carry on with making the output joined table. `keys` is set to to an empty 1173 # list so that all original left and right columns are included in the 1174 # output table. 1175 if keys_left is not None or keys_right is not None: 1176 keys = [] 1177 left = left_orig 1178 right = right_orig 1179 1180 # Joined array dtype as a list of descr (name, type_str, shape) tuples 1181 col_name_map = get_col_name_map([left, right], keys, uniq_col_name, table_names) 1182 out_descrs = get_descrs([left, right], col_name_map) 1183 1184 # Main inner loop in Cython to compute the cartesian product 1185 # indices for the given join type 1186 int_join_type = {'inner': 0, 'outer': 1, 'left': 2, 'right': 3, 1187 'cartesian': 1}[join_type] 1188 masked, n_out, left_out, left_mask, right_out, right_mask = \ 1189 _np_utils.join_inner(idxs, idx_sort, len_left, int_join_type) 1190 1191 out = _get_out_class([left, right])() 1192 1193 for out_name, dtype, shape in out_descrs: 1194 if out_name == cartesian_index_name: 1195 continue 1196 1197 left_name, right_name = col_name_map[out_name] 1198 if left_name and right_name: # this is a key which comes from left and right 1199 cols = [left[left_name], right[right_name]] 1200 1201 col_cls = _get_out_class(cols) 1202 if not hasattr(col_cls.info, 'new_like'): 1203 raise NotImplementedError('join unavailable for mixin column type(s): {}' 1204 .format(col_cls.__name__)) 1205 1206 out[out_name] = col_cls.info.new_like(cols, n_out, metadata_conflicts, out_name) 1207 out[out_name][:] = np.where(right_mask, 1208 left[left_name].take(left_out), 1209 right[right_name].take(right_out)) 1210 continue 1211 elif left_name: # out_name came from the left table 1212 name, array, array_out, array_mask = left_name, left, left_out, left_mask 1213 elif right_name: 1214 name, array, array_out, array_mask = right_name, right, right_out, right_mask 1215 else: 1216 raise TableMergeError('Unexpected column names (maybe one is ""?)') 1217 1218 # Select the correct elements from the original table 1219 col = array[name][array_out] 1220 1221 # If the output column is masked then set the output column masking 1222 # accordingly. Check for columns that don't support a mask attribute. 1223 if masked and np.any(array_mask): 1224 # If col is a Column but not MaskedColumn then upgrade at this point 1225 # because masking is required. 1226 if isinstance(col, Column) and not isinstance(col, MaskedColumn): 1227 col = out.MaskedColumn(col, copy=False) 1228 1229 if isinstance(col, Quantity) and not isinstance(col, Masked): 1230 col = Masked(col, copy=False) 1231 1232 # array_mask is 1-d corresponding to length of output column. We need 1233 # make it have the correct shape for broadcasting, i.e. (length, 1, 1, ..). 1234 # Mixin columns might not have ndim attribute so use len(col.shape). 1235 array_mask.shape = (col.shape[0],) + (1,) * (len(col.shape) - 1) 1236 1237 # Now broadcast to the correct final shape 1238 array_mask = np.broadcast_to(array_mask, col.shape) 1239 1240 try: 1241 col[array_mask] = col.info.mask_val 1242 except Exception as err: # Not clear how different classes will fail here 1243 raise NotImplementedError( 1244 "join requires masking column '{}' but column" 1245 " type {} does not support masking" 1246 .format(out_name, col.__class__.__name__)) from err 1247 1248 # Set the output table column to the new joined column 1249 out[out_name] = col 1250 1251 # If col_name_map supplied as a dict input, then update. 1252 if isinstance(_col_name_map, Mapping): 1253 _col_name_map.update(col_name_map) 1254 1255 return out 1256 1257 1258def _join_keys_left_right(left, right, keys, keys_left, keys_right, join_funcs): 1259 """Do processing to handle keys_left / keys_right args for join. 1260 1261 This takes the keys_left/right inputs and turns them into a list of left/right 1262 columns corresponding to those inputs (which can be column names or column 1263 data values). It also generates the list of fake key column names (strings 1264 of "1", "2", etc.) that correspond to the input keys. 1265 """ 1266 def _keys_to_cols(keys, table, label): 1267 # Process input `keys`, which is a str or list of str column names in 1268 # `table` or a list of column-like objects. The `label` is just for 1269 # error reporting. 1270 if isinstance(keys, str): 1271 keys = [keys] 1272 cols = [] 1273 for key in keys: 1274 if isinstance(key, str): 1275 try: 1276 cols.append(table[key]) 1277 except KeyError: 1278 raise ValueError(f'{label} table does not have key column {key!r}') 1279 else: 1280 if len(key) != len(table): 1281 raise ValueError(f'{label} table has different length from key {key}') 1282 cols.append(key) 1283 return cols 1284 1285 if join_funcs is not None: 1286 raise ValueError('cannot supply join_funcs arg and keys_left / keys_right') 1287 1288 if keys_left is None or keys_right is None: 1289 raise ValueError('keys_left and keys_right must both be provided') 1290 1291 if keys is not None: 1292 raise ValueError('keys arg must be None if keys_left and keys_right are supplied') 1293 1294 cols_left = _keys_to_cols(keys_left, left, 'left') 1295 cols_right = _keys_to_cols(keys_right, right, 'right') 1296 1297 if len(cols_left) != len(cols_right): 1298 raise ValueError('keys_left and keys_right args must have same length') 1299 1300 # Make two new temp tables for the join with only the join columns and 1301 # key columns in common. 1302 keys = [f'{ii}' for ii in range(len(cols_left))] 1303 1304 left = left.__class__(cols_left, names=keys, copy=False) 1305 right = right.__class__(cols_right, names=keys, copy=False) 1306 1307 return left, right, keys 1308 1309 1310def _check_join_type(join_type, func_name): 1311 """Check join_type arg in hstack and vstack. 1312 1313 This specifically checks for the common mistake of call vstack(t1, t2) 1314 instead of vstack([t1, t2]). The subsequent check of 1315 ``join_type in ('inner', ..)`` does not raise in this case. 1316 """ 1317 if not isinstance(join_type, str): 1318 msg = '`join_type` arg must be a string' 1319 if isinstance(join_type, Table): 1320 msg += ('. Did you accidentally ' 1321 f'call {func_name}(t1, t2, ..) instead of ' 1322 f'{func_name}([t1, t2], ..)?') 1323 raise TypeError(msg) 1324 1325 if join_type not in ('inner', 'exact', 'outer'): 1326 raise ValueError("`join_type` arg must be one of 'inner', 'exact' or 'outer'") 1327 1328 1329def _vstack(arrays, join_type='outer', col_name_map=None, metadata_conflicts='warn'): 1330 """ 1331 Stack Tables vertically (by rows) 1332 1333 A ``join_type`` of 'exact' (default) means that the arrays must all 1334 have exactly the same column names (though the order can vary). If 1335 ``join_type`` is 'inner' then the intersection of common columns will 1336 be the output. A value of 'outer' means the output will have the union of 1337 all columns, with array values being masked where no common values are 1338 available. 1339 1340 Parameters 1341 ---------- 1342 arrays : list of Tables 1343 Tables to stack by rows (vertically) 1344 join_type : str 1345 Join type ('inner' | 'exact' | 'outer'), default is 'outer' 1346 col_name_map : empty dict or None 1347 If passed as a dict then it will be updated in-place with the 1348 mapping of output to input column names. 1349 1350 Returns 1351 ------- 1352 stacked_table : `~astropy.table.Table` object 1353 New table containing the stacked data from the input tables. 1354 """ 1355 # Store user-provided col_name_map until the end 1356 _col_name_map = col_name_map 1357 1358 # Trivial case of one input array 1359 if len(arrays) == 1: 1360 return arrays[0] 1361 1362 # Start by assuming an outer match where all names go to output 1363 names = set(itertools.chain(*[arr.colnames for arr in arrays])) 1364 col_name_map = get_col_name_map(arrays, names) 1365 1366 # If require_match is True then the output must have exactly the same 1367 # number of columns as each input array 1368 if join_type == 'exact': 1369 for names in col_name_map.values(): 1370 if any(x is None for x in names): 1371 raise TableMergeError('Inconsistent columns in input arrays ' 1372 "(use 'inner' or 'outer' join_type to " 1373 "allow non-matching columns)") 1374 join_type = 'outer' 1375 1376 # For an inner join, keep only columns where all input arrays have that column 1377 if join_type == 'inner': 1378 col_name_map = OrderedDict((name, in_names) for name, in_names in col_name_map.items() 1379 if all(x is not None for x in in_names)) 1380 if len(col_name_map) == 0: 1381 raise TableMergeError('Input arrays have no columns in common') 1382 1383 lens = [len(arr) for arr in arrays] 1384 n_rows = sum(lens) 1385 out = _get_out_class(arrays)() 1386 1387 for out_name, in_names in col_name_map.items(): 1388 # List of input arrays that contribute to this output column 1389 cols = [arr[name] for arr, name in zip(arrays, in_names) if name is not None] 1390 1391 col_cls = _get_out_class(cols) 1392 if not hasattr(col_cls.info, 'new_like'): 1393 raise NotImplementedError('vstack unavailable for mixin column type(s): {}' 1394 .format(col_cls.__name__)) 1395 try: 1396 col = col_cls.info.new_like(cols, n_rows, metadata_conflicts, out_name) 1397 except metadata.MergeConflictError as err: 1398 # Beautify the error message when we are trying to merge columns with incompatible 1399 # types by including the name of the columns that originated the error. 1400 raise TableMergeError("The '{}' columns have incompatible types: {}" 1401 .format(out_name, err._incompat_types)) from err 1402 1403 idx0 = 0 1404 for name, array in zip(in_names, arrays): 1405 idx1 = idx0 + len(array) 1406 if name in array.colnames: 1407 col[idx0:idx1] = array[name] 1408 else: 1409 # If col is a Column but not MaskedColumn then upgrade at this point 1410 # because masking is required. 1411 if isinstance(col, Column) and not isinstance(col, MaskedColumn): 1412 col = out.MaskedColumn(col, copy=False) 1413 1414 if isinstance(col, Quantity) and not isinstance(col, Masked): 1415 col = Masked(col, copy=False) 1416 1417 try: 1418 col[idx0:idx1] = col.info.mask_val 1419 except Exception as err: 1420 raise NotImplementedError( 1421 "vstack requires masking column '{}' but column" 1422 " type {} does not support masking" 1423 .format(out_name, col.__class__.__name__)) from err 1424 idx0 = idx1 1425 1426 out[out_name] = col 1427 1428 # If col_name_map supplied as a dict input, then update. 1429 if isinstance(_col_name_map, Mapping): 1430 _col_name_map.update(col_name_map) 1431 1432 return out 1433 1434 1435def _hstack(arrays, join_type='outer', uniq_col_name='{col_name}_{table_name}', 1436 table_names=None, col_name_map=None): 1437 """ 1438 Stack tables horizontally (by columns) 1439 1440 A ``join_type`` of 'exact' (default) means that the arrays must all 1441 have exactly the same number of rows. If ``join_type`` is 'inner' then 1442 the intersection of rows will be the output. A value of 'outer' means 1443 the output will have the union of all rows, with array values being 1444 masked where no common values are available. 1445 1446 Parameters 1447 ---------- 1448 arrays : List of tables 1449 Tables to stack by columns (horizontally) 1450 join_type : str 1451 Join type ('inner' | 'exact' | 'outer'), default is 'outer' 1452 uniq_col_name : str or None 1453 String generate a unique output column name in case of a conflict. 1454 The default is '{col_name}_{table_name}'. 1455 table_names : list of str or None 1456 Two-element list of table names used when generating unique output 1457 column names. The default is ['1', '2', ..]. 1458 1459 Returns 1460 ------- 1461 stacked_table : `~astropy.table.Table` object 1462 New table containing the stacked data from the input tables. 1463 """ 1464 1465 # Store user-provided col_name_map until the end 1466 _col_name_map = col_name_map 1467 1468 if table_names is None: 1469 table_names = [f'{ii + 1}' for ii in range(len(arrays))] 1470 if len(arrays) != len(table_names): 1471 raise ValueError('Number of arrays must match number of table_names') 1472 1473 # Trivial case of one input arrays 1474 if len(arrays) == 1: 1475 return arrays[0] 1476 1477 col_name_map = get_col_name_map(arrays, [], uniq_col_name, table_names) 1478 1479 # If require_match is True then all input arrays must have the same length 1480 arr_lens = [len(arr) for arr in arrays] 1481 if join_type == 'exact': 1482 if len(set(arr_lens)) > 1: 1483 raise TableMergeError("Inconsistent number of rows in input arrays " 1484 "(use 'inner' or 'outer' join_type to allow " 1485 "non-matching rows)") 1486 join_type = 'outer' 1487 1488 # For an inner join, keep only the common rows 1489 if join_type == 'inner': 1490 min_arr_len = min(arr_lens) 1491 if len(set(arr_lens)) > 1: 1492 arrays = [arr[:min_arr_len] for arr in arrays] 1493 arr_lens = [min_arr_len for arr in arrays] 1494 1495 # If there are any output rows where one or more input arrays are missing 1496 # then the output must be masked. If any input arrays are masked then 1497 # output is masked. 1498 1499 n_rows = max(arr_lens) 1500 out = _get_out_class(arrays)() 1501 1502 for out_name, in_names in col_name_map.items(): 1503 for name, array, arr_len in zip(in_names, arrays, arr_lens): 1504 if name is None: 1505 continue 1506 1507 if n_rows > arr_len: 1508 indices = np.arange(n_rows) 1509 indices[arr_len:] = 0 1510 col = array[name][indices] 1511 1512 # If col is a Column but not MaskedColumn then upgrade at this point 1513 # because masking is required. 1514 if isinstance(col, Column) and not isinstance(col, MaskedColumn): 1515 col = out.MaskedColumn(col, copy=False) 1516 1517 if isinstance(col, Quantity) and not isinstance(col, Masked): 1518 col = Masked(col, copy=False) 1519 1520 try: 1521 col[arr_len:] = col.info.mask_val 1522 except Exception as err: 1523 raise NotImplementedError( 1524 "hstack requires masking column '{}' but column" 1525 " type {} does not support masking" 1526 .format(out_name, col.__class__.__name__)) from err 1527 else: 1528 col = array[name][:n_rows] 1529 1530 out[out_name] = col 1531 1532 # If col_name_map supplied as a dict input, then update. 1533 if isinstance(_col_name_map, Mapping): 1534 _col_name_map.update(col_name_map) 1535 1536 return out 1537