1# Licensed under a 3-clause BSD style license - see LICENSE.rst 2 3import pytest 4 5from numpy.testing import assert_equal 6 7from astropy import units as u 8from astropy.table import Table, QTable, vstack, join 9from astropy.time import Time 10 11from astropy.timeseries.sampled import TimeSeries 12from astropy.timeseries.binned import BinnedTimeSeries 13 14 15INPUT_TIME = Time(['2016-03-22T12:30:31', '2015-01-21T12:30:32', '2016-03-22T12:30:40']) 16PLAIN_TABLE = Table([[1., 2., 11.], [3, 4, 1], ['x', 'y', 'z']], names=['a', 'b', 'c']) 17 18 19class CommonTimeSeriesTests: 20 21 def test_stacking(self): 22 ts = vstack([self.series, self.series]) 23 assert isinstance(ts, self.series.__class__) 24 25 def test_row_slicing(self): 26 ts = self.series[:2] 27 assert isinstance(ts, self.series.__class__) 28 29 def test_row_indexing(self): 30 self.series[0][self.time_attr] == Time('2015-01-21T12:30:32') 31 self.series[self.time_attr][0] == Time('2015-01-21T12:30:32') 32 33 def test_column_indexing(self): 34 assert_equal(self.series['a'], [1, 2, 11]) 35 36 def test_column_slicing_notime(self): 37 tab = self.series['a', 'b'] 38 assert not isinstance(tab, self.series.__class__) 39 assert isinstance(tab, QTable) 40 41 def test_add_column(self): 42 self.series['d'] = [1, 2, 3] 43 44 def test_add_row(self): 45 self.series.add_row(self._row) 46 47 def test_set_unit(self): 48 self.series['d'] = [1, 2, 3] 49 self.series['d'].unit = 's' 50 51 def test_replace_column(self): 52 self.series.replace_column('c', [1, 3, 4]) 53 54 def test_required_after_stacking(self): 55 # When stacking, we have to temporarily relax the checking of the 56 # columns in the time series, but we need to make sure that the 57 # checking works again afterwards 58 ts = vstack([self.series, self.series]) 59 with pytest.raises(ValueError) as exc: 60 ts.remove_columns(ts.colnames) 61 assert 'TimeSeries object is invalid' in exc.value.args[0] 62 63 def test_join(self): 64 ts_other = self.series.copy() 65 ts_other.add_row(self._row) 66 ts_other['d'] = [11, 22, 33, 44] 67 ts_other.remove_columns(['a', 'b']) 68 ts = join(self.series, ts_other) 69 assert len(ts) == len(self.series) 70 ts = join(self.series, ts_other, join_type='outer') 71 assert len(ts) == len(ts_other) 72 73 74class TestTimeSeries(CommonTimeSeriesTests): 75 76 _row = {'time': '2016-03-23T12:30:40', 'a': 1., 'b': 2, 'c': 'a'} 77 78 def setup_method(self, method): 79 self.series = TimeSeries(time=INPUT_TIME, data=PLAIN_TABLE) 80 self.time_attr = 'time' 81 82 def test_column_slicing(self): 83 ts = self.series['time', 'a'] 84 assert isinstance(ts, TimeSeries) 85 86 87class TestBinnedTimeSeries(CommonTimeSeriesTests): 88 89 _row = {'time_bin_start': '2016-03-23T12:30:40', 90 'time_bin_size': 2 * u.s, 'a': 1., 'b': 2, 'c': 'a'} 91 92 def setup_method(self, method): 93 self.series = BinnedTimeSeries(time_bin_start=INPUT_TIME, 94 time_bin_size=3 * u.s, 95 data=PLAIN_TABLE) 96 self.time_attr = 'time_bin_start' 97 98 def test_column_slicing(self): 99 ts = self.series['time_bin_start', 'time_bin_size', 'a'] 100 assert isinstance(ts, BinnedTimeSeries) 101