1# -*- coding: utf-8 -*-
2# Licensed under a 3-clause BSD style license - see LICENSE.rst
3
4from astropy.utils.tests.test_metadata import MetaBaseTest
5import operator
6
7import pytest
8import numpy as np
9from numpy.testing import assert_array_equal
10
11from astropy.tests.helper import assert_follows_unicode_guidelines
12from astropy import table
13from astropy import time
14from astropy import units as u
15
16
17class TestColumn():
18
19    def test_subclass(self, Column):
20        c = Column(name='a')
21        assert isinstance(c, np.ndarray)
22        c2 = c * 2
23        assert isinstance(c2, Column)
24        assert isinstance(c2, np.ndarray)
25
26    def test_numpy_ops(self, Column):
27        """Show that basic numpy operations with Column behave sensibly"""
28
29        arr = np.array([1, 2, 3])
30        c = Column(arr, name='a')
31
32        for op, test_equal in ((operator.eq, True),
33                               (operator.ne, False),
34                               (operator.ge, True),
35                               (operator.gt, False),
36                               (operator.le, True),
37                               (operator.lt, False)):
38            for eq in (op(c, arr), op(arr, c)):
39
40                assert np.all(eq) if test_equal else not np.any(eq)
41                assert len(eq) == 3
42                if Column is table.Column:
43                    assert type(eq) == np.ndarray
44                else:
45                    assert type(eq) == np.ma.core.MaskedArray
46                assert eq.dtype.str == '|b1'
47
48        lt = c - 1 < arr
49        assert np.all(lt)
50
51    def test_numpy_boolean_ufuncs(self, Column):
52        """Show that basic numpy operations with Column behave sensibly"""
53
54        arr = np.array([1, 2, 3])
55        c = Column(arr, name='a')
56
57        for ufunc, test_true in ((np.isfinite, True),
58                                 (np.isinf, False),
59                                 (np.isnan, False),
60                                 (np.sign, True),
61                                 (np.signbit, False)):
62            result = ufunc(c)
63            assert len(result) == len(c)
64            assert np.all(result) if test_true else not np.any(result)
65            if Column is table.Column:
66                assert type(result) == np.ndarray
67            else:
68                assert type(result) == np.ma.core.MaskedArray
69                if ufunc is not np.sign:
70                    assert result.dtype.str == '|b1'
71
72    def test_view(self, Column):
73        c = np.array([1, 2, 3], dtype=np.int64).view(Column)
74        assert repr(c) == f"<{Column.__name__} dtype='int64' length=3>\n1\n2\n3"
75
76    def test_format(self, Column):
77        """Show that the formatted output from str() works"""
78        from astropy import conf
79        with conf.set_temp('max_lines', 8):
80            c1 = Column(np.arange(2000), name='a', dtype=float,
81                        format='%6.2f')
82            assert str(c1).splitlines() == ['   a   ',
83                                            '-------',
84                                            '   0.00',
85                                            '   1.00',
86                                            '    ...',
87                                            '1998.00',
88                                            '1999.00',
89                                            'Length = 2000 rows']
90
91    def test_convert_numpy_array(self, Column):
92        d = Column([1, 2, 3], name='a', dtype='i8')
93
94        np_data = np.array(d)
95        assert np.all(np_data == d)
96        np_data = np.array(d, copy=False)
97        assert np.all(np_data == d)
98        np_data = np.array(d, dtype='i4')
99        assert np.all(np_data == d)
100
101    def test_convert_unit(self, Column):
102        d = Column([1, 2, 3], name='a', dtype="f8", unit="m")
103        d.convert_unit_to("km")
104        assert np.all(d.data == [0.001, 0.002, 0.003])
105
106    def test_array_wrap(self):
107        """Test that the __array_wrap__ method converts a reduction ufunc
108        output that has a different shape into an ndarray view.  Without this a
109        method call like c.mean() returns a Column array object with length=1."""
110        # Mean and sum for a 1-d float column
111        c = table.Column(name='a', data=[1., 2., 3.])
112        assert np.allclose(c.mean(), 2.0)
113        assert isinstance(c.mean(), (np.floating, float))
114        assert np.allclose(c.sum(), 6.)
115        assert isinstance(c.sum(), (np.floating, float))
116
117        # Non-reduction ufunc preserves Column class
118        assert isinstance(np.cos(c), table.Column)
119
120        # Sum for a 1-d int column
121        c = table.Column(name='a', data=[1, 2, 3])
122        assert np.allclose(c.sum(), 6)
123        assert isinstance(c.sum(), (np.integer, int))
124
125        # Sum for a 2-d int column
126        c = table.Column(name='a', data=[[1, 2, 3],
127                                         [4, 5, 6]])
128        assert c.sum() == 21
129        assert isinstance(c.sum(), (np.integer, int))
130        assert np.all(c.sum(axis=0) == [5, 7, 9])
131        assert c.sum(axis=0).shape == (3,)
132        assert isinstance(c.sum(axis=0), np.ndarray)
133
134        # Sum and mean for a 1-d masked column
135        c = table.MaskedColumn(name='a', data=[1., 2., 3.], mask=[0, 0, 1])
136        assert np.allclose(c.mean(), 1.5)
137        assert isinstance(c.mean(), (np.floating, float))
138        assert np.allclose(c.sum(), 3.)
139        assert isinstance(c.sum(), (np.floating, float))
140
141    def test_name_none(self, Column):
142        """Can create a column without supplying name, which defaults to None"""
143        c = Column([1, 2])
144        assert c.name is None
145        assert np.all(c == np.array([1, 2]))
146
147    def test_quantity_init(self, Column):
148
149        c = Column(data=np.array([1, 2, 3]) * u.m)
150        assert np.all(c.data == np.array([1, 2, 3]))
151        assert np.all(c.unit == u.m)
152
153        c = Column(data=np.array([1, 2, 3]) * u.m, unit=u.cm)
154        assert np.all(c.data == np.array([100, 200, 300]))
155        assert np.all(c.unit == u.cm)
156
157    def test_quantity_comparison(self, Column):
158        # regression test for gh-6532
159        c = Column([1, 2100, 3], unit='Hz')
160        q = 2 * u.kHz
161        check = c < q
162        assert np.all(check == [True, False, True])
163        # This already worked, but just in case.
164        check = q >= c
165        assert np.all(check == [True, False, True])
166
167    def test_attrs_survive_getitem_after_change(self, Column):
168        """
169        Test for issue #3023: when calling getitem with a MaskedArray subclass
170        the original object attributes are not copied.
171        """
172        c1 = Column([1, 2, 3], name='a', unit='m', format='%i',
173                    description='aa', meta={'a': 1})
174        c1.name = 'b'
175        c1.unit = 'km'
176        c1.format = '%d'
177        c1.description = 'bb'
178        c1.meta = {'bbb': 2}
179
180        for item in (slice(None, None), slice(None, 1), np.array([0, 2]),
181                     np.array([False, True, False])):
182            c2 = c1[item]
183            assert c2.name == 'b'
184            assert c2.unit is u.km
185            assert c2.format == '%d'
186            assert c2.description == 'bb'
187            assert c2.meta == {'bbb': 2}
188
189        # Make sure that calling getitem resulting in a scalar does
190        # not copy attributes.
191        val = c1[1]
192        for attr in ('name', 'unit', 'format', 'description', 'meta'):
193            assert not hasattr(val, attr)
194
195    def test_to_quantity(self, Column):
196        d = Column([1, 2, 3], name='a', dtype="f8", unit="m")
197
198        assert np.all(d.quantity == ([1, 2, 3.] * u.m))
199        assert np.all(d.quantity.value == ([1, 2, 3.] * u.m).value)
200        assert np.all(d.quantity == d.to('m'))
201        assert np.all(d.quantity.value == d.to('m').value)
202
203        np.testing.assert_allclose(d.to(u.km).value, ([.001, .002, .003] * u.km).value)
204        np.testing.assert_allclose(d.to('km').value, ([.001, .002, .003] * u.km).value)
205
206        np.testing.assert_allclose(d.to(u.MHz, u.equivalencies.spectral()).value,
207                                   [299.792458, 149.896229, 99.93081933])
208
209        d_nounit = Column([1, 2, 3], name='a', dtype="f8", unit=None)
210        with pytest.raises(u.UnitsError):
211            d_nounit.to(u.km)
212        assert np.all(d_nounit.to(u.dimensionless_unscaled) == np.array([1, 2, 3]))
213
214        # make sure the correct copy/no copy behavior is happening
215        q = [1, 3, 5] * u.km
216
217        # to should always make a copy
218        d.to(u.km)[:] = q
219        np.testing.assert_allclose(d, [1, 2, 3])
220
221        # explicit copying of the quantity should not change the column
222        d.quantity.copy()[:] = q
223        np.testing.assert_allclose(d, [1, 2, 3])
224
225        # but quantity directly is a "view", accessing the underlying column
226        d.quantity[:] = q
227        np.testing.assert_allclose(d, [1000, 3000, 5000])
228
229        # view should also work for integers
230        d2 = Column([1, 2, 3], name='a', dtype=int, unit="m")
231        d2.quantity[:] = q
232        np.testing.assert_allclose(d2, [1000, 3000, 5000])
233
234        # but it should fail for strings or other non-numeric tables
235        d3 = Column(['arg', 'name', 'stuff'], name='a', unit="m")
236        with pytest.raises(TypeError):
237            d3.quantity
238
239    def test_to_funcunit_quantity(self, Column):
240        """
241        Tests for #8424, check if function-unit can be retrieved from column.
242        """
243        d = Column([1, 2, 3], name='a', dtype="f8", unit="dex(AA)")
244
245        assert np.all(d.quantity == ([1, 2, 3] * u.dex(u.AA)))
246        assert np.all(d.quantity.value == ([1, 2, 3] * u.dex(u.AA)).value)
247        assert np.all(d.quantity == d.to("dex(AA)"))
248        assert np.all(d.quantity.value == d.to("dex(AA)").value)
249
250        # make sure, casting to linear unit works
251        q = [10, 100, 1000] * u.AA
252        np.testing.assert_allclose(d.to(u.AA), q)
253
254    def test_item_access_type(self, Column):
255        """
256        Tests for #3095, which forces integer item access to always return a plain
257        ndarray or MaskedArray, even in the case of a multi-dim column.
258        """
259        integer_types = (int, np.int_)
260
261        for int_type in integer_types:
262            c = Column([[1, 2], [3, 4]])
263            i0 = int_type(0)
264            i1 = int_type(1)
265            assert np.all(c[i0] == [1, 2])
266            assert type(c[i0]) == (np.ma.MaskedArray if hasattr(Column, 'mask') else np.ndarray)
267            assert c[i0].shape == (2,)
268
269            c01 = c[i0:i1]
270            assert np.all(c01 == [[1, 2]])
271            assert isinstance(c01, Column)
272            assert c01.shape == (1, 2)
273
274            c = Column([1, 2])
275            assert np.all(c[i0] == 1)
276            assert isinstance(c[i0], np.integer)
277            assert c[i0].shape == ()
278
279            c01 = c[i0:i1]
280            assert np.all(c01 == [1])
281            assert isinstance(c01, Column)
282            assert c01.shape == (1,)
283
284    def test_insert_basic(self, Column):
285        c = Column([0, 1, 2], name='a', dtype=int, unit='mJy', format='%i',
286                   description='test column', meta={'c': 8, 'd': 12})
287
288        # Basic insert
289        c1 = c.insert(1, 100)
290        assert np.all(c1 == [0, 100, 1, 2])
291        assert c1.attrs_equal(c)
292        assert type(c) is type(c1)
293        if hasattr(c1, 'mask'):
294            assert c1.data.shape == c1.mask.shape
295
296        c1 = c.insert(-1, 100)
297        assert np.all(c1 == [0, 1, 100, 2])
298
299        c1 = c.insert(3, 100)
300        assert np.all(c1 == [0, 1, 2, 100])
301
302        c1 = c.insert(-3, 100)
303        assert np.all(c1 == [100, 0, 1, 2])
304
305        c1 = c.insert(1, [100, 200, 300])
306        if hasattr(c1, 'mask'):
307            assert c1.data.shape == c1.mask.shape
308
309        # Out of bounds index
310        with pytest.raises((ValueError, IndexError)):
311            c1 = c.insert(-4, 100)
312        with pytest.raises((ValueError, IndexError)):
313            c1 = c.insert(4, 100)
314
315    def test_insert_axis(self, Column):
316        """Insert with non-default axis kwarg"""
317        c = Column([[1, 2], [3, 4]])
318
319        c1 = c.insert(1, [5, 6], axis=None)
320        assert np.all(c1 == [1, 5, 6, 2, 3, 4])
321
322        c1 = c.insert(1, [5, 6], axis=1)
323        assert np.all(c1 == [[1, 5, 2], [3, 6, 4]])
324
325    def test_insert_string_expand(self, Column):
326        c = Column(['a', 'b'])
327        c1 = c.insert(0, 'abc')
328        assert np.all(c1 == ['abc', 'a', 'b'])
329
330        c = Column(['a', 'b'])
331        c1 = c.insert(0, ['c', 'def'])
332        assert np.all(c1 == ['c', 'def', 'a', 'b'])
333
334    def test_insert_string_masked_values(self):
335        c = table.MaskedColumn(['a', 'b'])
336        c1 = c.insert(0, np.ma.masked)
337        assert np.all(c1 == ['', 'a', 'b'])
338        assert np.all(c1.mask == [True, False, False])
339        assert c1.dtype == 'U1'
340        c2 = c.insert(1, np.ma.MaskedArray(['ccc', 'dd'], mask=[True, False]))
341        assert np.all(c2 == ['a', 'ccc', 'dd', 'b'])
342        assert np.all(c2.mask == [False, True, False, False])
343        assert c2.dtype == 'U3'
344
345    def test_insert_string_type_error(self, Column):
346        c = Column([1, 2])
347        with pytest.raises(ValueError, match='invalid literal for int'):
348            c.insert(0, 'string')
349
350        c = Column(['a', 'b'])
351        with pytest.raises(TypeError, match='string operation on non-string array'):
352            c.insert(0, 1)
353
354    def test_insert_multidim(self, Column):
355        c = Column([[1, 2],
356                    [3, 4]], name='a', dtype=int)
357
358        # Basic insert
359        c1 = c.insert(1, [100, 200])
360        assert np.all(c1 == [[1, 2], [100, 200], [3, 4]])
361
362        # Broadcast
363        c1 = c.insert(1, 100)
364        assert np.all(c1 == [[1, 2], [100, 100], [3, 4]])
365
366        # Wrong shape
367        with pytest.raises(ValueError):
368            c1 = c.insert(1, [100, 200, 300])
369
370    def test_insert_object(self, Column):
371        c = Column(['a', 1, None], name='a', dtype=object)
372
373        # Basic insert
374        c1 = c.insert(1, [100, 200])
375        assert np.all(c1 == np.array(['a', [100, 200], 1, None],
376                                     dtype=object))
377
378    def test_insert_masked(self):
379        c = table.MaskedColumn([0, 1, 2], name='a', fill_value=9999,
380                               mask=[False, True, False])
381
382        # Basic insert
383        c1 = c.insert(1, 100)
384        assert np.all(c1.data.data == [0, 100, 1, 2])
385        assert c1.fill_value == 9999
386        assert np.all(c1.data.mask == [False, False, True, False])
387        assert type(c) is type(c1)
388
389        for mask in (False, True):
390            c1 = c.insert(1, 100, mask=mask)
391            assert np.all(c1.data.data == [0, 100, 1, 2])
392            assert np.all(c1.data.mask == [False, mask, True, False])
393
394    def test_masked_multidim_as_list(self):
395        data = np.ma.MaskedArray([1, 2], mask=[True, False])
396        c = table.MaskedColumn([data])
397        assert c.shape == (1, 2)
398        assert np.all(c[0].mask == [True, False])
399
400    def test_insert_masked_multidim(self):
401        c = table.MaskedColumn([[1, 2],
402                                [3, 4]], name='a', dtype=int)
403
404        c1 = c.insert(1, [100, 200], mask=True)
405        assert np.all(c1.data.data == [[1, 2], [100, 200], [3, 4]])
406        assert np.all(c1.data.mask == [[False, False], [True, True], [False, False]])
407
408        c1 = c.insert(1, [100, 200], mask=[True, False])
409        assert np.all(c1.data.data == [[1, 2], [100, 200], [3, 4]])
410        assert np.all(c1.data.mask == [[False, False], [True, False], [False, False]])
411
412        with pytest.raises(ValueError):
413            c1 = c.insert(1, [100, 200], mask=[True, False, True])
414
415    def test_mask_on_non_masked_table(self):
416        """
417        When table is not masked and trying to set mask on column then
418        it's Raise AttributeError.
419        """
420
421        t = table.Table([[1, 2], [3, 4]], names=('a', 'b'), dtype=('i4', 'f8'))
422
423        with pytest.raises(AttributeError):
424            t['a'].mask = [True, False]
425
426
427class TestAttrEqual():
428    """Bunch of tests originally from ATpy that test the attrs_equal method."""
429
430    def test_5(self, Column):
431        c1 = Column(name='a', dtype=int, unit='mJy')
432        c2 = Column(name='a', dtype=int, unit='mJy')
433        assert c1.attrs_equal(c2)
434
435    def test_6(self, Column):
436        c1 = Column(name='a', dtype=int, unit='mJy', format='%i',
437                    description='test column', meta={'c': 8, 'd': 12})
438        c2 = Column(name='a', dtype=int, unit='mJy', format='%i',
439                    description='test column', meta={'c': 8, 'd': 12})
440        assert c1.attrs_equal(c2)
441
442    def test_7(self, Column):
443        c1 = Column(name='a', dtype=int, unit='mJy', format='%i',
444                    description='test column', meta={'c': 8, 'd': 12})
445        c2 = Column(name='b', dtype=int, unit='mJy', format='%i',
446                    description='test column', meta={'c': 8, 'd': 12})
447        assert not c1.attrs_equal(c2)
448
449    def test_8(self, Column):
450        c1 = Column(name='a', dtype=int, unit='mJy', format='%i',
451                    description='test column', meta={'c': 8, 'd': 12})
452        c2 = Column(name='a', dtype=float, unit='mJy', format='%i',
453                    description='test column', meta={'c': 8, 'd': 12})
454        assert not c1.attrs_equal(c2)
455
456    def test_9(self, Column):
457        c1 = Column(name='a', dtype=int, unit='mJy', format='%i',
458                    description='test column', meta={'c': 8, 'd': 12})
459        c2 = Column(name='a', dtype=int, unit='erg.cm-2.s-1.Hz-1', format='%i',
460                    description='test column', meta={'c': 8, 'd': 12})
461        assert not c1.attrs_equal(c2)
462
463    def test_10(self, Column):
464        c1 = Column(name='a', dtype=int, unit='mJy', format='%i',
465                    description='test column', meta={'c': 8, 'd': 12})
466        c2 = Column(name='a', dtype=int, unit='mJy', format='%g',
467                    description='test column', meta={'c': 8, 'd': 12})
468        assert not c1.attrs_equal(c2)
469
470    def test_11(self, Column):
471        c1 = Column(name='a', dtype=int, unit='mJy', format='%i',
472                    description='test column', meta={'c': 8, 'd': 12})
473        c2 = Column(name='a', dtype=int, unit='mJy', format='%i',
474                    description='another test column', meta={'c': 8, 'd': 12})
475        assert not c1.attrs_equal(c2)
476
477    def test_12(self, Column):
478        c1 = Column(name='a', dtype=int, unit='mJy', format='%i',
479                    description='test column', meta={'c': 8, 'd': 12})
480        c2 = Column(name='a', dtype=int, unit='mJy', format='%i',
481                    description='test column', meta={'e': 8, 'd': 12})
482        assert not c1.attrs_equal(c2)
483
484    def test_13(self, Column):
485        c1 = Column(name='a', dtype=int, unit='mJy', format='%i',
486                    description='test column', meta={'c': 8, 'd': 12})
487        c2 = Column(name='a', dtype=int, unit='mJy', format='%i',
488                    description='test column', meta={'c': 9, 'd': 12})
489        assert not c1.attrs_equal(c2)
490
491    def test_col_and_masked_col(self):
492        c1 = table.Column(name='a', dtype=int, unit='mJy', format='%i',
493                          description='test column', meta={'c': 8, 'd': 12})
494        c2 = table.MaskedColumn(name='a', dtype=int, unit='mJy', format='%i',
495                                description='test column', meta={'c': 8, 'd': 12})
496        assert c1.attrs_equal(c2)
497        assert c2.attrs_equal(c1)
498
499# Check that the meta descriptor is working as expected. The MetaBaseTest class
500# takes care of defining all the tests, and we simply have to define the class
501# and any minimal set of args to pass.
502
503
504class TestMetaColumn(MetaBaseTest):
505    test_class = table.Column
506    args = ()
507
508
509class TestMetaMaskedColumn(MetaBaseTest):
510    test_class = table.MaskedColumn
511    args = ()
512
513
514def test_getitem_metadata_regression():
515    """
516    Regression test for #1471: MaskedArray does not call __array_finalize__ so
517    the meta-data was not getting copied over. By overloading _update_from we
518    are able to work around this bug.
519    """
520
521    # Make sure that meta-data gets propagated with __getitem__
522
523    c = table.Column(data=[1, 2], name='a', description='b', unit='m', format="%i", meta={'c': 8})
524    assert c[1:2].name == 'a'
525    assert c[1:2].description == 'b'
526    assert c[1:2].unit == 'm'
527    assert c[1:2].format == '%i'
528    assert c[1:2].meta['c'] == 8
529
530    c = table.MaskedColumn(data=[1, 2], name='a', description='b',
531                           unit='m', format="%i", meta={'c': 8})
532    assert c[1:2].name == 'a'
533    assert c[1:2].description == 'b'
534    assert c[1:2].unit == 'm'
535    assert c[1:2].format == '%i'
536    assert c[1:2].meta['c'] == 8
537
538    # As above, but with take() - check the method and the function
539
540    c = table.Column(data=[1, 2, 3], name='a', description='b',
541                     unit='m', format="%i", meta={'c': 8})
542    for subset in [c.take([0, 1]), np.take(c, [0, 1])]:
543        assert subset.name == 'a'
544        assert subset.description == 'b'
545        assert subset.unit == 'm'
546        assert subset.format == '%i'
547        assert subset.meta['c'] == 8
548
549    # Metadata isn't copied for scalar values
550    for subset in [c.take(0), np.take(c, 0)]:
551        assert subset == 1
552        assert subset.shape == ()
553        assert not isinstance(subset, table.Column)
554
555    c = table.MaskedColumn(data=[1, 2, 3], name='a', description='b',
556                           unit='m', format="%i", meta={'c': 8})
557    for subset in [c.take([0, 1]), np.take(c, [0, 1])]:
558        assert subset.name == 'a'
559        assert subset.description == 'b'
560        assert subset.unit == 'm'
561        assert subset.format == '%i'
562        assert subset.meta['c'] == 8
563
564    # Metadata isn't copied for scalar values
565    for subset in [c.take(0), np.take(c, 0)]:
566        assert subset == 1
567        assert subset.shape == ()
568        assert not isinstance(subset, table.MaskedColumn)
569
570
571def test_unicode_guidelines():
572    arr = np.array([1, 2, 3])
573    c = table.Column(arr, name='a')
574
575    assert_follows_unicode_guidelines(c)
576
577
578def test_scalar_column():
579    """
580    Column is not designed to hold scalars, but for numpy 1.6 this can happen:
581
582      >> type(np.std(table.Column([1, 2])))
583      astropy.table.column.Column
584    """
585    c = table.Column(1.5)
586    assert repr(c) == '1.5'
587    assert str(c) == '1.5'
588
589
590def test_qtable_column_conversion():
591    """
592    Ensures that a QTable that gets assigned a unit switches to be Quantity-y
593    """
594    qtab = table.QTable([[1, 2], [3, 4.2]], names=['i', 'f'])
595
596    assert isinstance(qtab['i'], table.column.Column)
597    assert isinstance(qtab['f'], table.column.Column)
598
599    qtab['i'].unit = 'km/s'
600    assert isinstance(qtab['i'], u.Quantity)
601    assert isinstance(qtab['f'], table.column.Column)
602
603    # should follow from the above, but good to make sure as a #4497 regression test
604    assert isinstance(qtab['i'][0], u.Quantity)
605    assert isinstance(qtab[0]['i'], u.Quantity)
606    assert not isinstance(qtab['f'][0], u.Quantity)
607    assert not isinstance(qtab[0]['f'], u.Quantity)
608
609    # Regression test for #5342: if a function unit is assigned, the column
610    # should become the appropriate FunctionQuantity subclass.
611    qtab['f'].unit = u.dex(u.cm / u.s**2)
612    assert isinstance(qtab['f'], u.Dex)
613
614
615@pytest.mark.parametrize('masked', [True, False])
616def test_string_truncation_warning(masked):
617    """
618    Test warnings associated with in-place assignment to a string
619    column that results in truncation of the right hand side.
620    """
621    from inspect import currentframe, getframeinfo
622
623    t = table.Table([['aa', 'bb']], names=['a'], masked=masked)
624    t['a'][1] = 'cc'
625    t['a'][:] = 'dd'
626
627    with pytest.warns(table.StringTruncateWarning, match=r'truncated right side '
628                      r'string\(s\) longer than 2 character\(s\)') as w:
629        frameinfo = getframeinfo(currentframe())
630        t['a'][0] = 'eee'  # replace item with string that gets truncated
631    assert t['a'][0] == 'ee'
632    assert len(w) == 1
633
634    # Make sure the warning points back to the user code line
635    assert w[0].lineno == frameinfo.lineno + 1
636    assert 'test_column' in w[0].filename
637
638    with pytest.warns(table.StringTruncateWarning, match=r'truncated right side '
639                      r'string\(s\) longer than 2 character\(s\)') as w:
640        t['a'][:] = ['ff', 'ggg']  # replace item with string that gets truncated
641    assert np.all(t['a'] == ['ff', 'gg'])
642    assert len(w) == 1
643
644    # Test the obscure case of assigning from an array that was originally
645    # wider than any of the current elements (i.e. dtype is U4 but actual
646    # elements are U1 at the time of assignment).
647    val = np.array(['ffff', 'gggg'])
648    val[:] = ['f', 'g']
649    t['a'][:] = val
650    assert np.all(t['a'] == ['f', 'g'])
651
652
653def test_string_truncation_warning_masked():
654    """
655    Test warnings associated with in-place assignment to a string
656    to a masked column, specifically where the right hand side
657    contains np.ma.masked.
658    """
659
660    # Test for strings, but also cover assignment of np.ma.masked to
661    # int and float masked column setting.  This was previously only
662    # covered in an unrelated io.ascii test (test_line_endings) which
663    # showed an unexpected difference between handling of str and numeric
664    # masked arrays.
665    for values in (['a', 'b'], [1, 2], [1.0, 2.0]):
666        mc = table.MaskedColumn(values)
667
668        mc[1] = np.ma.masked
669        assert np.all(mc.mask == [False, True])
670
671        mc[:] = np.ma.masked
672        assert np.all(mc.mask == [True, True])
673
674    mc = table.MaskedColumn(['aa', 'bb'])
675
676    with pytest.warns(table.StringTruncateWarning, match=r'truncated right side '
677                      r'string\(s\) longer than 2 character\(s\)') as w:
678        mc[:] = [np.ma.masked, 'ggg']  # replace item with string that gets truncated
679    assert mc[1] == 'gg'
680    assert np.all(mc.mask == [True, False])
681    assert len(w) == 1
682
683
684@pytest.mark.parametrize('Column', (table.Column, table.MaskedColumn))
685def test_col_unicode_sandwich_create_from_str(Column):
686    """
687    Create a bytestring Column from strings (including unicode) in Py3.
688    """
689    # a-umlaut is a 2-byte character in utf-8, test fails with ascii encoding.
690    # Stress the system by injecting non-ASCII characters.
691    uba = 'bä'
692    c = Column([uba, 'def'], dtype='S')
693    assert c.dtype.char == 'S'
694    assert c[0] == uba
695    assert isinstance(c[0], str)
696    assert isinstance(c[:0], table.Column)
697    assert np.all(c[:2] == np.array([uba, 'def']))
698
699
700@pytest.mark.parametrize('Column', (table.Column, table.MaskedColumn))
701def test_col_unicode_sandwich_bytes_obj(Column):
702    """
703    Create a Column of dtype object with bytestring in it and make sure
704    it keeps the bytestring and not convert to str with accessed.
705    """
706    c = Column([None, b'def'])
707    assert c.dtype.char == 'O'
708    assert not c[0]
709    assert c[1] == b'def'
710    assert isinstance(c[1], bytes)
711    assert not isinstance(c[1], str)
712    assert isinstance(c[:0], table.Column)
713    assert np.all(c[:2] == np.array([None, b'def']))
714    assert not np.all(c[:2] == np.array([None, 'def']))
715
716
717@pytest.mark.parametrize('Column', (table.Column, table.MaskedColumn))
718def test_col_unicode_sandwich_bytes(Column):
719    """
720    Create a bytestring Column from bytes and ensure that it works in Python 3 in
721    a convenient way like in Python 2.
722    """
723    # a-umlaut is a 2-byte character in utf-8, test fails with ascii encoding.
724    # Stress the system by injecting non-ASCII characters.
725    uba = 'bä'
726    uba8 = uba.encode('utf-8')
727    c = Column([uba8, b'def'])
728    assert c.dtype.char == 'S'
729    assert c[0] == uba
730    assert isinstance(c[0], str)
731    assert isinstance(c[:0], table.Column)
732    assert np.all(c[:2] == np.array([uba, 'def']))
733
734    assert isinstance(c[:], table.Column)
735    assert c[:].dtype.char == 'S'
736
737    # Array / list comparisons
738    assert np.all(c == [uba, 'def'])
739
740    ok = c == [uba8, b'def']
741    assert type(ok) is type(c.data)  # noqa
742    assert ok.dtype.char == '?'
743    assert np.all(ok)
744
745    assert np.all(c == np.array([uba, 'def']))
746    assert np.all(c == np.array([uba8, b'def']))
747
748    # Scalar compare
749    cmps = (uba, uba8)
750    for cmp in cmps:
751        ok = c == cmp
752        assert type(ok) is type(c.data)  # noqa
753        assert np.all(ok == [True, False])
754
755
756def test_col_unicode_sandwich_unicode():
757    """
758    Sanity check that Unicode Column behaves normally.
759    """
760    # On Py2 the unicode must be ASCII-compatible, else the final test fails.
761    uba = 'bä'
762    uba8 = uba.encode('utf-8')
763
764    c = table.Column([uba, 'def'], dtype='U')
765    assert c[0] == uba
766    assert isinstance(c[:0], table.Column)
767    assert isinstance(c[0], str)
768    assert np.all(c[:2] == np.array([uba, 'def']))
769
770    assert isinstance(c[:], table.Column)
771    assert c[:].dtype.char == 'U'
772
773    ok = c == [uba, 'def']
774    assert type(ok) == np.ndarray
775    assert ok.dtype.char == '?'
776    assert np.all(ok)
777
778    assert np.all(c != [uba8, b'def'])
779
780
781def test_masked_col_unicode_sandwich():
782    """
783    Create a bytestring MaskedColumn and ensure that it works in Python 3 in
784    a convenient way like in Python 2.
785    """
786    c = table.MaskedColumn([b'abc', b'def'])
787    c[1] = np.ma.masked
788    assert isinstance(c[:0], table.MaskedColumn)
789    assert isinstance(c[0], str)
790
791    assert c[0] == 'abc'
792    assert c[1] is np.ma.masked
793
794    assert isinstance(c[:], table.MaskedColumn)
795    assert c[:].dtype.char == 'S'
796
797    ok = c == ['abc', 'def']
798    assert ok[0] == True  # noqa
799    assert ok[1] is np.ma.masked
800    assert np.all(c == [b'abc', b'def'])
801    assert np.all(c == np.array(['abc', 'def']))
802    assert np.all(c == np.array([b'abc', b'def']))
803
804    for cmp in ('abc', b'abc'):
805        ok = c == cmp
806        assert type(ok) is np.ma.MaskedArray
807        assert ok[0] == True  # noqa
808        assert ok[1] is np.ma.masked
809
810
811@pytest.mark.parametrize('Column', (table.Column, table.MaskedColumn))
812def test_unicode_sandwich_set(Column):
813    """
814    Test setting
815    """
816    uba = 'bä'
817
818    c = Column([b'abc', b'def'])
819
820    c[0] = b'aa'
821    assert np.all(c == ['aa', 'def'])
822
823    c[0] = uba  # a-umlaut is a 2-byte character in utf-8, test fails with ascii encoding
824    assert np.all(c == [uba, 'def'])
825    assert c.pformat() == ['None', '----', '  ' + uba, ' def']
826
827    c[:] = b'cc'
828    assert np.all(c == ['cc', 'cc'])
829
830    c[:] = uba
831    assert np.all(c == [uba, uba])
832
833    c[:] = ''
834    c[:] = [uba, b'def']
835    assert np.all(c == [uba, b'def'])
836
837
838@pytest.mark.parametrize('class1', [table.MaskedColumn, table.Column])
839@pytest.mark.parametrize('class2', [table.MaskedColumn, table.Column, str, list])
840def test_unicode_sandwich_compare(class1, class2):
841    """Test that comparing a bytestring Column/MaskedColumn with various
842    str (unicode) object types gives the expected result.  Tests #6838.
843    """
844    obj1 = class1([b'a', b'c'])
845    if class2 is str:
846        obj2 = 'a'
847    elif class2 is list:
848        obj2 = ['a', 'b']
849    else:
850        obj2 = class2(['a', 'b'])
851
852    assert np.all((obj1 == obj2) == [True, False])
853    assert np.all((obj2 == obj1) == [True, False])
854
855    assert np.all((obj1 != obj2) == [False, True])
856    assert np.all((obj2 != obj1) == [False, True])
857
858    assert np.all((obj1 > obj2) == [False, True])
859    assert np.all((obj2 > obj1) == [False, False])
860
861    assert np.all((obj1 <= obj2) == [True, False])
862    assert np.all((obj2 <= obj1) == [True, True])
863
864    assert np.all((obj1 < obj2) == [False, False])
865    assert np.all((obj2 < obj1) == [False, True])
866
867    assert np.all((obj1 >= obj2) == [True, True])
868    assert np.all((obj2 >= obj1) == [True, False])
869
870
871def test_unicode_sandwich_masked_compare():
872    """Test the fix for #6839 from #6899."""
873    c1 = table.MaskedColumn(['a', 'b', 'c', 'd'],
874                            mask=[True, False, True, False])
875    c2 = table.MaskedColumn([b'a', b'b', b'c', b'd'],
876                            mask=[True, True, False, False])
877
878    for cmp in ((c1 == c2), (c2 == c1)):
879        assert cmp[0] is np.ma.masked
880        assert cmp[1] is np.ma.masked
881        assert cmp[2] is np.ma.masked
882        assert cmp[3]
883
884    for cmp in ((c1 != c2), (c2 != c1)):
885        assert cmp[0] is np.ma.masked
886        assert cmp[1] is np.ma.masked
887        assert cmp[2] is np.ma.masked
888        assert not cmp[3]
889
890    # Note: comparisons <, >, >=, <= fail to return a masked array entirely,
891    # see https://github.com/numpy/numpy/issues/10092.
892
893
894def test_structured_masked_column_roundtrip():
895    mc = table.MaskedColumn([(1., 2.), (3., 4.)],
896                            mask=[(False, False), (False, False)], dtype='f8,f8')
897    assert len(mc.dtype.fields) == 2
898    mc2 = table.MaskedColumn(mc)
899    assert_array_equal(mc2, mc)
900
901
902@pytest.mark.parametrize('dtype', ['i4,f4', 'f4,(2,)f8'])
903def test_structured_empty_column_init(dtype):
904    dtype = np.dtype(dtype)
905    c = table.Column(length=5, shape=(2,), dtype=dtype)
906    assert c.shape == (5, 2)
907    assert c.dtype == dtype
908
909
910def test_column_value_access():
911    """Can a column's underlying data consistently be accessed via `.value`,
912    whether it is a `Column`, `MaskedColumn`, `Quantity`, or `Time`?"""
913    data = np.array([1, 2, 3])
914    tbl = table.QTable({'a': table.Column(data),
915                        'b': table.MaskedColumn(data),
916                        'c': u.Quantity(data),
917                        'd': time.Time(data, format='mjd')})
918    assert type(tbl['a'].value) == np.ndarray
919    assert type(tbl['b'].value) == np.ma.MaskedArray
920    assert type(tbl['c'].value) == np.ndarray
921    assert type(tbl['d'].value) == np.ndarray
922
923
924def test_masked_column_serialize_method_propagation():
925    mc = table.MaskedColumn([1., 2., 3.], mask=[True, False, True])
926    assert mc.info.serialize_method['ecsv'] == 'null_value'
927    mc.info.serialize_method['ecsv'] = 'data_mask'
928    assert mc.info.serialize_method['ecsv'] == 'data_mask'
929    mc2 = mc.copy()
930    assert mc2.info.serialize_method['ecsv'] == 'data_mask'
931    mc3 = table.MaskedColumn(mc)
932    assert mc3.info.serialize_method['ecsv'] == 'data_mask'
933    mc4 = mc.view(table.MaskedColumn)
934    assert mc4.info.serialize_method['ecsv'] == 'data_mask'
935    mc5 = mc[1:]
936    assert mc5.info.serialize_method['ecsv'] == 'data_mask'
937