1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3from datetime import date
4from itertools import count
5
6import pytest
7
8import numpy as np
9from erfa import DJM0
10
11from astropy.time import Time, TimeFormat
12from astropy.time.utils import day_frac
13
14
15class SpecificException(ValueError):
16    pass
17
18
19@pytest.fixture
20def custom_format_name():
21    for i in count():
22        if not i:
23            custom = f"custom_format_name"
24        else:
25            custom = f"custom_format_name_{i}"
26        if custom not in Time.FORMATS:
27            break
28    yield custom
29    Time.FORMATS.pop(custom, None)
30
31
32def test_custom_time_format_set_jds_exception(custom_format_name):
33    class Custom(TimeFormat):
34        name = custom_format_name
35
36        def set_jds(self, val, val2):
37            raise SpecificException
38
39    try:
40        Time(7.0, format=custom_format_name)
41    except ValueError as e:
42        assert hasattr(e, "__cause__") and isinstance(e.__cause__, SpecificException)
43
44
45def test_custom_time_format_val_type_exception(custom_format_name):
46    class Custom(TimeFormat):
47        name = custom_format_name
48
49        def _check_val_type(self, val, val2):
50            raise SpecificException
51
52    try:
53        Time(7.0, format=custom_format_name)
54    except ValueError as e:
55        assert hasattr(e, "__cause__") and isinstance(e.__cause__, SpecificException)
56
57
58def test_custom_time_format_value_exception(custom_format_name):
59    class Custom(TimeFormat):
60        name = custom_format_name
61
62        def set_jds(self, val, val2):
63            self.jd1, self.jd2 = val, val2
64
65        @property
66        def value(self):
67            raise SpecificException
68
69    t = Time.now()
70    with pytest.raises(SpecificException):
71        getattr(t, custom_format_name)
72
73
74def test_custom_time_format_fine(custom_format_name):
75    class Custom(TimeFormat):
76        name = custom_format_name
77
78        def set_jds(self, val, val2):
79            self.jd1, self.jd2 = val, val2
80
81        @property
82        def value(self):
83            return self.jd1 + self.jd2
84
85    t = Time.now()
86    getattr(t, custom_format_name)
87    t2 = Time(7, 9, format=custom_format_name)
88    getattr(t2, custom_format_name)
89
90
91def test_custom_time_format_forgot_property(custom_format_name):
92    with pytest.raises(ValueError):
93        class Custom(TimeFormat):
94            name = custom_format_name
95
96            def set_jds(self, val, val2):
97                self.jd1, self.jd2 = val, val2
98
99            def value(self):
100                return self.jd1, self.jd2
101
102
103def test_custom_time_format_problematic_name():
104    assert "sort" not in Time.FORMATS, "problematic name in default FORMATS!"
105    assert hasattr(Time, "sort")
106
107    try:
108
109        class Custom(TimeFormat):
110            name = "sort"
111            _dtype = np.dtype([('jd1', 'f8'), ('jd2', 'f8')])
112
113            def set_jds(self, val, val2):
114                self.jd1, self.jd2 = val, val2
115
116            @property
117            def value(self):
118                result = np.empty(self.jd1.shape, self._dtype)
119                result['jd1'] = self.jd1
120                result['jd2'] = self.jd2
121                return result
122
123        t = Time.now()
124        assert t.sort() == t, "bogus time format clobbers everyone's Time objects"
125
126        t.format = "sort"
127        assert t.value.dtype == Custom._dtype
128
129        t2 = Time(7, 9, format="sort")
130        assert t2.value == np.array((7, 9), Custom._dtype)
131
132    finally:
133        Time.FORMATS.pop("sort", None)
134
135
136def test_mjd_longdouble_preserves_precision(custom_format_name):
137    class CustomMJD(TimeFormat):
138        name = custom_format_name
139
140        def _check_val_type(self, val, val2):
141            val = np.longdouble(val)
142            if val2 is not None:
143                raise ValueError("Only one value permitted")
144            return val, 0
145
146        def set_jds(self, val, val2):
147            mjd1 = np.float64(np.floor(val))
148            mjd2 = np.float64(val - mjd1)
149            self.jd1, self.jd2 = day_frac(mjd1 + DJM0, mjd2)
150
151        @property
152        def value(self):
153            mjd1, mjd2 = day_frac(self.jd1 - DJM0, self.jd2)
154            return np.longdouble(mjd1) + np.longdouble(mjd2)
155
156    m = 58000.0
157    t = Time(m, format=custom_format_name)
158    # Pick a different long double (ensuring it will give a different jd2
159    # even when long doubles are more precise than Time, as on arm64).
160    m2 = np.longdouble(m) + max(2. * m * np.finfo(np.longdouble).eps,
161                                np.finfo(float).eps)
162    assert m2 != m, 'long double is weird!'
163    t2 = Time(m2, format=custom_format_name)
164    assert t != t2
165    assert isinstance(getattr(t, custom_format_name), np.longdouble)
166    assert getattr(t, custom_format_name) != getattr(t2, custom_format_name)
167
168
169@pytest.mark.parametrize(
170    "jd1, jd2",
171    [
172        ("foo", None),
173        (np.arange(3), np.arange(4)),
174        ("foo", "bar"),
175        (1j, 2j),
176        pytest.param(
177            np.longdouble(3), np.longdouble(5),
178            marks=pytest.mark.skipif(
179                np.longdouble().itemsize == np.dtype(float).itemsize,
180                reason="long double == double on this platform")),
181        ({1: 2}, {3: 4}),
182        ({1, 2}, {3, 4}),
183        ([1, 2], [3, 4]),
184        (lambda: 4, lambda: 7),
185        (np.arange(3), np.arange(4)),
186    ],
187)
188def test_custom_format_cannot_make_bogus_jd1(custom_format_name, jd1, jd2):
189    class Custom(TimeFormat):
190        name = custom_format_name
191
192        def set_jds(self, val, val2):
193            self.jd1, self.jd2 = jd1, jd2
194
195        @property
196        def value(self):
197            return self.jd1 + self.jd2
198
199    with pytest.raises((ValueError, TypeError)):
200        Time(5, format=custom_format_name)
201
202
203def test_custom_format_scalar_jd1_jd2_okay(custom_format_name):
204    class Custom(TimeFormat):
205        name = custom_format_name
206
207        def set_jds(self, val, val2):
208            self.jd1, self.jd2 = 7.0, 3.0
209
210        @property
211        def value(self):
212            return self.jd1 + self.jd2
213
214    getattr(Time(5, format=custom_format_name), custom_format_name)
215
216
217@pytest.mark.parametrize(
218    "thing",
219    [
220        1,
221        1.0,
222        np.longdouble(1),
223        1.0j,
224        "foo",
225        b"foo",
226        Time(5, format="mjd"),
227        lambda: 7,
228        np.datetime64('2005-02-25'),
229        date(2006, 2, 25),
230    ],
231)
232def test_custom_format_can_return_any_scalar(custom_format_name, thing):
233    class Custom(TimeFormat):
234        name = custom_format_name
235
236        def set_jds(self, val, val2):
237            self.jd1, self.jd2 = 2., 0.
238
239        @property
240        def value(self):
241            return np.array(thing)
242
243    assert type(getattr(Time(5, format=custom_format_name),
244                        custom_format_name)) == type(thing)
245    assert np.all(getattr(Time(5, format=custom_format_name),
246                          custom_format_name) == thing)
247
248
249@pytest.mark.parametrize(
250    "thing",
251    [
252        (1, 2),
253        [1, 2],
254        np.array([2, 3]),
255        np.array([2, 3, 5, 7]),
256        {6: 7},
257        {1, 2},
258    ],
259)
260def test_custom_format_can_return_any_iterable(custom_format_name, thing):
261    class Custom(TimeFormat):
262        name = custom_format_name
263
264        def set_jds(self, val, val2):
265            self.jd1, self.jd2 = 2., 0.
266
267        @property
268        def value(self):
269            return thing
270
271    assert type(getattr(Time(5, format=custom_format_name),
272                        custom_format_name)) == type(thing)
273    assert np.all(getattr(Time(5, format=custom_format_name),
274                          custom_format_name) == thing)
275