1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3import pytest
4
5from numpy.testing import assert_equal
6
7from astropy import units as u
8from astropy.table import Table, QTable, vstack, join
9from astropy.time import Time
10
11from astropy.timeseries.sampled import TimeSeries
12from astropy.timeseries.binned import BinnedTimeSeries
13
14
15INPUT_TIME = Time(['2016-03-22T12:30:31', '2015-01-21T12:30:32', '2016-03-22T12:30:40'])
16PLAIN_TABLE = Table([[1., 2., 11.], [3, 4, 1], ['x', 'y', 'z']], names=['a', 'b', 'c'])
17
18
19class CommonTimeSeriesTests:
20
21    def test_stacking(self):
22        ts = vstack([self.series, self.series])
23        assert isinstance(ts, self.series.__class__)
24
25    def test_row_slicing(self):
26        ts = self.series[:2]
27        assert isinstance(ts, self.series.__class__)
28
29    def test_row_indexing(self):
30        self.series[0][self.time_attr] == Time('2015-01-21T12:30:32')
31        self.series[self.time_attr][0] == Time('2015-01-21T12:30:32')
32
33    def test_column_indexing(self):
34        assert_equal(self.series['a'], [1, 2, 11])
35
36    def test_column_slicing_notime(self):
37        tab = self.series['a', 'b']
38        assert not isinstance(tab, self.series.__class__)
39        assert isinstance(tab, QTable)
40
41    def test_add_column(self):
42        self.series['d'] = [1, 2, 3]
43
44    def test_add_row(self):
45        self.series.add_row(self._row)
46
47    def test_set_unit(self):
48        self.series['d'] = [1, 2, 3]
49        self.series['d'].unit = 's'
50
51    def test_replace_column(self):
52        self.series.replace_column('c', [1, 3, 4])
53
54    def test_required_after_stacking(self):
55        # When stacking, we have to temporarily relax the checking of the
56        # columns in the time series, but we need to make sure that the
57        # checking works again afterwards
58        ts = vstack([self.series, self.series])
59        with pytest.raises(ValueError) as exc:
60            ts.remove_columns(ts.colnames)
61        assert 'TimeSeries object is invalid' in exc.value.args[0]
62
63    def test_join(self):
64        ts_other = self.series.copy()
65        ts_other.add_row(self._row)
66        ts_other['d'] = [11, 22, 33, 44]
67        ts_other.remove_columns(['a', 'b'])
68        ts = join(self.series, ts_other)
69        assert len(ts) == len(self.series)
70        ts = join(self.series, ts_other, join_type='outer')
71        assert len(ts) == len(ts_other)
72
73
74class TestTimeSeries(CommonTimeSeriesTests):
75
76    _row = {'time': '2016-03-23T12:30:40', 'a': 1., 'b': 2, 'c': 'a'}
77
78    def setup_method(self, method):
79        self.series = TimeSeries(time=INPUT_TIME, data=PLAIN_TABLE)
80        self.time_attr = 'time'
81
82    def test_column_slicing(self):
83        ts = self.series['time', 'a']
84        assert isinstance(ts, TimeSeries)
85
86
87class TestBinnedTimeSeries(CommonTimeSeriesTests):
88
89    _row = {'time_bin_start': '2016-03-23T12:30:40',
90            'time_bin_size': 2 * u.s, 'a': 1., 'b': 2, 'c': 'a'}
91
92    def setup_method(self, method):
93        self.series = BinnedTimeSeries(time_bin_start=INPUT_TIME,
94                                       time_bin_size=3 * u.s,
95                                       data=PLAIN_TABLE)
96        self.time_attr = 'time_bin_start'
97
98    def test_column_slicing(self):
99        ts = self.series['time_bin_start', 'time_bin_size', 'a']
100        assert isinstance(ts, BinnedTimeSeries)
101