1# Licensed under a 3-clause BSD style license - see LICENSE.rst 2 3from astropy.table.table_helpers import ArrayWrapper 4from astropy.coordinates.earth import EarthLocation 5from astropy.units.quantity import Quantity 6from collections import OrderedDict 7from contextlib import nullcontext 8 9import pytest 10import numpy as np 11 12from astropy.table import Table, QTable, TableMergeError, Column, MaskedColumn, NdarrayMixin 13from astropy.table.operations import _get_out_class, join_skycoord, join_distance 14from astropy import units as u 15from astropy.utils import metadata 16from astropy.utils.metadata import MergeConflictError 17from astropy import table 18from astropy.time import Time, TimeDelta 19from astropy.coordinates import (SkyCoord, SphericalRepresentation, 20 UnitSphericalRepresentation, 21 CartesianRepresentation, 22 BaseRepresentationOrDifferential, 23 search_around_3d) 24from astropy.coordinates.tests.test_representation import representation_equal 25from astropy.io.misc.asdf.tags.helpers import skycoord_equal 26from astropy.utils.compat.optional_deps import HAS_SCIPY # noqa 27 28 29def sort_eq(list1, list2): 30 return sorted(list1) == sorted(list2) 31 32 33def check_mask(col, exp_mask): 34 """Check that col.mask == exp_mask""" 35 if hasattr(col, 'mask'): 36 # Coerce expected mask into dtype of col.mask. In particular this is 37 # needed for types like EarthLocation where the mask is a structured 38 # array. 39 exp_mask = np.array(exp_mask).astype(col.mask.dtype) 40 out = np.all(col.mask == exp_mask) 41 else: 42 # With no mask the check is OK if all the expected mask values 43 # are False (i.e. no auto-conversion to MaskedQuantity if it was 44 # not required by the join). 45 out = np.all(exp_mask == False) 46 return out 47 48 49class TestJoin(): 50 51 def _setup(self, t_cls=Table): 52 lines1 = [' a b c ', 53 ' 0 foo L1', 54 ' 1 foo L2', 55 ' 1 bar L3', 56 ' 2 bar L4'] 57 lines2 = [' a b d ', 58 ' 1 foo R1', 59 ' 1 foo R2', 60 ' 2 bar R3', 61 ' 4 bar R4'] 62 self.t1 = t_cls.read(lines1, format='ascii') 63 self.t2 = t_cls.read(lines2, format='ascii') 64 self.t3 = t_cls(self.t2, copy=True) 65 66 self.t1.meta.update(OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)])) 67 self.t2.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)])) 68 self.t3.meta.update(OrderedDict([('b', 3), ('c', [1, 2]), ('d', 2), ('a', 1)])) 69 70 self.meta_merge = OrderedDict([('b', [1, 2, 3, 4]), 71 ('c', {'a': 1, 'b': 1}), 72 ('d', 1), 73 ('a', 1)]) 74 75 def test_table_meta_merge(self, operation_table_type): 76 self._setup(operation_table_type) 77 out = table.join(self.t1, self.t2, join_type='inner') 78 assert out.meta == self.meta_merge 79 80 def test_table_meta_merge_conflict(self, operation_table_type): 81 self._setup(operation_table_type) 82 83 with pytest.warns(metadata.MergeConflictWarning) as w: 84 out = table.join(self.t1, self.t3, join_type='inner') 85 assert len(w) == 3 86 87 assert out.meta == self.t3.meta 88 89 with pytest.warns(metadata.MergeConflictWarning) as w: 90 out = table.join(self.t1, self.t3, join_type='inner', metadata_conflicts='warn') 91 assert len(w) == 3 92 93 assert out.meta == self.t3.meta 94 95 out = table.join(self.t1, self.t3, join_type='inner', metadata_conflicts='silent') 96 97 assert out.meta == self.t3.meta 98 99 with pytest.raises(MergeConflictError): 100 out = table.join(self.t1, self.t3, join_type='inner', metadata_conflicts='error') 101 102 with pytest.raises(ValueError): 103 out = table.join(self.t1, self.t3, join_type='inner', metadata_conflicts='nonsense') 104 105 def test_both_unmasked_inner(self, operation_table_type): 106 self._setup(operation_table_type) 107 t1 = self.t1 108 t2 = self.t2 109 110 # Basic join with default parameters (inner join on common keys) 111 t12 = table.join(t1, t2) 112 assert type(t12) is operation_table_type 113 assert type(t12['a']) is type(t1['a']) # noqa 114 assert type(t12['b']) is type(t1['b']) # noqa 115 assert type(t12['c']) is type(t1['c']) # noqa 116 assert type(t12['d']) is type(t2['d']) # noqa 117 assert t12.masked is False 118 assert sort_eq(t12.pformat(), [' a b c d ', 119 '--- --- --- ---', 120 ' 1 foo L2 R1', 121 ' 1 foo L2 R2', 122 ' 2 bar L4 R3']) 123 # Table meta merged properly 124 assert t12.meta == self.meta_merge 125 126 def test_both_unmasked_left_right_outer(self, operation_table_type): 127 if operation_table_type is QTable: 128 pytest.xfail('Quantity columns do not support masking.') 129 self._setup(operation_table_type) 130 t1 = self.t1 131 t2 = self.t2 132 133 # Left join 134 t12 = table.join(t1, t2, join_type='left') 135 assert t12.has_masked_columns is True 136 assert t12.masked is False 137 for name in ('a', 'b', 'c'): 138 assert type(t12[name]) is Column 139 assert type(t12['d']) is MaskedColumn 140 assert sort_eq(t12.pformat(), [' a b c d ', 141 '--- --- --- ---', 142 ' 0 foo L1 --', 143 ' 1 bar L3 --', 144 ' 1 foo L2 R1', 145 ' 1 foo L2 R2', 146 ' 2 bar L4 R3']) 147 148 # Right join 149 t12 = table.join(t1, t2, join_type='right') 150 assert t12.has_masked_columns is True 151 assert t12.masked is False 152 assert sort_eq(t12.pformat(), [' a b c d ', 153 '--- --- --- ---', 154 ' 1 foo L2 R1', 155 ' 1 foo L2 R2', 156 ' 2 bar L4 R3', 157 ' 4 bar -- R4']) 158 159 # Outer join 160 t12 = table.join(t1, t2, join_type='outer') 161 assert t12.has_masked_columns is True 162 assert t12.masked is False 163 assert sort_eq(t12.pformat(), [' a b c d ', 164 '--- --- --- ---', 165 ' 0 foo L1 --', 166 ' 1 bar L3 --', 167 ' 1 foo L2 R1', 168 ' 1 foo L2 R2', 169 ' 2 bar L4 R3', 170 ' 4 bar -- R4']) 171 172 # Check that the common keys are 'a', 'b' 173 t12a = table.join(t1, t2, join_type='outer') 174 t12b = table.join(t1, t2, join_type='outer', keys=['a', 'b']) 175 assert np.all(t12a.as_array() == t12b.as_array()) 176 177 def test_both_unmasked_single_key_inner(self, operation_table_type): 178 self._setup(operation_table_type) 179 t1 = self.t1 180 t2 = self.t2 181 182 # Inner join on 'a' column 183 t12 = table.join(t1, t2, keys='a') 184 assert type(t12) is operation_table_type 185 assert type(t12['a']) is type(t1['a']) # noqa 186 assert type(t12['b_1']) is type(t1['b']) # noqa 187 assert type(t12['c']) is type(t1['c']) # noqa 188 assert type(t12['b_2']) is type(t2['b']) # noqa 189 assert type(t12['d']) is type(t2['d']) # noqa 190 assert t12.masked is False 191 assert sort_eq(t12.pformat(), [' a b_1 c b_2 d ', 192 '--- --- --- --- ---', 193 ' 1 foo L2 foo R1', 194 ' 1 foo L2 foo R2', 195 ' 1 bar L3 foo R1', 196 ' 1 bar L3 foo R2', 197 ' 2 bar L4 bar R3']) 198 199 def test_both_unmasked_single_key_left_right_outer(self, operation_table_type): 200 if operation_table_type is QTable: 201 pytest.xfail('Quantity columns do not support masking.') 202 self._setup(operation_table_type) 203 t1 = self.t1 204 t2 = self.t2 205 206 # Left join 207 t12 = table.join(t1, t2, join_type='left', keys='a') 208 assert t12.has_masked_columns is True 209 assert sort_eq(t12.pformat(), [' a b_1 c b_2 d ', 210 '--- --- --- --- ---', 211 ' 0 foo L1 -- --', 212 ' 1 foo L2 foo R1', 213 ' 1 foo L2 foo R2', 214 ' 1 bar L3 foo R1', 215 ' 1 bar L3 foo R2', 216 ' 2 bar L4 bar R3']) 217 218 # Right join 219 t12 = table.join(t1, t2, join_type='right', keys='a') 220 assert t12.has_masked_columns is True 221 assert sort_eq(t12.pformat(), [' a b_1 c b_2 d ', 222 '--- --- --- --- ---', 223 ' 1 foo L2 foo R1', 224 ' 1 foo L2 foo R2', 225 ' 1 bar L3 foo R1', 226 ' 1 bar L3 foo R2', 227 ' 2 bar L4 bar R3', 228 ' 4 -- -- bar R4']) 229 230 # Outer join 231 t12 = table.join(t1, t2, join_type='outer', keys='a') 232 assert t12.has_masked_columns is True 233 assert sort_eq(t12.pformat(), [' a b_1 c b_2 d ', 234 '--- --- --- --- ---', 235 ' 0 foo L1 -- --', 236 ' 1 foo L2 foo R1', 237 ' 1 foo L2 foo R2', 238 ' 1 bar L3 foo R1', 239 ' 1 bar L3 foo R2', 240 ' 2 bar L4 bar R3', 241 ' 4 -- -- bar R4']) 242 243 def test_masked_unmasked(self, operation_table_type): 244 if operation_table_type is QTable: 245 pytest.xfail('Quantity columns do not support masking.') 246 self._setup(operation_table_type) 247 t1 = self.t1 248 t1m = operation_table_type(self.t1, masked=True) 249 t2 = self.t2 250 251 # Result table is never masked 252 t1m2 = table.join(t1m, t2, join_type='inner') 253 assert t1m2.masked is False 254 255 # Result should match non-masked result 256 t12 = table.join(t1, t2) 257 assert np.all(t12.as_array() == np.array(t1m2)) 258 259 # Mask out some values in left table and make sure they propagate 260 t1m['b'].mask[1] = True 261 t1m['c'].mask[2] = True 262 t1m2 = table.join(t1m, t2, join_type='inner', keys='a') 263 assert sort_eq(t1m2.pformat(), [' a b_1 c b_2 d ', 264 '--- --- --- --- ---', 265 ' 1 -- L2 foo R1', 266 ' 1 -- L2 foo R2', 267 ' 1 bar -- foo R1', 268 ' 1 bar -- foo R2', 269 ' 2 bar L4 bar R3']) 270 271 t21m = table.join(t2, t1m, join_type='inner', keys='a') 272 assert sort_eq(t21m.pformat(), [' a b_1 d b_2 c ', 273 '--- --- --- --- ---', 274 ' 1 foo R2 -- L2', 275 ' 1 foo R2 bar --', 276 ' 1 foo R1 -- L2', 277 ' 1 foo R1 bar --', 278 ' 2 bar R3 bar L4']) 279 280 def test_masked_masked(self, operation_table_type): 281 self._setup(operation_table_type) 282 """Two masked tables""" 283 if operation_table_type is QTable: 284 pytest.xfail('Quantity columns do not support masking.') 285 t1 = self.t1 286 t1m = operation_table_type(self.t1, masked=True) 287 t2 = self.t2 288 t2m = operation_table_type(self.t2, masked=True) 289 290 # Result table is never masked but original column types are preserved 291 t1m2m = table.join(t1m, t2m, join_type='inner') 292 assert t1m2m.masked is False 293 for col in t1m2m.itercols(): 294 assert type(col) is MaskedColumn 295 296 # Result should match non-masked result 297 t12 = table.join(t1, t2) 298 assert np.all(t12.as_array() == np.array(t1m2m)) 299 300 # Mask out some values in both tables and make sure they propagate 301 t1m['b'].mask[1] = True 302 t1m['c'].mask[2] = True 303 t2m['d'].mask[2] = True 304 t1m2m = table.join(t1m, t2m, join_type='inner', keys='a') 305 assert sort_eq(t1m2m.pformat(), [' a b_1 c b_2 d ', 306 '--- --- --- --- ---', 307 ' 1 -- L2 foo R1', 308 ' 1 -- L2 foo R2', 309 ' 1 bar -- foo R1', 310 ' 1 bar -- foo R2', 311 ' 2 bar L4 bar --']) 312 313 def test_classes(self): 314 """Ensure that classes and subclasses get through as expected""" 315 class MyCol(Column): 316 pass 317 318 class MyMaskedCol(MaskedColumn): 319 pass 320 321 t1 = Table() 322 t1['a'] = MyCol([1]) 323 t1['b'] = MyCol([2]) 324 t1['c'] = MyMaskedCol([3]) 325 326 t2 = Table() 327 t2['a'] = Column([1, 2]) 328 t2['d'] = MyCol([3, 4]) 329 t2['e'] = MyMaskedCol([5, 6]) 330 331 t12 = table.join(t1, t2, join_type='inner') 332 for name, exp_type in (('a', MyCol), ('b', MyCol), ('c', MyMaskedCol), 333 ('d', MyCol), ('e', MyMaskedCol)): 334 assert type(t12[name] is exp_type) 335 336 t21 = table.join(t2, t1, join_type='left') 337 # Note col 'b' gets upgraded from MyCol to MaskedColumn since it needs to be 338 # masked, but col 'c' stays since MyMaskedCol supports masking. 339 for name, exp_type in (('a', MyCol), ('b', MaskedColumn), ('c', MyMaskedCol), 340 ('d', MyCol), ('e', MyMaskedCol)): 341 assert type(t21[name] is exp_type) 342 343 def test_col_rename(self, operation_table_type): 344 self._setup(operation_table_type) 345 """ 346 Test auto col renaming when there is a conflict. Use 347 non-default values of uniq_col_name and table_names. 348 """ 349 t1 = self.t1 350 t2 = self.t2 351 t12 = table.join(t1, t2, uniq_col_name='x_{table_name}_{col_name}_y', 352 table_names=['L', 'R'], keys='a') 353 assert t12.colnames == ['a', 'x_L_b_y', 'c', 'x_R_b_y', 'd'] 354 355 def test_rename_conflict(self, operation_table_type): 356 self._setup(operation_table_type) 357 """ 358 Test that auto-column rename fails because of a conflict 359 with an existing column 360 """ 361 t1 = self.t1 362 t2 = self.t2 363 t1['b_1'] = 1 # Add a new column b_1 that will conflict with auto-rename 364 with pytest.raises(TableMergeError): 365 table.join(t1, t2, keys='a') 366 367 def test_missing_keys(self, operation_table_type): 368 self._setup(operation_table_type) 369 """Merge on a key column that doesn't exist""" 370 t1 = self.t1 371 t2 = self.t2 372 with pytest.raises(TableMergeError): 373 table.join(t1, t2, keys=['a', 'not there']) 374 375 def test_bad_join_type(self, operation_table_type): 376 self._setup(operation_table_type) 377 """Bad join_type input""" 378 t1 = self.t1 379 t2 = self.t2 380 with pytest.raises(ValueError): 381 table.join(t1, t2, join_type='illegal value') 382 383 def test_no_common_keys(self, operation_table_type): 384 self._setup(operation_table_type) 385 """Merge tables with no common keys""" 386 t1 = self.t1 387 t2 = self.t2 388 del t1['a'] 389 del t1['b'] 390 del t2['a'] 391 del t2['b'] 392 with pytest.raises(TableMergeError): 393 table.join(t1, t2) 394 395 def test_masked_key_column(self, operation_table_type): 396 self._setup(operation_table_type) 397 """Merge on a key column that has a masked element""" 398 if operation_table_type is QTable: 399 pytest.xfail('Quantity columns do not support masking.') 400 t1 = self.t1 401 t2 = operation_table_type(self.t2, masked=True) 402 table.join(t1, t2) # OK 403 t2['a'].mask[0] = True 404 with pytest.raises(TableMergeError): 405 table.join(t1, t2) 406 407 def test_col_meta_merge(self, operation_table_type): 408 self._setup(operation_table_type) 409 t1 = self.t1 410 t2 = self.t2 411 t2.rename_column('d', 'c') # force col conflict and renaming 412 meta1 = OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)]) 413 meta2 = OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)]) 414 415 # Key col 'a', should first value ('cm') 416 t1['a'].unit = 'cm' 417 t2['a'].unit = 'm' 418 # Key col 'b', take first value 't1_b' 419 t1['b'].info.description = 't1_b' 420 # Key col 'b', take first non-empty value 't1_b' 421 t2['b'].info.format = '%6s' 422 # Key col 'a', should be merged meta 423 t1['a'].info.meta = meta1 424 t2['a'].info.meta = meta2 425 # Key col 'b', should be meta2 426 t2['b'].info.meta = meta2 427 428 # All these should pass through 429 t1['c'].info.format = '%3s' 430 t1['c'].info.description = 't1_c' 431 432 t2['c'].info.format = '%6s' 433 t2['c'].info.description = 't2_c' 434 435 if operation_table_type is Table: 436 ctx = pytest.warns(metadata.MergeConflictWarning, match=r"In merged column 'a' the 'unit' attribute does not match \(cm != m\)") # noqa 437 else: 438 ctx = nullcontext() 439 440 with ctx: 441 t12 = table.join(t1, t2, keys=['a', 'b']) 442 443 assert t12['a'].unit == 'm' 444 assert t12['b'].info.description == 't1_b' 445 assert t12['b'].info.format == '%6s' 446 assert t12['a'].info.meta == self.meta_merge 447 assert t12['b'].info.meta == meta2 448 assert t12['c_1'].info.format == '%3s' 449 assert t12['c_1'].info.description == 't1_c' 450 assert t12['c_2'].info.format == '%6s' 451 assert t12['c_2'].info.description == 't2_c' 452 453 def test_join_multidimensional(self, operation_table_type): 454 self._setup(operation_table_type) 455 456 # Regression test for #2984, which was an issue where join did not work 457 # on multi-dimensional columns. 458 459 t1 = operation_table_type() 460 t1['a'] = [1, 2, 3] 461 t1['b'] = np.ones((3, 4)) 462 463 t2 = operation_table_type() 464 t2['a'] = [1, 2, 3] 465 t2['c'] = [4, 5, 6] 466 467 t3 = table.join(t1, t2) 468 469 np.testing.assert_allclose(t3['a'], t1['a']) 470 np.testing.assert_allclose(t3['b'], t1['b']) 471 np.testing.assert_allclose(t3['c'], t2['c']) 472 473 def test_join_multidimensional_masked(self, operation_table_type): 474 self._setup(operation_table_type) 475 """ 476 Test for outer join with multidimensional columns where masking is required. 477 (Issue #4059). 478 """ 479 if operation_table_type is QTable: 480 pytest.xfail('Quantity columns do not support masking.') 481 482 a = table.MaskedColumn([1, 2, 3], name='a') 483 a2 = table.Column([1, 3, 4], name='a') 484 b = table.MaskedColumn([[1, 2], 485 [3, 4], 486 [5, 6]], 487 name='b', 488 mask=[[1, 0], 489 [0, 1], 490 [0, 0]]) 491 c = table.Column([[1, 1], 492 [2, 2], 493 [3, 3]], 494 name='c') 495 t1 = operation_table_type([a, b]) 496 t2 = operation_table_type([a2, c]) 497 t12 = table.join(t1, t2, join_type='inner') 498 499 assert np.all(t12['b'].mask == [[True, False], 500 [False, False]]) 501 assert not hasattr(t12['c'], 'mask') 502 503 t12 = table.join(t1, t2, join_type='outer') 504 assert np.all(t12['b'].mask == [[True, False], 505 [False, True], 506 [False, False], 507 [True, True]]) 508 assert np.all(t12['c'].mask == [[False, False], 509 [True, True], 510 [False, False], 511 [False, False]]) 512 513 def test_mixin_functionality(self, mixin_cols): 514 col = mixin_cols['m'] 515 cls_name = type(col).__name__ 516 len_col = len(col) 517 idx = np.arange(len_col) 518 t1 = table.QTable([idx, col], names=['idx', 'm1']) 519 t2 = table.QTable([idx, col], names=['idx', 'm2']) 520 # Set up join mismatches for different join_type cases 521 t1 = t1[[0, 1, 3]] 522 t2 = t2[[0, 2, 3]] 523 524 # Test inner join, which works for all mixin_cols 525 out = table.join(t1, t2, join_type='inner') 526 assert len(out) == 2 527 assert out['m2'].__class__ is col.__class__ 528 assert np.all(out['idx'] == [0, 3]) 529 if cls_name == 'SkyCoord': 530 # SkyCoord doesn't support __eq__ so use our own 531 assert skycoord_equal(out['m1'], col[[0, 3]]) 532 assert skycoord_equal(out['m2'], col[[0, 3]]) 533 elif 'Repr' in cls_name or 'Diff' in cls_name: 534 assert np.all(representation_equal(out['m1'], col[[0, 3]])) 535 assert np.all(representation_equal(out['m2'], col[[0, 3]])) 536 else: 537 assert np.all(out['m1'] == col[[0, 3]]) 538 assert np.all(out['m2'] == col[[0, 3]]) 539 540 # Check for left, right, outer join which requires masking. Works for 541 # the listed mixins classes. 542 if isinstance(col, (Quantity, Time, TimeDelta)): 543 out = table.join(t1, t2, join_type='left') 544 assert len(out) == 3 545 assert np.all(out['idx'] == [0, 1, 3]) 546 assert np.all(out['m1'] == t1['m1']) 547 assert np.all(out['m2'] == t2['m2']) 548 check_mask(out['m1'], [False, False, False]) 549 check_mask(out['m2'], [False, True, False]) 550 551 out = table.join(t1, t2, join_type='right') 552 assert len(out) == 3 553 assert np.all(out['idx'] == [0, 2, 3]) 554 assert np.all(out['m1'] == t1['m1']) 555 assert np.all(out['m2'] == t2['m2']) 556 check_mask(out['m1'], [False, True, False]) 557 check_mask(out['m2'], [False, False, False]) 558 559 out = table.join(t1, t2, join_type='outer') 560 assert len(out) == 4 561 assert np.all(out['idx'] == [0, 1, 2, 3]) 562 assert np.all(out['m1'] == col) 563 assert np.all(out['m2'] == col) 564 assert check_mask(out['m1'], [False, False, True, False]) 565 assert check_mask(out['m2'], [False, True, False, False]) 566 else: 567 # Otherwise make sure it fails with the right exception message 568 for join_type in ('outer', 'left', 'right'): 569 with pytest.raises(NotImplementedError) as err: 570 table.join(t1, t2, join_type=join_type) 571 assert ('join requires masking' in str(err.value) 572 or 'join unavailable' in str(err.value)) 573 574 def test_cartesian_join(self, operation_table_type): 575 t1 = Table(rows=[(1, 'a'), 576 (2, 'b')], names=['a', 'b']) 577 t2 = Table(rows=[(3, 'c'), 578 (4, 'd')], names=['a', 'c']) 579 t12 = table.join(t1, t2, join_type='cartesian') 580 581 assert t1.colnames == ['a', 'b'] 582 assert t2.colnames == ['a', 'c'] 583 assert len(t12) == len(t1) * len(t2) 584 assert str(t12).splitlines() == [ 585 'a_1 b a_2 c ', 586 '--- --- --- ---', 587 ' 1 a 3 c', 588 ' 1 a 4 d', 589 ' 2 b 3 c', 590 ' 2 b 4 d'] 591 592 with pytest.raises(ValueError, match='cannot supply keys for a cartesian join'): 593 t12 = table.join(t1, t2, join_type='cartesian', keys='a') 594 595 @pytest.mark.skipif('not HAS_SCIPY') 596 def test_join_with_join_skycoord_sky(self): 597 sc1 = SkyCoord([0, 1, 1.1, 2], [0, 0, 0, 0], unit='deg') 598 sc2 = SkyCoord([0.5, 1.05, 2.1], [0, 0, 0], unit='deg') 599 t1 = Table([sc1], names=['sc']) 600 t2 = Table([sc2], names=['sc']) 601 t12 = table.join(t1, t2, join_funcs={'sc': join_skycoord(0.2 * u.deg)}) 602 exp = ['sc_id sc_1 sc_2 ', 603 ' deg,deg deg,deg ', 604 '----- ------- --------', 605 ' 1 1.0,0.0 1.05,0.0', 606 ' 1 1.1,0.0 1.05,0.0', 607 ' 2 2.0,0.0 2.1,0.0'] 608 assert str(t12).splitlines() == exp 609 610 @pytest.mark.skipif('not HAS_SCIPY') 611 @pytest.mark.parametrize('distance_func', ['search_around_3d', search_around_3d]) 612 def test_join_with_join_skycoord_3d(self, distance_func): 613 sc1 = SkyCoord([0, 1, 1.1, 2]*u.deg, [0, 0, 0, 0]*u.deg, [1, 1, 2, 1]*u.m) 614 sc2 = SkyCoord([0.5, 1.05, 2.1]*u.deg, [0, 0, 0]*u.deg, [1, 1, 1]*u.m) 615 t1 = Table([sc1], names=['sc']) 616 t2 = Table([sc2], names=['sc']) 617 join_func = join_skycoord(np.deg2rad(0.2) * u.m, 618 distance_func=distance_func) 619 t12 = table.join(t1, t2, join_funcs={'sc': join_func}) 620 exp = ['sc_id sc_1 sc_2 ', 621 ' deg,deg,m deg,deg,m ', 622 '----- ----------- ------------', 623 ' 1 1.0,0.0,1.0 1.05,0.0,1.0', 624 ' 2 2.0,0.0,1.0 2.1,0.0,1.0'] 625 assert str(t12).splitlines() == exp 626 627 @pytest.mark.skipif('not HAS_SCIPY') 628 def test_join_with_join_distance_1d(self): 629 c1 = [0, 1, 1.1, 2] 630 c2 = [0.5, 1.05, 2.1] 631 t1 = Table([c1], names=['col']) 632 t2 = Table([c2], names=['col']) 633 join_func = join_distance(0.2, 634 kdtree_args={'leafsize': 32}, 635 query_args={'p': 2}) 636 t12 = table.join(t1, t2, join_type='outer', join_funcs={'col': join_func}) 637 exp = ['col_id col_1 col_2', 638 '------ ----- -----', 639 ' 1 1.0 1.05', 640 ' 1 1.1 1.05', 641 ' 2 2.0 2.1', 642 ' 3 0.0 --', 643 ' 4 -- 0.5'] 644 assert str(t12).splitlines() == exp 645 646 @pytest.mark.skipif('not HAS_SCIPY') 647 def test_join_with_join_distance_1d_multikey(self): 648 from astropy.table.operations import _apply_join_funcs 649 650 c1 = [0, 1, 1.1, 1.2, 2] 651 id1 = [0, 1, 2, 2, 3] 652 o1 = ['a', 'b', 'c', 'd', 'e'] 653 c2 = [0.5, 1.05, 2.1] 654 id2 = [0, 2, 4] 655 o2 = ['z', 'y', 'x'] 656 t1 = Table([c1, id1, o1], names=['col', 'id', 'o1']) 657 t2 = Table([c2, id2, o2], names=['col', 'id', 'o2']) 658 join_func = join_distance(0.2) 659 join_funcs = {'col': join_func} 660 t12 = table.join(t1, t2, join_type='outer', join_funcs=join_funcs) 661 exp = ['col_id col_1 id o1 col_2 o2', 662 '------ ----- --- --- ----- ---', 663 ' 1 1.0 1 b -- --', 664 ' 1 1.1 2 c 1.05 y', 665 ' 1 1.2 2 d 1.05 y', 666 ' 2 2.0 3 e -- --', 667 ' 2 -- 4 -- 2.1 x', 668 ' 3 0.0 0 a -- --', 669 ' 4 -- 0 -- 0.5 z'] 670 assert str(t12).splitlines() == exp 671 672 left, right, keys = _apply_join_funcs(t1, t2, ('col', 'id'), join_funcs) 673 assert keys == ('col_id', 'id') 674 675 @pytest.mark.skipif('not HAS_SCIPY') 676 def test_join_with_join_distance_1d_quantity(self): 677 c1 = [0, 1, 1.1, 2] * u.m 678 c2 = [500, 1050, 2100] * u.mm 679 t1 = QTable([c1], names=['col']) 680 t2 = QTable([c2], names=['col']) 681 join_func = join_distance(20 * u.cm) 682 t12 = table.join(t1, t2, join_funcs={'col': join_func}) 683 exp = ['col_id col_1 col_2 ', 684 ' m mm ', 685 '------ ----- ------', 686 ' 1 1.0 1050.0', 687 ' 1 1.1 1050.0', 688 ' 2 2.0 2100.0'] 689 assert str(t12).splitlines() == exp 690 691 # Generate column name conflict 692 t2['col_id'] = [0, 0, 0] 693 t2['col__id'] = [0, 0, 0] 694 t12 = table.join(t1, t2, join_funcs={'col': join_func}) 695 exp = ['col___id col_1 col_2 col_id col__id', 696 ' m mm ', 697 '-------- ----- ------ ------ -------', 698 ' 1 1.0 1050.0 0 0', 699 ' 1 1.1 1050.0 0 0', 700 ' 2 2.0 2100.0 0 0'] 701 assert str(t12).splitlines() == exp 702 703 @pytest.mark.skipif('not HAS_SCIPY') 704 def test_join_with_join_distance_2d(self): 705 c1 = np.array([[0, 1, 1.1, 2], 706 [0, 0, 1, 0]]).transpose() 707 c2 = np.array([[0.5, 1.05, 2.1], 708 [0, 0, 0]]).transpose() 709 t1 = Table([c1], names=['col']) 710 t2 = Table([c2], names=['col']) 711 join_func = join_distance(0.2, 712 kdtree_args={'leafsize': 32}, 713 query_args={'p': 2}) 714 t12 = table.join(t1, t2, join_type='outer', join_funcs={'col': join_func}) 715 exp = ['col_id col_1 [2] col_2 [2] ', 716 '------ ---------- -----------', 717 ' 1 1.0 .. 0.0 1.05 .. 0.0', 718 ' 2 2.0 .. 0.0 2.1 .. 0.0', 719 ' 3 0.0 .. 0.0 -- .. --', 720 ' 4 1.1 .. 1.0 -- .. --', 721 ' 5 -- .. -- 0.5 .. 0.0'] 722 assert str(t12).splitlines() == exp 723 724 def test_keys_left_right_basic(self): 725 """Test using the keys_left and keys_right args to specify different 726 join keys. This takes the standard test case but renames column 'a' 727 to 'x' and 'y' respectively for tables 1 and 2. Then it compares the 728 normal join on 'a' to the new join on 'x' and 'y'.""" 729 self._setup() 730 731 for join_type in ('inner', 'left', 'right', 'outer'): 732 t1 = self.t1.copy() 733 t2 = self.t2.copy() 734 # Expected is same as joining on 'a' but with names 'x', 'y' instead 735 t12_exp = table.join(t1, t2, keys='a', join_type=join_type) 736 t12_exp.add_column(t12_exp['a'], name='x', index=1) 737 t12_exp.add_column(t12_exp['a'], name='y', index=len(t1.colnames) + 1) 738 del t12_exp['a'] 739 740 # Different key names 741 t1.rename_column('a', 'x') 742 t2.rename_column('a', 'y') 743 keys_left_list = ['x'] # Test string key name 744 keys_right_list = [['y']] # Test list of string key names 745 if join_type == 'outer': 746 # Just do this for the outer join (others are the same) 747 keys_left_list.append([t1['x'].tolist()]) # Test list key column 748 keys_right_list.append([t2['y']]) # Test Column key column 749 750 for keys_left, keys_right in zip(keys_left_list, keys_right_list): 751 t12 = table.join(t1, t2, keys_left=keys_left, keys_right=keys_right, 752 join_type=join_type) 753 754 assert t12.colnames == t12_exp.colnames 755 for col in t12.values_equal(t12_exp).itercols(): 756 assert np.all(col) 757 assert t12_exp.meta == t12.meta 758 759 def test_keys_left_right_exceptions(self): 760 """Test exceptions using the keys_left and keys_right args to specify 761 different join keys. 762 """ 763 self._setup() 764 t1 = self.t1 765 t2 = self.t2 766 767 msg = r"left table does not have key column 'z'" 768 with pytest.raises(ValueError, match=msg): 769 table.join(t1, t2, keys_left='z', keys_right=['a']) 770 771 msg = r"left table has different length from key \[1, 2\]" 772 with pytest.raises(ValueError, match=msg): 773 table.join(t1, t2, keys_left=[[1, 2]], keys_right=['a']) 774 775 msg = r"keys arg must be None if keys_left and keys_right are supplied" 776 with pytest.raises(ValueError, match=msg): 777 table.join(t1, t2, keys_left='z', keys_right=['a'], keys='a') 778 779 msg = r"keys_left and keys_right args must have same length" 780 with pytest.raises(ValueError, match=msg): 781 table.join(t1, t2, keys_left=['a', 'b'], keys_right=['a']) 782 783 msg = r"keys_left and keys_right must both be provided" 784 with pytest.raises(ValueError, match=msg): 785 table.join(t1, t2, keys_left=['a', 'b']) 786 787 msg = r"cannot supply join_funcs arg and keys_left / keys_right" 788 with pytest.raises(ValueError, match=msg): 789 table.join(t1, t2, keys_left=['a'], keys_right=['a'], join_funcs={}) 790 791 792class TestSetdiff(): 793 794 def _setup(self, t_cls=Table): 795 lines1 = [' a b ', 796 ' 0 foo ', 797 ' 1 foo ', 798 ' 1 bar ', 799 ' 2 bar '] 800 lines2 = [' a b ', 801 ' 0 foo ', 802 ' 3 foo ', 803 ' 4 bar ', 804 ' 2 bar '] 805 lines3 = [' a b d ', 806 ' 0 foo R1', 807 ' 8 foo R2', 808 ' 1 bar R3', 809 ' 4 bar R4'] 810 self.t1 = t_cls.read(lines1, format='ascii') 811 self.t2 = t_cls.read(lines2, format='ascii') 812 self.t3 = t_cls.read(lines3, format='ascii') 813 814 def test_default_same_columns(self, operation_table_type): 815 self._setup(operation_table_type) 816 out = table.setdiff(self.t1, self.t2) 817 assert type(out['a']) is type(self.t1['a']) # noqa 818 assert type(out['b']) is type(self.t1['b']) # noqa 819 assert out.pformat() == [' a b ', 820 '--- ---', 821 ' 1 bar', 822 ' 1 foo'] 823 824 def test_default_same_tables(self, operation_table_type): 825 self._setup(operation_table_type) 826 out = table.setdiff(self.t1, self.t1) 827 828 assert type(out['a']) is type(self.t1['a']) # noqa 829 assert type(out['b']) is type(self.t1['b']) # noqa 830 assert out.pformat() == [' a b ', 831 '--- ---'] 832 833 def test_extra_col_left_table(self, operation_table_type): 834 self._setup(operation_table_type) 835 836 with pytest.raises(ValueError): 837 table.setdiff(self.t3, self.t1) 838 839 def test_extra_col_right_table(self, operation_table_type): 840 self._setup(operation_table_type) 841 out = table.setdiff(self.t1, self.t3) 842 843 assert type(out['a']) is type(self.t1['a']) # noqa 844 assert type(out['b']) is type(self.t1['b']) # noqa 845 assert out.pformat() == [' a b ', 846 '--- ---', 847 ' 1 foo', 848 ' 2 bar'] 849 850 def test_keys(self, operation_table_type): 851 self._setup(operation_table_type) 852 out = table.setdiff(self.t3, self.t1, keys=['a', 'b']) 853 854 assert type(out['a']) is type(self.t1['a']) # noqa 855 assert type(out['b']) is type(self.t1['b']) # noqa 856 assert out.pformat() == [' a b d ', 857 '--- --- ---', 858 ' 4 bar R4', 859 ' 8 foo R2'] 860 861 def test_missing_key(self, operation_table_type): 862 self._setup(operation_table_type) 863 864 with pytest.raises(ValueError): 865 table.setdiff(self.t3, self.t1, keys=['a', 'd']) 866 867 868class TestVStack(): 869 870 def _setup(self, t_cls=Table): 871 self.t1 = t_cls.read([' a b', 872 ' 0. foo', 873 ' 1. bar'], format='ascii') 874 875 self.t2 = t_cls.read([' a b c', 876 ' 2. pez 4', 877 ' 3. sez 5'], format='ascii') 878 879 self.t3 = t_cls.read([' a b', 880 ' 4. 7', 881 ' 5. 8', 882 ' 6. 9'], format='ascii') 883 self.t4 = t_cls(self.t1, copy=True, masked=t_cls is Table) 884 885 # The following table has meta-data that conflicts with t1 886 self.t5 = t_cls(self.t1, copy=True) 887 888 self.t1.meta.update(OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)])) 889 self.t2.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)])) 890 self.t4.meta.update(OrderedDict([('b', [5, 6]), ('c', {'c': 1}), ('e', 1)])) 891 self.t5.meta.update(OrderedDict([('b', 3), ('c', 'k'), ('d', 1)])) 892 self.meta_merge = OrderedDict([('b', [1, 2, 3, 4, 5, 6]), 893 ('c', {'a': 1, 'b': 1, 'c': 1}), 894 ('d', 1), 895 ('a', 1), 896 ('e', 1)]) 897 898 def test_validate_join_type(self): 899 self._setup() 900 with pytest.raises(TypeError, match='Did you accidentally call vstack'): 901 table.vstack(self.t1, self.t2) 902 903 def test_stack_rows(self, operation_table_type): 904 self._setup(operation_table_type) 905 t2 = self.t1.copy() 906 t2.meta.clear() 907 out = table.vstack([self.t1, t2[1]]) 908 assert type(out['a']) is type(self.t1['a']) # noqa 909 assert type(out['b']) is type(self.t1['b']) # noqa 910 assert out.pformat() == [' a b ', 911 '--- ---', 912 '0.0 foo', 913 '1.0 bar', 914 '1.0 bar'] 915 916 def test_stack_table_column(self, operation_table_type): 917 self._setup(operation_table_type) 918 t2 = self.t1.copy() 919 t2.meta.clear() 920 out = table.vstack([self.t1, t2['a']]) 921 assert out.masked is False 922 assert out.pformat() == [' a b ', 923 '--- ---', 924 '0.0 foo', 925 '1.0 bar', 926 '0.0 --', 927 '1.0 --'] 928 929 def test_table_meta_merge(self, operation_table_type): 930 self._setup(operation_table_type) 931 out = table.vstack([self.t1, self.t2, self.t4], join_type='inner') 932 assert out.meta == self.meta_merge 933 934 def test_table_meta_merge_conflict(self, operation_table_type): 935 self._setup(operation_table_type) 936 937 with pytest.warns(metadata.MergeConflictWarning) as w: 938 out = table.vstack([self.t1, self.t5], join_type='inner') 939 assert len(w) == 2 940 941 assert out.meta == self.t5.meta 942 943 with pytest.warns(metadata.MergeConflictWarning) as w: 944 out = table.vstack([self.t1, self.t5], join_type='inner', metadata_conflicts='warn') 945 assert len(w) == 2 946 947 assert out.meta == self.t5.meta 948 949 out = table.vstack([self.t1, self.t5], join_type='inner', metadata_conflicts='silent') 950 951 assert out.meta == self.t5.meta 952 953 with pytest.raises(MergeConflictError): 954 out = table.vstack([self.t1, self.t5], join_type='inner', metadata_conflicts='error') 955 956 with pytest.raises(ValueError): 957 out = table.vstack([self.t1, self.t5], join_type='inner', metadata_conflicts='nonsense') 958 959 def test_bad_input_type(self, operation_table_type): 960 self._setup(operation_table_type) 961 with pytest.raises(ValueError): 962 table.vstack([]) 963 with pytest.raises(TypeError): 964 table.vstack(1) 965 with pytest.raises(TypeError): 966 table.vstack([self.t2, 1]) 967 with pytest.raises(ValueError): 968 table.vstack([self.t1, self.t2], join_type='invalid join type') 969 970 def test_stack_basic_inner(self, operation_table_type): 971 self._setup(operation_table_type) 972 t1 = self.t1 973 t2 = self.t2 974 t4 = self.t4 975 976 t12 = table.vstack([t1, t2], join_type='inner') 977 assert t12.masked is False 978 assert type(t12) is operation_table_type 979 assert type(t12['a']) is type(t1['a']) # noqa 980 assert type(t12['b']) is type(t1['b']) # noqa 981 assert t12.pformat() == [' a b ', 982 '--- ---', 983 '0.0 foo', 984 '1.0 bar', 985 '2.0 pez', 986 '3.0 sez'] 987 988 t124 = table.vstack([t1, t2, t4], join_type='inner') 989 assert type(t124) is operation_table_type 990 assert type(t12['a']) is type(t1['a']) # noqa 991 assert type(t12['b']) is type(t1['b']) # noqa 992 assert t124.pformat() == [' a b ', 993 '--- ---', 994 '0.0 foo', 995 '1.0 bar', 996 '2.0 pez', 997 '3.0 sez', 998 '0.0 foo', 999 '1.0 bar'] 1000 1001 def test_stack_basic_outer(self, operation_table_type): 1002 if operation_table_type is QTable: 1003 pytest.xfail('Quantity columns do not support masking.') 1004 self._setup(operation_table_type) 1005 t1 = self.t1 1006 t2 = self.t2 1007 t4 = self.t4 1008 t12 = table.vstack([t1, t2], join_type='outer') 1009 assert t12.masked is False 1010 assert t12.pformat() == [' a b c ', 1011 '--- --- ---', 1012 '0.0 foo --', 1013 '1.0 bar --', 1014 '2.0 pez 4', 1015 '3.0 sez 5'] 1016 1017 t124 = table.vstack([t1, t2, t4], join_type='outer') 1018 assert t124.masked is False 1019 assert t124.pformat() == [' a b c ', 1020 '--- --- ---', 1021 '0.0 foo --', 1022 '1.0 bar --', 1023 '2.0 pez 4', 1024 '3.0 sez 5', 1025 '0.0 foo --', 1026 '1.0 bar --'] 1027 1028 def test_stack_incompatible(self, operation_table_type): 1029 self._setup(operation_table_type) 1030 with pytest.raises(TableMergeError) as excinfo: 1031 table.vstack([self.t1, self.t3], join_type='inner') 1032 assert ("The 'b' columns have incompatible types: {}" 1033 .format([self.t1['b'].dtype.name, self.t3['b'].dtype.name]) 1034 in str(excinfo.value)) 1035 1036 with pytest.raises(TableMergeError) as excinfo: 1037 table.vstack([self.t1, self.t3], join_type='outer') 1038 assert "The 'b' columns have incompatible types:" in str(excinfo.value) 1039 1040 with pytest.raises(TableMergeError): 1041 table.vstack([self.t1, self.t2], join_type='exact') 1042 1043 t1_reshape = self.t1.copy() 1044 t1_reshape['b'].shape = [2, 1] 1045 with pytest.raises(TableMergeError) as excinfo: 1046 table.vstack([self.t1, t1_reshape]) 1047 assert "have different shape" in str(excinfo.value) 1048 1049 def test_vstack_one_masked(self, operation_table_type): 1050 if operation_table_type is QTable: 1051 pytest.xfail('Quantity columns do not support masking.') 1052 self._setup(operation_table_type) 1053 t1 = self.t1 1054 t4 = self.t4 1055 t4['b'].mask[1] = True 1056 t14 = table.vstack([t1, t4]) 1057 assert t14.masked is False 1058 assert t14.pformat() == [' a b ', 1059 '--- ---', 1060 '0.0 foo', 1061 '1.0 bar', 1062 '0.0 foo', 1063 '1.0 --'] 1064 1065 def test_col_meta_merge_inner(self, operation_table_type): 1066 self._setup(operation_table_type) 1067 t1 = self.t1 1068 t2 = self.t2 1069 t4 = self.t4 1070 1071 # Key col 'a', should last value ('km') 1072 t1['a'].info.unit = 'cm' 1073 t2['a'].info.unit = 'm' 1074 t4['a'].info.unit = 'km' 1075 1076 # Key col 'a' format should take last when all match 1077 t1['a'].info.format = '%f' 1078 t2['a'].info.format = '%f' 1079 t4['a'].info.format = '%f' 1080 1081 # Key col 'b', take first value 't1_b' 1082 t1['b'].info.description = 't1_b' 1083 1084 # Key col 'b', take first non-empty value '%6s' 1085 t4['b'].info.format = '%6s' 1086 1087 # Key col 'a', should be merged meta 1088 t1['a'].info.meta.update(OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)])) 1089 t2['a'].info.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)])) 1090 t4['a'].info.meta.update(OrderedDict([('b', [5, 6]), ('c', {'c': 1}), ('e', 1)])) 1091 1092 # Key col 'b', should be meta2 1093 t2['b'].info.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)])) 1094 1095 if operation_table_type is Table: 1096 ctx = pytest.warns(metadata.MergeConflictWarning) 1097 else: 1098 ctx = nullcontext() 1099 1100 with ctx as warning_lines: 1101 out = table.vstack([t1, t2, t4], join_type='inner') 1102 1103 if operation_table_type is Table: 1104 assert len(warning_lines) == 2 1105 assert ("In merged column 'a' the 'unit' attribute does not match (cm != m)" 1106 in str(warning_lines[0].message)) 1107 assert ("In merged column 'a' the 'unit' attribute does not match (m != km)" 1108 in str(warning_lines[1].message)) 1109 # Check units are suitably ignored for a regular Table 1110 assert out.pformat() == [' a b ', 1111 ' km ', 1112 '-------- ------', 1113 '0.000000 foo', 1114 '1.000000 bar', 1115 '2.000000 pez', 1116 '3.000000 sez', 1117 '0.000000 foo', 1118 '1.000000 bar'] 1119 else: 1120 # Check QTable correctly dealt with units. 1121 assert out.pformat() == [' a b ', 1122 ' km ', 1123 '-------- ------', 1124 '0.000000 foo', 1125 '0.000010 bar', 1126 '0.002000 pez', 1127 '0.003000 sez', 1128 '0.000000 foo', 1129 '1.000000 bar'] 1130 assert out['a'].info.unit == 'km' 1131 assert out['a'].info.format == '%f' 1132 assert out['b'].info.description == 't1_b' 1133 assert out['b'].info.format == '%6s' 1134 assert out['a'].info.meta == self.meta_merge 1135 assert out['b'].info.meta == OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)]) 1136 1137 def test_col_meta_merge_outer(self, operation_table_type): 1138 if operation_table_type is QTable: 1139 pytest.xfail('Quantity columns do not support masking.') 1140 self._setup(operation_table_type) 1141 t1 = self.t1 1142 t2 = self.t2 1143 t4 = self.t4 1144 1145 # Key col 'a', should last value ('km') 1146 t1['a'].unit = 'cm' 1147 t2['a'].unit = 'm' 1148 t4['a'].unit = 'km' 1149 1150 # Key col 'a' format should take last when all match 1151 t1['a'].info.format = '%0d' 1152 t2['a'].info.format = '%0d' 1153 t4['a'].info.format = '%0d' 1154 1155 # Key col 'b', take first value 't1_b' 1156 t1['b'].info.description = 't1_b' 1157 1158 # Key col 'b', take first non-empty value '%6s' 1159 t4['b'].info.format = '%6s' 1160 1161 # Key col 'a', should be merged meta 1162 t1['a'].info.meta.update(OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)])) 1163 t2['a'].info.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)])) 1164 t4['a'].info.meta.update(OrderedDict([('b', [5, 6]), ('c', {'c': 1}), ('e', 1)])) 1165 1166 # Key col 'b', should be meta2 1167 t2['b'].info.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)])) 1168 1169 # All these should pass through 1170 t2['c'].unit = 'm' 1171 t2['c'].info.format = '%6s' 1172 t2['c'].info.description = 't2_c' 1173 1174 with pytest.warns(metadata.MergeConflictWarning) as warning_lines: 1175 out = table.vstack([t1, t2, t4], join_type='outer') 1176 1177 assert len(warning_lines) == 2 1178 assert ("In merged column 'a' the 'unit' attribute does not match (cm != m)" 1179 in str(warning_lines[0].message)) 1180 assert ("In merged column 'a' the 'unit' attribute does not match (m != km)" 1181 in str(warning_lines[1].message)) 1182 assert out['a'].unit == 'km' 1183 assert out['a'].info.format == '%0d' 1184 assert out['b'].info.description == 't1_b' 1185 assert out['b'].info.format == '%6s' 1186 assert out['a'].info.meta == self.meta_merge 1187 assert out['b'].info.meta == OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)]) 1188 assert out['c'].info.unit == 'm' 1189 assert out['c'].info.format == '%6s' 1190 assert out['c'].info.description == 't2_c' 1191 1192 def test_vstack_one_table(self, operation_table_type): 1193 self._setup(operation_table_type) 1194 """Regression test for issue #3313""" 1195 assert (self.t1 == table.vstack(self.t1)).all() 1196 assert (self.t1 == table.vstack([self.t1])).all() 1197 1198 def test_mixin_functionality(self, mixin_cols): 1199 col = mixin_cols['m'] 1200 len_col = len(col) 1201 t = table.QTable([col], names=['a']) 1202 cls_name = type(col).__name__ 1203 1204 # Vstack works for these classes: 1205 if isinstance(col, (u.Quantity, Time, TimeDelta, SkyCoord, EarthLocation, 1206 BaseRepresentationOrDifferential)): 1207 out = table.vstack([t, t]) 1208 assert len(out) == len_col * 2 1209 if cls_name == 'SkyCoord': 1210 # Argh, SkyCoord needs __eq__!! 1211 assert skycoord_equal(out['a'][len_col:], col) 1212 assert skycoord_equal(out['a'][:len_col], col) 1213 elif 'Repr' in cls_name or 'Diff' in cls_name: 1214 assert np.all(representation_equal(out['a'][:len_col], col)) 1215 assert np.all(representation_equal(out['a'][len_col:], col)) 1216 else: 1217 assert np.all(out['a'][:len_col] == col) 1218 assert np.all(out['a'][len_col:] == col) 1219 else: 1220 with pytest.raises(NotImplementedError) as err: 1221 table.vstack([t, t]) 1222 assert ('vstack unavailable for mixin column type(s): {}' 1223 .format(cls_name) in str(err.value)) 1224 1225 # Check for outer stack which requires masking. Only Time supports 1226 # this currently. 1227 t2 = table.QTable([col], names=['b']) # different from col name for t 1228 if isinstance(col, (Time, TimeDelta, Quantity)): 1229 out = table.vstack([t, t2], join_type='outer') 1230 assert len(out) == len_col * 2 1231 assert np.all(out['a'][:len_col] == col) 1232 assert np.all(out['b'][len_col:] == col) 1233 assert check_mask(out['a'], [False] * len_col + [True] * len_col) 1234 assert check_mask(out['b'], [True] * len_col + [False] * len_col) 1235 # check directly stacking mixin columns: 1236 out2 = table.vstack([t, t2['b']]) 1237 assert np.all(out['a'] == out2['a']) 1238 assert np.all(out['b'] == out2['b']) 1239 else: 1240 with pytest.raises(NotImplementedError) as err: 1241 table.vstack([t, t2], join_type='outer') 1242 assert ('vstack requires masking' in str(err.value) 1243 or 'vstack unavailable' in str(err.value)) 1244 1245 def test_vstack_different_representation(self): 1246 """Test that representations can be mixed together.""" 1247 rep1 = CartesianRepresentation([1, 2]*u.km, [3, 4]*u.km, 1*u.km) 1248 rep2 = SphericalRepresentation([0]*u.deg, [0]*u.deg, 10*u.km) 1249 t1 = Table([rep1]) 1250 t2 = Table([rep2]) 1251 t12 = table.vstack([t1, t2]) 1252 expected = CartesianRepresentation([1, 2, 10]*u.km, 1253 [3, 4, 0]*u.km, 1254 [1, 1, 0]*u.km) 1255 assert np.all(representation_equal(t12['col0'], expected)) 1256 1257 rep3 = UnitSphericalRepresentation([0]*u.deg, [0]*u.deg) 1258 t3 = Table([rep3]) 1259 with pytest.raises(ValueError, match='representations are inconsistent'): 1260 table.vstack([t1, t3]) 1261 1262 1263class TestDStack(): 1264 1265 def _setup(self, t_cls=Table): 1266 self.t1 = t_cls.read([' a b', 1267 ' 0. foo', 1268 ' 1. bar'], format='ascii') 1269 1270 self.t2 = t_cls.read([' a b c', 1271 ' 2. pez 4', 1272 ' 3. sez 5'], format='ascii') 1273 self.t2['d'] = Time([1, 2], format='cxcsec') 1274 1275 self.t3 = t_cls({'a': [[5., 6.], [4., 3.]], 1276 'b': [['foo', 'bar'], ['pez', 'sez']]}, 1277 names=('a', 'b')) 1278 1279 self.t4 = t_cls(self.t1, copy=True, masked=t_cls is Table) 1280 1281 self.t5 = t_cls({'a': [[4., 2.], [1., 6.]], 1282 'b': [['foo', 'pez'], ['bar', 'sez']]}, 1283 names=('a', 'b')) 1284 self.t6 = t_cls.read([' a b c', 1285 ' 7. pez 2', 1286 ' 4. sez 6', 1287 ' 6. foo 3'], format='ascii') 1288 1289 def test_validate_join_type(self): 1290 self._setup() 1291 with pytest.raises(TypeError, match='Did you accidentally call dstack'): 1292 table.dstack(self.t1, self.t2) 1293 1294 @staticmethod 1295 def compare_dstack(tables, out): 1296 for ii, tbl in enumerate(tables): 1297 for name, out_col in out.columns.items(): 1298 if name in tbl.colnames: 1299 # Columns always compare equal 1300 assert np.all(tbl[name] == out[name][:, ii]) 1301 1302 # If input has a mask then output must have same mask 1303 if hasattr(tbl[name], 'mask'): 1304 assert np.all(tbl[name].mask == out[name].mask[:, ii]) 1305 1306 # If input has no mask then output might have a mask (if other table 1307 # is missing that column). If so then all mask values should be False. 1308 elif hasattr(out[name], 'mask'): 1309 assert not np.any(out[name].mask[:, ii]) 1310 1311 else: 1312 # Column missing for this table, out must have a mask with all True. 1313 assert np.all(out[name].mask[:, ii]) 1314 1315 def test_dstack_table_column(self, operation_table_type): 1316 """Stack a table with 3 cols and one column (gets auto-converted to Table). 1317 """ 1318 self._setup(operation_table_type) 1319 t2 = self.t1.copy() 1320 out = table.dstack([self.t1, t2['a']]) 1321 self.compare_dstack([self.t1, t2[('a',)]], out) 1322 1323 def test_dstack_basic_outer(self, operation_table_type): 1324 if operation_table_type is QTable: 1325 pytest.xfail('Quantity columns do not support masking.') 1326 self._setup(operation_table_type) 1327 t1 = self.t1 1328 t2 = self.t2 1329 t4 = self.t4 1330 t4['a'].mask[0] = True 1331 # Test for non-masked table 1332 t12 = table.dstack([t1, t2], join_type='outer') 1333 assert type(t12) is operation_table_type 1334 assert type(t12['a']) is type(t1['a']) # noqa 1335 assert type(t12['b']) is type(t1['b']) # noqa 1336 self.compare_dstack([t1, t2], t12) 1337 1338 # Test for masked table 1339 t124 = table.dstack([t1, t2, t4], join_type='outer') 1340 assert type(t124) is operation_table_type 1341 assert type(t124['a']) is type(t4['a']) # noqa 1342 assert type(t124['b']) is type(t4['b']) # noqa 1343 self.compare_dstack([t1, t2, t4], t124) 1344 1345 def test_dstack_basic_inner(self, operation_table_type): 1346 self._setup(operation_table_type) 1347 t1 = self.t1 1348 t2 = self.t2 1349 t4 = self.t4 1350 1351 # Test for masked table 1352 t124 = table.dstack([t1, t2, t4], join_type='inner') 1353 assert type(t124) is operation_table_type 1354 assert type(t124['a']) is type(t4['a']) # noqa 1355 assert type(t124['b']) is type(t4['b']) # noqa 1356 self.compare_dstack([t1, t2, t4], t124) 1357 1358 def test_dstack_multi_dimension_column(self, operation_table_type): 1359 self._setup(operation_table_type) 1360 t3 = self.t3 1361 t5 = self.t5 1362 t2 = self.t2 1363 t35 = table.dstack([t3, t5]) 1364 assert type(t35) is operation_table_type 1365 assert type(t35['a']) is type(t3['a']) # noqa 1366 assert type(t35['b']) is type(t3['b']) # noqa 1367 self.compare_dstack([t3, t5], t35) 1368 1369 with pytest.raises(TableMergeError): 1370 table.dstack([t2, t3]) 1371 1372 def test_dstack_different_length_table(self, operation_table_type): 1373 self._setup(operation_table_type) 1374 t2 = self.t2 1375 t6 = self.t6 1376 with pytest.raises(ValueError): 1377 table.dstack([t2, t6]) 1378 1379 def test_dstack_single_table(self): 1380 self._setup(Table) 1381 out = table.dstack(self.t1) 1382 assert np.all(out == self.t1) 1383 1384 def test_dstack_representation(self): 1385 rep1 = SphericalRepresentation([1, 2]*u.deg, [3, 4]*u.deg, 1*u.kpc) 1386 rep2 = SphericalRepresentation([10, 20]*u.deg, [30, 40]*u.deg, 10*u.kpc) 1387 t1 = Table([rep1]) 1388 t2 = Table([rep2]) 1389 t12 = table.dstack([t1, t2]) 1390 assert np.all(representation_equal(t12['col0'][:, 0], rep1)) 1391 assert np.all(representation_equal(t12['col0'][:, 1], rep2)) 1392 1393 def test_dstack_skycoord(self): 1394 sc1 = SkyCoord([1, 2]*u.deg, [3, 4]*u.deg) 1395 sc2 = SkyCoord([10, 20]*u.deg, [30, 40]*u.deg) 1396 t1 = Table([sc1]) 1397 t2 = Table([sc2]) 1398 t12 = table.dstack([t1, t2]) 1399 assert skycoord_equal(sc1, t12['col0'][:, 0]) 1400 assert skycoord_equal(sc2, t12['col0'][:, 1]) 1401 1402 1403class TestHStack(): 1404 1405 def _setup(self, t_cls=Table): 1406 self.t1 = t_cls.read([' a b', 1407 ' 0. foo', 1408 ' 1. bar'], format='ascii') 1409 1410 self.t2 = t_cls.read([' a b c', 1411 ' 2. pez 4', 1412 ' 3. sez 5'], format='ascii') 1413 1414 self.t3 = t_cls.read([' d e', 1415 ' 4. 7', 1416 ' 5. 8', 1417 ' 6. 9'], format='ascii') 1418 self.t4 = t_cls(self.t1, copy=True, masked=True) 1419 self.t4['a'].name = 'f' 1420 self.t4['b'].name = 'g' 1421 1422 # The following table has meta-data that conflicts with t1 1423 self.t5 = t_cls(self.t1, copy=True) 1424 1425 self.t1.meta.update(OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)])) 1426 self.t2.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)])) 1427 self.t4.meta.update(OrderedDict([('b', [5, 6]), ('c', {'c': 1}), ('e', 1)])) 1428 self.t5.meta.update(OrderedDict([('b', 3), ('c', 'k'), ('d', 1)])) 1429 self.meta_merge = OrderedDict([('b', [1, 2, 3, 4, 5, 6]), 1430 ('c', {'a': 1, 'b': 1, 'c': 1}), 1431 ('d', 1), 1432 ('a', 1), 1433 ('e', 1)]) 1434 1435 def test_validate_join_type(self): 1436 self._setup() 1437 with pytest.raises(TypeError, match='Did you accidentally call hstack'): 1438 table.hstack(self.t1, self.t2) 1439 1440 def test_stack_same_table(self, operation_table_type): 1441 """ 1442 From #2995, test that hstack'ing references to the same table has the 1443 expected output. 1444 """ 1445 self._setup(operation_table_type) 1446 out = table.hstack([self.t1, self.t1]) 1447 assert out.masked is False 1448 assert out.pformat() == ['a_1 b_1 a_2 b_2', 1449 '--- --- --- ---', 1450 '0.0 foo 0.0 foo', 1451 '1.0 bar 1.0 bar'] 1452 1453 def test_stack_rows(self, operation_table_type): 1454 self._setup(operation_table_type) 1455 out = table.hstack([self.t1[0], self.t2[1]]) 1456 assert out.masked is False 1457 assert out.pformat() == ['a_1 b_1 a_2 b_2 c ', 1458 '--- --- --- --- ---', 1459 '0.0 foo 3.0 sez 5'] 1460 1461 def test_stack_columns(self, operation_table_type): 1462 self._setup(operation_table_type) 1463 out = table.hstack([self.t1, self.t2['c']]) 1464 assert type(out['a']) is type(self.t1['a']) # noqa 1465 assert type(out['b']) is type(self.t1['b']) # noqa 1466 assert type(out['c']) is type(self.t2['c']) # noqa 1467 assert out.pformat() == [' a b c ', 1468 '--- --- ---', 1469 '0.0 foo 4', 1470 '1.0 bar 5'] 1471 1472 def test_table_meta_merge(self, operation_table_type): 1473 self._setup(operation_table_type) 1474 out = table.hstack([self.t1, self.t2, self.t4], join_type='inner') 1475 assert out.meta == self.meta_merge 1476 1477 def test_table_meta_merge_conflict(self, operation_table_type): 1478 self._setup(operation_table_type) 1479 1480 with pytest.warns(metadata.MergeConflictWarning) as w: 1481 out = table.hstack([self.t1, self.t5], join_type='inner') 1482 assert len(w) == 2 1483 1484 assert out.meta == self.t5.meta 1485 1486 with pytest.warns(metadata.MergeConflictWarning) as w: 1487 out = table.hstack([self.t1, self.t5], join_type='inner', metadata_conflicts='warn') 1488 assert len(w) == 2 1489 1490 assert out.meta == self.t5.meta 1491 1492 out = table.hstack([self.t1, self.t5], join_type='inner', metadata_conflicts='silent') 1493 1494 assert out.meta == self.t5.meta 1495 1496 with pytest.raises(MergeConflictError): 1497 out = table.hstack([self.t1, self.t5], join_type='inner', metadata_conflicts='error') 1498 1499 with pytest.raises(ValueError): 1500 out = table.hstack([self.t1, self.t5], join_type='inner', metadata_conflicts='nonsense') 1501 1502 def test_bad_input_type(self, operation_table_type): 1503 self._setup(operation_table_type) 1504 with pytest.raises(ValueError): 1505 table.hstack([]) 1506 with pytest.raises(TypeError): 1507 table.hstack(1) 1508 with pytest.raises(TypeError): 1509 table.hstack([self.t2, 1]) 1510 with pytest.raises(ValueError): 1511 table.hstack([self.t1, self.t2], join_type='invalid join type') 1512 1513 def test_stack_basic(self, operation_table_type): 1514 self._setup(operation_table_type) 1515 t1 = self.t1 1516 t2 = self.t2 1517 t3 = self.t3 1518 t4 = self.t4 1519 1520 out = table.hstack([t1, t2], join_type='inner') 1521 assert out.masked is False 1522 assert type(out) is operation_table_type 1523 assert type(out['a_1']) is type(t1['a']) # noqa 1524 assert type(out['b_1']) is type(t1['b']) # noqa 1525 assert type(out['a_2']) is type(t2['a']) # noqa 1526 assert type(out['b_2']) is type(t2['b']) # noqa 1527 assert out.pformat() == ['a_1 b_1 a_2 b_2 c ', 1528 '--- --- --- --- ---', 1529 '0.0 foo 2.0 pez 4', 1530 '1.0 bar 3.0 sez 5'] 1531 1532 # stacking as a list gives same result 1533 out_list = table.hstack([t1, t2], join_type='inner') 1534 assert out.pformat() == out_list.pformat() 1535 1536 out = table.hstack([t1, t2], join_type='outer') 1537 assert out.pformat() == out_list.pformat() 1538 1539 out = table.hstack([t1, t2, t3, t4], join_type='outer') 1540 assert out.masked is False 1541 assert out.pformat() == ['a_1 b_1 a_2 b_2 c d e f g ', 1542 '--- --- --- --- --- --- --- --- ---', 1543 '0.0 foo 2.0 pez 4 4.0 7 0.0 foo', 1544 '1.0 bar 3.0 sez 5 5.0 8 1.0 bar', 1545 ' -- -- -- -- -- 6.0 9 -- --'] 1546 1547 out = table.hstack([t1, t2, t3, t4], join_type='inner') 1548 assert out.masked is False 1549 assert out.pformat() == ['a_1 b_1 a_2 b_2 c d e f g ', 1550 '--- --- --- --- --- --- --- --- ---', 1551 '0.0 foo 2.0 pez 4 4.0 7 0.0 foo', 1552 '1.0 bar 3.0 sez 5 5.0 8 1.0 bar'] 1553 1554 def test_stack_incompatible(self, operation_table_type): 1555 self._setup(operation_table_type) 1556 # For join_type exact, which will fail here because n_rows 1557 # does not match 1558 with pytest.raises(TableMergeError): 1559 table.hstack([self.t1, self.t3], join_type='exact') 1560 1561 def test_hstack_one_masked(self, operation_table_type): 1562 if operation_table_type is QTable: 1563 pytest.xfail() 1564 self._setup(operation_table_type) 1565 t1 = self.t1 1566 t2 = operation_table_type(t1, copy=True, masked=True) 1567 t2.meta.clear() 1568 t2['b'].mask[1] = True 1569 out = table.hstack([t1, t2]) 1570 assert out.pformat() == ['a_1 b_1 a_2 b_2', 1571 '--- --- --- ---', 1572 '0.0 foo 0.0 foo', 1573 '1.0 bar 1.0 --'] 1574 1575 def test_table_col_rename(self, operation_table_type): 1576 self._setup(operation_table_type) 1577 out = table.hstack([self.t1, self.t2], join_type='inner', 1578 uniq_col_name='{table_name}_{col_name}', 1579 table_names=('left', 'right')) 1580 assert out.masked is False 1581 assert out.pformat() == ['left_a left_b right_a right_b c ', 1582 '------ ------ ------- ------- ---', 1583 ' 0.0 foo 2.0 pez 4', 1584 ' 1.0 bar 3.0 sez 5'] 1585 1586 def test_col_meta_merge(self, operation_table_type): 1587 self._setup(operation_table_type) 1588 t1 = self.t1 1589 t3 = self.t3[:2] 1590 t4 = self.t4 1591 1592 # Just set a bunch of meta and make sure it is the same in output 1593 meta1 = OrderedDict([('b', [1, 2]), ('c', {'a': 1}), ('d', 1)]) 1594 t1['a'].unit = 'cm' 1595 t1['b'].info.description = 't1_b' 1596 t4['f'].info.format = '%6s' 1597 t1['b'].info.meta.update(meta1) 1598 t3['d'].info.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)])) 1599 t4['g'].info.meta.update(OrderedDict([('b', [5, 6]), ('c', {'c': 1}), ('e', 1)])) 1600 t3['e'].info.meta.update(OrderedDict([('b', [3, 4]), ('c', {'b': 1}), ('a', 1)])) 1601 t3['d'].unit = 'm' 1602 t3['d'].info.format = '%6s' 1603 t3['d'].info.description = 't3_c' 1604 1605 out = table.hstack([t1, t3, t4], join_type='exact') 1606 1607 for t in [t1, t3, t4]: 1608 for name in t.colnames: 1609 for attr in ('meta', 'unit', 'format', 'description'): 1610 assert getattr(out[name].info, attr) == getattr(t[name].info, attr) 1611 1612 # Make sure we got a copy of meta, not ref 1613 t1['b'].info.meta['b'] = None 1614 assert out['b'].info.meta['b'] == [1, 2] 1615 1616 def test_hstack_one_table(self, operation_table_type): 1617 self._setup(operation_table_type) 1618 """Regression test for issue #3313""" 1619 assert (self.t1 == table.hstack(self.t1)).all() 1620 assert (self.t1 == table.hstack([self.t1])).all() 1621 1622 def test_mixin_functionality(self, mixin_cols): 1623 col1 = mixin_cols['m'] 1624 col2 = col1[2:4] # Shorter version of col1 1625 t1 = table.QTable([col1]) 1626 t2 = table.QTable([col2]) 1627 1628 cls_name = type(col1).__name__ 1629 1630 out = table.hstack([t1, t2], join_type='inner') 1631 assert type(out['col0_1']) is type(out['col0_2']) # noqa 1632 assert len(out) == len(col2) 1633 1634 # Check that columns are as expected. 1635 if cls_name == 'SkyCoord': 1636 assert skycoord_equal(out['col0_1'], col1[:len(col2)]) 1637 assert skycoord_equal(out['col0_2'], col2) 1638 elif 'Repr' in cls_name or 'Diff' in cls_name: 1639 assert np.all(representation_equal(out['col0_1'], col1[:len(col2)])) 1640 assert np.all(representation_equal(out['col0_2'], col2)) 1641 else: 1642 assert np.all(out['col0_1'] == col1[:len(col2)]) 1643 assert np.all(out['col0_2'] == col2) 1644 1645 # Time class supports masking, all other mixins do not 1646 if isinstance(col1, (Time, TimeDelta, Quantity)): 1647 out = table.hstack([t1, t2], join_type='outer') 1648 assert len(out) == len(t1) 1649 assert np.all(out['col0_1'] == col1) 1650 assert np.all(out['col0_2'][:len(col2)] == col2) 1651 assert check_mask(out['col0_2'], [False, False, True, True]) 1652 1653 # check directly stacking mixin columns: 1654 out2 = table.hstack([t1, t2['col0']], join_type='outer') 1655 assert np.all(out['col0_1'] == out2['col0_1']) 1656 assert np.all(out['col0_2'] == out2['col0_2']) 1657 else: 1658 with pytest.raises(NotImplementedError) as err: 1659 table.hstack([t1, t2], join_type='outer') 1660 assert 'hstack requires masking' in str(err.value) 1661 1662 1663def test_unique(operation_table_type): 1664 t = operation_table_type.read( 1665 [' a b c d', 1666 ' 2 b 7.0 0', 1667 ' 1 c 3.0 5', 1668 ' 2 b 6.0 2', 1669 ' 2 a 4.0 3', 1670 ' 1 a 1.0 7', 1671 ' 2 b 5.0 1', 1672 ' 0 a 0.0 4', 1673 ' 1 a 2.0 6', 1674 ' 1 c 3.0 5', 1675 ], format='ascii') 1676 1677 tu = operation_table_type(np.sort(t[:-1])) 1678 1679 t_all = table.unique(t) 1680 assert sort_eq(t_all.pformat(), tu.pformat()) 1681 t_s = t.copy() 1682 del t_s['b', 'c', 'd'] 1683 t_all = table.unique(t_s) 1684 assert sort_eq(t_all.pformat(), [' a ', 1685 '---', 1686 ' 0', 1687 ' 1', 1688 ' 2']) 1689 1690 key1 = 'a' 1691 t1a = table.unique(t, key1) 1692 assert sort_eq(t1a.pformat(), [' a b c d ', 1693 '--- --- --- ---', 1694 ' 0 a 0.0 4', 1695 ' 1 c 3.0 5', 1696 ' 2 b 7.0 0']) 1697 t1b = table.unique(t, key1, keep='last') 1698 assert sort_eq(t1b.pformat(), [' a b c d ', 1699 '--- --- --- ---', 1700 ' 0 a 0.0 4', 1701 ' 1 c 3.0 5', 1702 ' 2 b 5.0 1']) 1703 t1c = table.unique(t, key1, keep='none') 1704 assert sort_eq(t1c.pformat(), [' a b c d ', 1705 '--- --- --- ---', 1706 ' 0 a 0.0 4']) 1707 1708 key2 = ['a', 'b'] 1709 t2a = table.unique(t, key2) 1710 assert sort_eq(t2a.pformat(), [' a b c d ', 1711 '--- --- --- ---', 1712 ' 0 a 0.0 4', 1713 ' 1 a 1.0 7', 1714 ' 1 c 3.0 5', 1715 ' 2 a 4.0 3', 1716 ' 2 b 7.0 0']) 1717 1718 t2b = table.unique(t, key2, keep='last') 1719 assert sort_eq(t2b.pformat(), [' a b c d ', 1720 '--- --- --- ---', 1721 ' 0 a 0.0 4', 1722 ' 1 a 2.0 6', 1723 ' 1 c 3.0 5', 1724 ' 2 a 4.0 3', 1725 ' 2 b 5.0 1']) 1726 t2c = table.unique(t, key2, keep='none') 1727 assert sort_eq(t2c.pformat(), [' a b c d ', 1728 '--- --- --- ---', 1729 ' 0 a 0.0 4', 1730 ' 2 a 4.0 3']) 1731 1732 key2 = ['a', 'a'] 1733 with pytest.raises(ValueError) as exc: 1734 t2a = table.unique(t, key2) 1735 assert exc.value.args[0] == "duplicate key names" 1736 1737 with pytest.raises(ValueError) as exc: 1738 table.unique(t, key2, keep=True) 1739 assert exc.value.args[0] == ( 1740 "'keep' should be one of 'first', 'last', 'none'") 1741 1742 t1_m = operation_table_type(t1a, masked=True) 1743 t1_m['a'].mask[1] = True 1744 1745 with pytest.raises(ValueError) as exc: 1746 t1_mu = table.unique(t1_m) 1747 assert exc.value.args[0] == ( 1748 "cannot use columns with masked values as keys; " 1749 "remove column 'a' from keys and rerun unique()") 1750 1751 t1_mu = table.unique(t1_m, silent=True) 1752 assert t1_mu.masked is False 1753 assert t1_mu.pformat() == [' a b c d ', 1754 '--- --- --- ---', 1755 ' 0 a 0.0 4', 1756 ' 2 b 7.0 0', 1757 ' -- c 3.0 5'] 1758 1759 with pytest.raises(ValueError): 1760 t1_mu = table.unique(t1_m, silent=True, keys='a') 1761 1762 t1_m = operation_table_type(t, masked=True) 1763 t1_m['a'].mask[1] = True 1764 t1_m['d'].mask[3] = True 1765 1766 # Test that multiple masked key columns get removed in the correct 1767 # order 1768 t1_mu = table.unique(t1_m, keys=['d', 'a', 'b'], silent=True) 1769 assert t1_mu.masked is False 1770 assert t1_mu.pformat() == [' a b c d ', 1771 '--- --- --- ---', 1772 ' 2 a 4.0 --', 1773 ' 2 b 7.0 0', 1774 ' -- c 3.0 5'] 1775 1776 1777def test_vstack_bytes(operation_table_type): 1778 """ 1779 Test for issue #5617 when vstack'ing bytes columns in Py3. 1780 This is really an upstream numpy issue numpy/numpy/#8403. 1781 """ 1782 t = operation_table_type([[b'a']], names=['a']) 1783 assert t['a'].itemsize == 1 1784 1785 t2 = table.vstack([t, t]) 1786 assert len(t2) == 2 1787 assert t2['a'].itemsize == 1 1788 1789 1790def test_vstack_unicode(): 1791 """ 1792 Test for problem related to issue #5617 when vstack'ing *unicode* 1793 columns. In this case the character size gets multiplied by 4. 1794 """ 1795 t = table.Table([['a']], names=['a']) 1796 assert t['a'].itemsize == 4 # 4-byte / char for U dtype 1797 1798 t2 = table.vstack([t, t]) 1799 assert len(t2) == 2 1800 assert t2['a'].itemsize == 4 1801 1802 1803def test_join_mixins_time_quantity(): 1804 """ 1805 Test for table join using non-ndarray key columns. 1806 """ 1807 tm1 = Time([2, 1, 2], format='cxcsec') 1808 q1 = [2, 1, 1] * u.m 1809 idx1 = [1, 2, 3] 1810 tm2 = Time([2, 3], format='cxcsec') 1811 q2 = [2, 3] * u.m 1812 idx2 = [10, 20] 1813 t1 = Table([tm1, q1, idx1], names=['tm', 'q', 'idx']) 1814 t2 = Table([tm2, q2, idx2], names=['tm', 'q', 'idx']) 1815 # Output: 1816 # 1817 # <Table length=4> 1818 # tm q idx_1 idx_2 1819 # m 1820 # object float64 int64 int64 1821 # ------------------ ------- ----- ----- 1822 # 0.9999999999969589 1.0 2 -- 1823 # 2.00000000000351 1.0 3 -- 1824 # 2.00000000000351 2.0 1 10 1825 # 3.000000000000469 3.0 -- 20 1826 1827 t12 = table.join(t1, t2, join_type='outer', keys=['tm', 'q']) 1828 # Key cols are lexically sorted 1829 assert np.all(t12['tm'] == Time([1, 2, 2, 3], format='cxcsec')) 1830 assert np.all(t12['q'] == [1, 1, 2, 3] * u.m) 1831 assert np.all(t12['idx_1'] == np.ma.array([2, 3, 1, 0], mask=[0, 0, 0, 1])) 1832 assert np.all(t12['idx_2'] == np.ma.array([0, 0, 10, 20], mask=[1, 1, 0, 0])) 1833 1834 1835def test_join_mixins_not_sortable(): 1836 """ 1837 Test for table join using non-ndarray key columns that are not sortable. 1838 """ 1839 sc = SkyCoord([1, 2], [3, 4], unit='deg,deg') 1840 t1 = Table([sc, [1, 2]], names=['sc', 'idx1']) 1841 t2 = Table([sc, [10, 20]], names=['sc', 'idx2']) 1842 1843 with pytest.raises(TypeError, match='one or more key columns are not sortable'): 1844 table.join(t1, t2, keys='sc') 1845 1846 1847def test_join_non_1d_key_column(): 1848 c1 = [[1, 2], [3, 4]] 1849 c2 = [1, 2] 1850 t1 = Table([c1, c2], names=['a', 'b']) 1851 t2 = t1.copy() 1852 with pytest.raises(ValueError, match="key column 'a' must be 1-d"): 1853 table.join(t1, t2, keys='a') 1854 1855 1856def test_argsort_time_column(): 1857 """Regression test for #10823.""" 1858 times = Time(['2016-01-01', '2018-01-01', '2017-01-01']) 1859 t = Table([times], names=['time']) 1860 i = t.argsort('time') 1861 assert np.all(i == times.argsort()) 1862 1863 1864def test_sort_indexed_table(): 1865 """Test fix for #9473 and #6545 - and another regression test for #10823.""" 1866 t = Table([[1, 3, 2], [6, 4, 5]], names=('a', 'b')) 1867 t.add_index('a') 1868 t.sort('a') 1869 assert np.all(t['a'] == [1, 2, 3]) 1870 assert np.all(t['b'] == [6, 5, 4]) 1871 t.sort('b') 1872 assert np.all(t['b'] == [4, 5, 6]) 1873 assert np.all(t['a'] == [3, 2, 1]) 1874 1875 times = ['2016-01-01', '2018-01-01', '2017-01-01'] 1876 tm = Time(times) 1877 t2 = Table([tm, [3, 2, 1]], names=['time', 'flux']) 1878 t2.sort('flux') 1879 assert np.all(t2['flux'] == [1, 2, 3]) 1880 t2.sort('time') 1881 assert np.all(t2['flux'] == [3, 1, 2]) 1882 assert np.all(t2['time'] == tm[[0, 2, 1]]) 1883 1884 # Using the table as a TimeSeries implicitly sets the index, so 1885 # this test is a bit different from the above. 1886 from astropy.timeseries import TimeSeries 1887 ts = TimeSeries(time=times) 1888 ts['flux'] = [3, 2, 1] 1889 ts.sort('flux') 1890 assert np.all(ts['flux'] == [1, 2, 3]) 1891 ts.sort('time') 1892 assert np.all(ts['flux'] == [3, 1, 2]) 1893 assert np.all(ts['time'] == tm[[0, 2, 1]]) 1894 1895 1896def test_get_out_class(): 1897 c = table.Column([1, 2]) 1898 mc = table.MaskedColumn([1, 2]) 1899 q = [1, 2] * u.m 1900 1901 assert _get_out_class([c, mc]) is mc.__class__ 1902 assert _get_out_class([mc, c]) is mc.__class__ 1903 assert _get_out_class([c, c]) is c.__class__ 1904 assert _get_out_class([c]) is c.__class__ 1905 1906 with pytest.raises(ValueError): 1907 _get_out_class([c, q]) 1908 1909 with pytest.raises(ValueError): 1910 _get_out_class([q, c]) 1911 1912 1913def test_masking_required_exception(): 1914 """ 1915 Test that outer join, hstack and vstack fail for a mixin column which 1916 does not support masking. 1917 """ 1918 col = table.NdarrayMixin([0, 1, 2, 3]) 1919 t1 = table.QTable([[1, 2, 3, 4], col], names=['a', 'b']) 1920 t2 = table.QTable([[1, 2], col[:2]], names=['a', 'c']) 1921 1922 with pytest.raises(NotImplementedError) as err: 1923 table.vstack([t1, t2], join_type='outer') 1924 assert 'vstack unavailable' in str(err.value) 1925 1926 with pytest.raises(NotImplementedError) as err: 1927 table.hstack([t1, t2], join_type='outer') 1928 assert 'hstack requires masking' in str(err.value) 1929 1930 with pytest.raises(NotImplementedError) as err: 1931 table.join(t1, t2, join_type='outer') 1932 assert 'join requires masking' in str(err.value) 1933 1934 1935def test_stack_columns(): 1936 c = table.Column([1, 2]) 1937 mc = table.MaskedColumn([1, 2]) 1938 q = [1, 2] * u.m 1939 time = Time(['2001-01-02T12:34:56', '2001-02-03T00:01:02']) 1940 sc = SkyCoord([1, 2], [3, 4], unit='deg') 1941 cq = table.Column([11, 22], unit=u.m) 1942 1943 t = table.hstack([c, q]) 1944 assert t.__class__ is table.QTable 1945 assert t.masked is False 1946 t = table.hstack([q, c]) 1947 assert t.__class__ is table.QTable 1948 assert t.masked is False 1949 1950 t = table.hstack([mc, q]) 1951 assert t.__class__ is table.QTable 1952 assert t.masked is False 1953 1954 t = table.hstack([c, mc]) 1955 assert t.__class__ is table.Table 1956 assert t.masked is False 1957 1958 t = table.vstack([q, q]) 1959 assert t.__class__ is table.QTable 1960 1961 t = table.vstack([c, c]) 1962 assert t.__class__ is table.Table 1963 1964 t = table.hstack([c, time]) 1965 assert t.__class__ is table.Table 1966 t = table.hstack([c, sc]) 1967 assert t.__class__ is table.Table 1968 t = table.hstack([q, time, sc]) 1969 assert t.__class__ is table.QTable 1970 1971 with pytest.raises(ValueError): 1972 table.vstack([c, q]) 1973 1974 with pytest.raises(ValueError): 1975 t = table.vstack([q, cq]) 1976 1977 1978def test_mixin_join_regression(): 1979 # This used to trigger a ValueError: 1980 # ValueError: NumPy boolean array indexing assignment cannot assign 1981 # 6 input values to the 4 output values where the mask is true 1982 1983 t1 = QTable() 1984 t1['index'] = [1, 2, 3, 4, 5] 1985 t1['flux1'] = [2, 3, 2, 1, 1] * u.Jy 1986 t1['flux2'] = [2, 3, 2, 1, 1] * u.Jy 1987 1988 t2 = QTable() 1989 t2['index'] = [3, 4, 5, 6] 1990 t2['flux1'] = [2, 1, 1, 3] * u.Jy 1991 t2['flux2'] = [2, 1, 1, 3] * u.Jy 1992 1993 t12 = table.join(t1, t2, keys=('index', 'flux1', 'flux2'), join_type='outer') 1994 1995 assert len(t12) == 6 1996