1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3from math import inf
4
5import pytest
6
7import numpy as np
8
9from astropy.cosmology.utils import inf_like, vectorize_if_needed, vectorize_redshift_method
10from astropy.utils.exceptions import AstropyDeprecationWarning
11
12
13def test_vectorize_redshift_method():
14    """Test :func:`astropy.cosmology.utils.vectorize_redshift_method`."""
15    class Class:
16
17        @vectorize_redshift_method
18        def method(self, z):
19            return z
20
21    c = Class()
22
23    assert hasattr(c.method, "__vectorized__")
24    assert isinstance(c.method.__vectorized__, np.vectorize)
25
26    # calling with Number
27    assert c.method(1) == 1
28    assert isinstance(c.method(1), int)
29
30    # calling with a numpy scalar
31    assert c.method(np.float64(1)) == np.float64(1)
32    assert isinstance(c.method(np.float64(1)), np.float64)
33
34    # numpy array
35    assert all(c.method(np.array([1, 2])) == np.array([1, 2]))
36    assert isinstance(c.method(np.array([1, 2])), np.ndarray)
37
38    # non-scalar
39    assert all(c.method([1, 2]) == np.array([1, 2]))
40    assert isinstance(c.method([1, 2]), np.ndarray)
41
42
43def test_vectorize_if_needed():
44    """
45    Test :func:`astropy.cosmology.utils.vectorize_if_needed`.
46    There's no need to test 'veckw' because that is directly pasased to
47    `numpy.vectorize` which thoroughly tests the various inputs.
48
49    """
50    func = lambda x: x ** 2
51
52    with pytest.warns(AstropyDeprecationWarning):
53        # not vectorized
54        assert vectorize_if_needed(func, 2) == 4
55        # vectorized
56        assert all(vectorize_if_needed(func, [2, 3]) == [4, 9])
57
58
59@pytest.mark.parametrize("arr, expected",
60                         [(0.0, inf),  # float scalar
61                          (1, inf),  # integer scalar should give float output
62                          ([0.0, 1.0, 2.0, 3.0], (inf, inf, inf, inf)),
63                          ([0, 1, 2, 3], (inf, inf, inf, inf)),  # integer list
64                         ])
65def test_inf_like(arr, expected):
66    """
67    Test :func:`astropy.cosmology.utils.inf_like`.
68    All inputs should give a float output.
69    These tests are also in the docstring, but it's better to have them also
70    in one consolidated location.
71    """
72    with pytest.warns(AstropyDeprecationWarning):
73        assert np.all(inf_like(arr) == expected)
74