1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3import pytest
4import numpy as np
5
6from astropy.time import Time, TimeDelta
7from astropy.units.quantity_helper.function_helpers import ARRAY_FUNCTION_ENABLED
8
9
10class TestFunctionsTime:
11
12    def setup_class(cls):
13        cls.t = Time(50000, np.arange(8).reshape(4, 2), format='mjd',
14                     scale='tai')
15
16    def check(self, func, cls=None, scale=None, format=None, *args, **kwargs):
17        if cls is None:
18            cls = self.t.__class__
19        if scale is None:
20            scale = self.t.scale
21        if format is None:
22            format = self.t.format
23        out = func(self.t, *args, **kwargs)
24        jd1 = func(self.t.jd1, *args, **kwargs)
25        jd2 = func(self.t.jd2, *args, **kwargs)
26        expected = cls(jd1, jd2, format=format, scale=scale)
27        if isinstance(out, np.ndarray):
28            expected = np.array(expected)
29
30        assert np.all(out == expected)
31
32    @pytest.mark.parametrize('axis', (0, 1))
33    def test_diff(self, axis):
34        self.check(np.diff, axis=axis, cls=TimeDelta, format='jd')
35
36
37class TestFunctionsTimeDelta(TestFunctionsTime):
38
39    def setup_class(cls):
40        cls.t = TimeDelta(np.arange(8).reshape(4, 2), format='jd',
41                          scale='tai')
42
43    @pytest.mark.parametrize('axis', (0, 1, None))
44    @pytest.mark.parametrize('func', (np.sum, np.mean, np.median))
45    def test_sum_like(self, func, axis):
46        self.check(func, axis=axis)
47
48
49@pytest.mark.xfail(not ARRAY_FUNCTION_ENABLED,
50                   reason="Needs __array_function__ support")
51@pytest.mark.parametrize('attribute', ['shape', 'ndim', 'size'])
52@pytest.mark.parametrize('t', [
53    Time('2001-02-03T04:05:06'),
54    Time(50000, np.arange(8).reshape(4, 2), format='mjd', scale='tai'),
55    TimeDelta(100, format='jd')])
56def test_shape_attribute_functions(t, attribute):
57    # Regression test for
58    # https://github.com/astropy/astropy/issues/8610#issuecomment-736855217
59    function = getattr(np, attribute)
60    result = function(t)
61    assert result == getattr(t, attribute)
62