1# Licensed under a 3-clause BSD style license - see LICENSE.rst 2 3from datetime import date 4from itertools import count 5 6import pytest 7 8import numpy as np 9from erfa import DJM0 10 11from astropy.time import Time, TimeFormat 12from astropy.time.utils import day_frac 13 14 15class SpecificException(ValueError): 16 pass 17 18 19@pytest.fixture 20def custom_format_name(): 21 for i in count(): 22 if not i: 23 custom = f"custom_format_name" 24 else: 25 custom = f"custom_format_name_{i}" 26 if custom not in Time.FORMATS: 27 break 28 yield custom 29 Time.FORMATS.pop(custom, None) 30 31 32def test_custom_time_format_set_jds_exception(custom_format_name): 33 class Custom(TimeFormat): 34 name = custom_format_name 35 36 def set_jds(self, val, val2): 37 raise SpecificException 38 39 try: 40 Time(7.0, format=custom_format_name) 41 except ValueError as e: 42 assert hasattr(e, "__cause__") and isinstance(e.__cause__, SpecificException) 43 44 45def test_custom_time_format_val_type_exception(custom_format_name): 46 class Custom(TimeFormat): 47 name = custom_format_name 48 49 def _check_val_type(self, val, val2): 50 raise SpecificException 51 52 try: 53 Time(7.0, format=custom_format_name) 54 except ValueError as e: 55 assert hasattr(e, "__cause__") and isinstance(e.__cause__, SpecificException) 56 57 58def test_custom_time_format_value_exception(custom_format_name): 59 class Custom(TimeFormat): 60 name = custom_format_name 61 62 def set_jds(self, val, val2): 63 self.jd1, self.jd2 = val, val2 64 65 @property 66 def value(self): 67 raise SpecificException 68 69 t = Time.now() 70 with pytest.raises(SpecificException): 71 getattr(t, custom_format_name) 72 73 74def test_custom_time_format_fine(custom_format_name): 75 class Custom(TimeFormat): 76 name = custom_format_name 77 78 def set_jds(self, val, val2): 79 self.jd1, self.jd2 = val, val2 80 81 @property 82 def value(self): 83 return self.jd1 + self.jd2 84 85 t = Time.now() 86 getattr(t, custom_format_name) 87 t2 = Time(7, 9, format=custom_format_name) 88 getattr(t2, custom_format_name) 89 90 91def test_custom_time_format_forgot_property(custom_format_name): 92 with pytest.raises(ValueError): 93 class Custom(TimeFormat): 94 name = custom_format_name 95 96 def set_jds(self, val, val2): 97 self.jd1, self.jd2 = val, val2 98 99 def value(self): 100 return self.jd1, self.jd2 101 102 103def test_custom_time_format_problematic_name(): 104 assert "sort" not in Time.FORMATS, "problematic name in default FORMATS!" 105 assert hasattr(Time, "sort") 106 107 try: 108 109 class Custom(TimeFormat): 110 name = "sort" 111 _dtype = np.dtype([('jd1', 'f8'), ('jd2', 'f8')]) 112 113 def set_jds(self, val, val2): 114 self.jd1, self.jd2 = val, val2 115 116 @property 117 def value(self): 118 result = np.empty(self.jd1.shape, self._dtype) 119 result['jd1'] = self.jd1 120 result['jd2'] = self.jd2 121 return result 122 123 t = Time.now() 124 assert t.sort() == t, "bogus time format clobbers everyone's Time objects" 125 126 t.format = "sort" 127 assert t.value.dtype == Custom._dtype 128 129 t2 = Time(7, 9, format="sort") 130 assert t2.value == np.array((7, 9), Custom._dtype) 131 132 finally: 133 Time.FORMATS.pop("sort", None) 134 135 136def test_mjd_longdouble_preserves_precision(custom_format_name): 137 class CustomMJD(TimeFormat): 138 name = custom_format_name 139 140 def _check_val_type(self, val, val2): 141 val = np.longdouble(val) 142 if val2 is not None: 143 raise ValueError("Only one value permitted") 144 return val, 0 145 146 def set_jds(self, val, val2): 147 mjd1 = np.float64(np.floor(val)) 148 mjd2 = np.float64(val - mjd1) 149 self.jd1, self.jd2 = day_frac(mjd1 + DJM0, mjd2) 150 151 @property 152 def value(self): 153 mjd1, mjd2 = day_frac(self.jd1 - DJM0, self.jd2) 154 return np.longdouble(mjd1) + np.longdouble(mjd2) 155 156 m = 58000.0 157 t = Time(m, format=custom_format_name) 158 # Pick a different long double (ensuring it will give a different jd2 159 # even when long doubles are more precise than Time, as on arm64). 160 m2 = np.longdouble(m) + max(2. * m * np.finfo(np.longdouble).eps, 161 np.finfo(float).eps) 162 assert m2 != m, 'long double is weird!' 163 t2 = Time(m2, format=custom_format_name) 164 assert t != t2 165 assert isinstance(getattr(t, custom_format_name), np.longdouble) 166 assert getattr(t, custom_format_name) != getattr(t2, custom_format_name) 167 168 169@pytest.mark.parametrize( 170 "jd1, jd2", 171 [ 172 ("foo", None), 173 (np.arange(3), np.arange(4)), 174 ("foo", "bar"), 175 (1j, 2j), 176 pytest.param( 177 np.longdouble(3), np.longdouble(5), 178 marks=pytest.mark.skipif( 179 np.longdouble().itemsize == np.dtype(float).itemsize, 180 reason="long double == double on this platform")), 181 ({1: 2}, {3: 4}), 182 ({1, 2}, {3, 4}), 183 ([1, 2], [3, 4]), 184 (lambda: 4, lambda: 7), 185 (np.arange(3), np.arange(4)), 186 ], 187) 188def test_custom_format_cannot_make_bogus_jd1(custom_format_name, jd1, jd2): 189 class Custom(TimeFormat): 190 name = custom_format_name 191 192 def set_jds(self, val, val2): 193 self.jd1, self.jd2 = jd1, jd2 194 195 @property 196 def value(self): 197 return self.jd1 + self.jd2 198 199 with pytest.raises((ValueError, TypeError)): 200 Time(5, format=custom_format_name) 201 202 203def test_custom_format_scalar_jd1_jd2_okay(custom_format_name): 204 class Custom(TimeFormat): 205 name = custom_format_name 206 207 def set_jds(self, val, val2): 208 self.jd1, self.jd2 = 7.0, 3.0 209 210 @property 211 def value(self): 212 return self.jd1 + self.jd2 213 214 getattr(Time(5, format=custom_format_name), custom_format_name) 215 216 217@pytest.mark.parametrize( 218 "thing", 219 [ 220 1, 221 1.0, 222 np.longdouble(1), 223 1.0j, 224 "foo", 225 b"foo", 226 Time(5, format="mjd"), 227 lambda: 7, 228 np.datetime64('2005-02-25'), 229 date(2006, 2, 25), 230 ], 231) 232def test_custom_format_can_return_any_scalar(custom_format_name, thing): 233 class Custom(TimeFormat): 234 name = custom_format_name 235 236 def set_jds(self, val, val2): 237 self.jd1, self.jd2 = 2., 0. 238 239 @property 240 def value(self): 241 return np.array(thing) 242 243 assert type(getattr(Time(5, format=custom_format_name), 244 custom_format_name)) == type(thing) 245 assert np.all(getattr(Time(5, format=custom_format_name), 246 custom_format_name) == thing) 247 248 249@pytest.mark.parametrize( 250 "thing", 251 [ 252 (1, 2), 253 [1, 2], 254 np.array([2, 3]), 255 np.array([2, 3, 5, 7]), 256 {6: 7}, 257 {1, 2}, 258 ], 259) 260def test_custom_format_can_return_any_iterable(custom_format_name, thing): 261 class Custom(TimeFormat): 262 name = custom_format_name 263 264 def set_jds(self, val, val2): 265 self.jd1, self.jd2 = 2., 0. 266 267 @property 268 def value(self): 269 return thing 270 271 assert type(getattr(Time(5, format=custom_format_name), 272 custom_format_name)) == type(thing) 273 assert np.all(getattr(Time(5, format=custom_format_name), 274 custom_format_name) == thing) 275