1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2from concurrent.futures import ThreadPoolExecutor
3from datetime import datetime, timedelta
4
5import pytest
6import erfa
7
8from astropy.utils import iers
9from astropy.utils.exceptions import AstropyWarning
10
11import astropy.time.core
12from astropy.time import update_leap_seconds, Time
13
14
15class TestUpdateLeapSeconds:
16    def setup(self):
17        self.built_in = iers.LeapSeconds.from_iers_leap_seconds()
18        self.erfa_ls = iers.LeapSeconds.from_erfa()
19        now = datetime.now()
20        self.good_enough = now + timedelta(150)
21
22    def teardown(self):
23        self.erfa_ls.update_erfa_leap_seconds(initialize_erfa=True)
24
25    def test_auto_update_leap_seconds(self):
26        # Sanity check.
27        assert erfa.dat(2018, 1, 1, 0.) == 37.0
28        # Set expired leap seconds
29        expired = self.erfa_ls[self.erfa_ls['year'] < 2017]
30        expired.update_erfa_leap_seconds(initialize_erfa='empty')
31        # Check the 2017 leap second is indeed missing.
32        assert erfa.dat(2018, 1, 1, 0.) == 36.0
33
34        # Update with missing leap seconds.
35        n_update = update_leap_seconds([iers.IERS_LEAP_SECOND_FILE])
36        assert n_update >= 1
37        assert erfa.leap_seconds.expires == self.built_in.expires
38        assert erfa.dat(2018, 1, 1, 0.) == 37.0
39
40        # Doing it again does not change anything
41        n_update2 = update_leap_seconds([iers.IERS_LEAP_SECOND_FILE])
42        assert n_update2 == 0
43        assert erfa.dat(2018, 1, 1, 0.) == 37.0
44
45    @pytest.mark.remote_data
46    def test_never_expired_if_connected(self):
47        assert self.erfa_ls.expires > datetime.now()
48        assert self.erfa_ls.expires >= self.good_enough
49
50    @pytest.mark.remote_data
51    def test_auto_update_always_good(self):
52        self.erfa_ls.update_erfa_leap_seconds(initialize_erfa='only')
53        update_leap_seconds()
54        assert not erfa.leap_seconds.expired
55        assert erfa.leap_seconds.expires > self.good_enough
56
57    def test_auto_update_bad_file(self):
58        with pytest.warns(AstropyWarning, match='FileNotFound'):
59            update_leap_seconds(['nonsense'])
60
61    def test_auto_update_corrupt_file(self, tmpdir):
62        bad_file = str(tmpdir.join('no_expiration'))
63        with open(iers.IERS_LEAP_SECOND_FILE) as fh:
64
65            lines = fh.readlines()
66        with open(bad_file, 'w') as fh:
67            fh.write('\n'.join([line for line in lines
68                                if not line.startswith('#')]))
69
70        with pytest.warns(AstropyWarning,
71                          match='ValueError.*did not find expiration'):
72            update_leap_seconds([bad_file])
73
74    def test_auto_update_expired_file(self, tmpdir):
75        # Set up expired ERFA leap seconds.
76        expired = self.erfa_ls[self.erfa_ls['year'] < 2017]
77        expired.update_erfa_leap_seconds(initialize_erfa='empty')
78        # Create similarly expired file.
79        expired_file = str(tmpdir.join('expired.dat'))
80        with open(expired_file, 'w') as fh:
81            fh.write('\n'.join(['# File expires on 28 June 2010']
82                               + [str(item) for item in expired]))
83
84        with pytest.warns(iers.IERSStaleWarning):
85            update_leap_seconds(['erfa', expired_file])
86
87    def test_init_thread_safety(self, monkeypatch):
88        # Set up expired ERFA leap seconds.
89        expired = self.erfa_ls[self.erfa_ls['year'] < 2017]
90        expired.update_erfa_leap_seconds(initialize_erfa='empty')
91        # Force re-initialization, even if another test already did it
92        monkeypatch.setattr(astropy.time.core, '_LEAP_SECONDS_CHECK',
93                            astropy.time.core._LeapSecondsCheck.NOT_STARTED)
94        workers = 4
95        with ThreadPoolExecutor(max_workers=workers) as executor:
96            futures = [executor.submit(lambda: str(Time('2019-01-01 00:00:00.000').tai))
97                       for i in range(workers)]
98            results = [future.result() for future in futures]
99            assert results == ['2019-01-01 00:00:37.000'] * workers
100