1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3import inspect
4import sys
5from io import StringIO
6
7import pytest
8
9import numpy as np
10
11from astropy import units as u
12from astropy.cosmology import core, flrw
13from astropy.cosmology.funcs import _z_at_scalar_value, z_at_value
14from astropy.cosmology.realizations import WMAP1, WMAP3, WMAP5, WMAP7, WMAP9, Planck13, Planck15, Planck18
15from astropy.units import allclose
16from astropy.utils.compat.optional_deps import HAS_SCIPY  # noqa
17from astropy.utils.exceptions import AstropyUserWarning
18
19
20@pytest.mark.skipif('not HAS_SCIPY')
21def test_z_at_value_scalar():
22    # These are tests of expected values, and hence have less precision
23    # than the roundtrip tests below (test_z_at_value_roundtrip);
24    # here we have to worry about the cosmological calculations
25    # giving slightly different values on different architectures,
26    # there we are checking internal consistency on the same architecture
27    # and so can be more demanding
28    cosmo = Planck13
29    assert allclose(z_at_value(cosmo.age, 2 * u.Gyr), 3.19812268, rtol=1e-6)
30    assert allclose(z_at_value(cosmo.lookback_time, 7 * u.Gyr), 0.795198375, rtol=1e-6)
31    assert allclose(z_at_value(cosmo.distmod, 46 * u.mag), 1.991389168, rtol=1e-6)
32    assert allclose(z_at_value(cosmo.luminosity_distance, 1e4 * u.Mpc), 1.36857907, rtol=1e-6)
33    assert allclose(z_at_value(cosmo.luminosity_distance, 26.037193804 * u.Gpc, ztol=1e-10),
34                    3, rtol=1e-9)
35    assert allclose(z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, zmax=2),
36                    0.681277696, rtol=1e-6)
37    assert allclose(z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, zmin=2.5),
38                    3.7914908, rtol=1e-6)
39
40    # test behavior when the solution is outside z limits (should
41    # raise a CosmologyError)
42    with pytest.raises(core.CosmologyError):
43        with pytest.warns(AstropyUserWarning, match=r'fval is not bracketed'):
44            z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, zmax=0.5)
45
46    with pytest.raises(core.CosmologyError):
47        with pytest.warns(AstropyUserWarning, match=r'fval is not bracketed'):
48            z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, zmin=4.)
49
50
51@pytest.mark.skipif('not HAS_SCIPY')
52class Test_ZatValue:
53
54    def setup_class(self):
55        self.cosmo = Planck13
56
57    def test_broadcast_arguments(self):
58        """Test broadcast of arguments."""
59        # broadcasting main argument
60        assert allclose(
61            z_at_value(self.cosmo.age, [2, 7] * u.Gyr),
62            [3.1981206134773115, 0.7562044333305182], rtol=1e-6)
63
64        # basic broadcast of secondary arguments
65        assert allclose(
66            z_at_value(self.cosmo.angular_diameter_distance, 1500 * u.Mpc,
67                       zmin=[0, 2.5], zmax=[2, 4]),
68            [0.681277696, 3.7914908], rtol=1e-6)
69
70        # more interesting broadcast
71        assert allclose(
72            z_at_value(self.cosmo.angular_diameter_distance, 1500 * u.Mpc,
73                       zmin=[[0, 2.5]], zmax=[2, 4]),
74            [[0.681277696, 3.7914908]], rtol=1e-6)
75
76    def test_broadcast_bracket(self):
77        """`bracket` has special requirements."""
78        # start with an easy one
79        assert allclose(
80            z_at_value(self.cosmo.age, 2 * u.Gyr, bracket=None),
81            3.1981206134773115, rtol=1e-6)
82
83        # now actually have a bracket
84        assert allclose(
85            z_at_value(self.cosmo.age, 2 * u.Gyr, bracket=[0, 4]),
86            3.1981206134773115, rtol=1e-6)
87
88        # now a bad length
89        with pytest.raises(ValueError, match="sequence"):
90            z_at_value(self.cosmo.age, 2 * u.Gyr, bracket=[0, 4, 4, 5])
91
92        # now the wrong dtype : an ndarray, but not an object array
93        with pytest.raises(TypeError, match="dtype"):
94            z_at_value(self.cosmo.age, 2 * u.Gyr, bracket=np.array([0, 4]))
95
96        # now an object array of brackets
97        bracket=np.array([[0, 4], [0, 3, 4]], dtype=object)
98        assert allclose(
99            z_at_value(self.cosmo.age, 2 * u.Gyr, bracket=bracket),
100            [3.1981206134773115, 3.1981206134773115], rtol=1e-6)
101
102    def test_bad_broadcast(self):
103        """Shapes mismatch as expected"""
104        with pytest.raises(ValueError, match="broadcast"):
105            z_at_value(self.cosmo.angular_diameter_distance, 1500 * u.Mpc,
106                       zmin=[0, 2.5, 0.1], zmax=[2, 4])
107
108    def test_scalar_input_to_output(self):
109        """Test scalar input returns a scalar."""
110        z = z_at_value(self.cosmo.angular_diameter_distance, 1500 * u.Mpc,
111                       zmin=0, zmax=2)
112        assert isinstance(z, u.Quantity)
113        assert z.dtype == np.float64
114        assert z.shape == ()
115
116
117@pytest.mark.skipif('not HAS_SCIPY')
118def test_z_at_value_numpyvectorize():
119    """Test that numpy vectorize fails on Quantities.
120
121    If this test starts failing then numpy vectorize can be used instead of
122    the home-brewed vectorization. Please submit a PR making the change.
123    """
124    z_at_value = np.vectorize(_z_at_scalar_value,
125                              excluded=["func", "method", "verbose"])
126    with pytest.raises(u.UnitConversionError, match="dimensionless quantities"):
127        z_at_value(Planck15.age, 10*u.Gyr)
128
129
130@pytest.mark.skipif('not HAS_SCIPY')
131def test_z_at_value_verbose(monkeypatch):
132    cosmo = Planck13
133
134    # Test the "verbose" flag. Since this uses "print", need to mod stdout
135    mock_stdout = StringIO()
136    monkeypatch.setattr(sys, 'stdout', mock_stdout)
137
138    resx = z_at_value(cosmo.age, 2 * u.Gyr, verbose=True)
139    assert str(resx.value) in mock_stdout.getvalue()  # test "verbose" prints res
140
141
142@pytest.mark.skipif('not HAS_SCIPY')
143@pytest.mark.parametrize('method', ['Brent', 'Golden', 'Bounded'])
144def test_z_at_value_bracketed(method):
145    """
146    Test 2 solutions for angular diameter distance by not constraining zmin, zmax,
147    but setting `bracket` on the appropriate side of the turning point z.
148    Setting zmin / zmax should override `bracket`.
149    """
150    cosmo = Planck13
151
152    if method == 'Bounded':
153        with pytest.warns(AstropyUserWarning, match=r'fval is not bracketed'):
154            z = z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, method=method)
155        if z > 1.6:
156            z = 3.7914908
157            bracket = (0.9, 1.5)
158        else:
159            z = 0.6812777
160            bracket = (1.6, 2.0)
161        with pytest.warns(UserWarning, match=r"Option 'bracket' is ignored"):
162            assert allclose(z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, method=method,
163                                       bracket=bracket), z, rtol=1e-6)
164    else:
165        assert allclose(z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, method=method,
166                                   bracket=(0.3, 1.0)), 0.6812777, rtol=1e-6)
167        assert allclose(z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, method=method,
168                                   bracket=(2.0, 4.0)), 3.7914908, rtol=1e-6)
169        assert allclose(z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, method=method,
170                                   bracket=(0.1, 1.5)), 0.6812777, rtol=1e-6)
171        assert allclose(z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, method=method,
172                                   bracket=(0.1, 1.0, 2.0)), 0.6812777, rtol=1e-6)
173        with pytest.warns(AstropyUserWarning, match=r'fval is not bracketed'):
174            assert allclose(z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, method=method,
175                                       bracket=(0.9, 1.5)), 0.6812777, rtol=1e-6)
176            assert allclose(z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, method=method,
177                                       bracket=(1.6, 2.0)), 3.7914908, rtol=1e-6)
178        assert allclose(z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, method=method,
179                                   bracket=(1.6, 2.0), zmax=1.6), 0.6812777, rtol=1e-6)
180        assert allclose(z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, method=method,
181                                   bracket=(0.9, 1.5), zmin=1.5), 3.7914908, rtol=1e-6)
182
183    with pytest.raises(core.CosmologyError):
184        with pytest.warns(AstropyUserWarning, match=r'fval is not bracketed'):
185            z_at_value(cosmo.angular_diameter_distance, 1500*u.Mpc, method=method,
186                       bracket=(3.9, 5.0), zmin=4.)
187
188
189@pytest.mark.skipif('not HAS_SCIPY')
190@pytest.mark.parametrize('method', ['Brent', 'Golden', 'Bounded'])
191def test_z_at_value_unconverged(method):
192    """
193    Test warnings on non-converged solution when setting `maxfun` to too small iteration number -
194    only 'Bounded' returns status value and specific message.
195    """
196    cosmo = Planck18
197    ztol = {'Brent': [1e-4, 1e-4], 'Golden': [1e-3, 1e-2], 'Bounded': [1e-3, 1e-1]}
198
199    if method == 'Bounded':
200        status, message = 1, 'Maximum number of function calls reached.'
201    else:
202        status, message = None, 'Unsuccessful'
203    diag = rf'Solver returned {status}: {message}'
204
205    with pytest.warns(AstropyUserWarning, match=diag):
206        z0 = z_at_value(cosmo.angular_diameter_distance, 1*u.Gpc, zmax=2, maxfun=13, method=method)
207    with pytest.warns(AstropyUserWarning, match=diag):
208        z1 = z_at_value(cosmo.angular_diameter_distance, 1*u.Gpc, zmin=2, maxfun=13, method=method)
209
210    assert allclose(z0, 0.32442, rtol=ztol[method][0])
211    assert allclose(z1, 8.18551, rtol=ztol[method][1])
212
213
214@pytest.mark.skipif('not HAS_SCIPY')
215@pytest.mark.parametrize('cosmo', [Planck13, Planck15, Planck18, WMAP1, WMAP3, WMAP5, WMAP7, WMAP9,
216                                   flrw.LambdaCDM, flrw.FlatLambdaCDM, flrw.wpwaCDM, flrw.w0wzCDM,
217                                   flrw.wCDM, flrw.FlatwCDM, flrw.w0waCDM, flrw.Flatw0waCDM])
218def test_z_at_value_roundtrip(cosmo):
219    """
220    Calculate values from a known redshift, and then check that
221    z_at_value returns the right answer.
222    """
223    z = 0.5
224
225    # Skip Ok, w, de_density_scale because in the Planck cosmologies
226    # they are redshift independent and hence uninvertable,
227    # *_distance_z1z2 methods take multiple arguments, so require
228    # special handling
229    # clone is not a redshift-dependent method
230    # nu_relative_density is not redshift-dependent in the WMAP cosmologies
231    skip = ('Ok',
232            'angular_diameter_distance_z1z2',
233            'clone', 'is_equivalent',
234            'de_density_scale', 'w')
235    if str(cosmo.name).startswith('WMAP'):
236        skip += ('nu_relative_density', )
237
238    methods = inspect.getmembers(cosmo, predicate=inspect.ismethod)
239
240    for name, func in methods:
241        if name.startswith('_') or name in skip:
242            continue
243        fval = func(z)
244        # we need zmax here to pick the right solution for
245        # angular_diameter_distance and related methods.
246        # Be slightly more generous with rtol than the default 1e-8
247        # used in z_at_value
248        got = z_at_value(func, fval, bracket=[0.3, 1.0], ztol=1e-12)
249        assert allclose(got, z, rtol=2e-11), f'Round-trip testing {name} failed'
250
251    # Test distance functions between two redshifts; only for realizations
252    if isinstance(cosmo.name, str):
253        z2 = 2.0
254        func_z1z2 = [
255            lambda z1: cosmo._comoving_distance_z1z2(z1, z2),
256            lambda z1: cosmo._comoving_transverse_distance_z1z2(z1, z2),
257            lambda z1: cosmo.angular_diameter_distance_z1z2(z1, z2)
258        ]
259        for func in func_z1z2:
260            fval = func(z)
261            assert allclose(z, z_at_value(func, fval, zmax=1.5, ztol=1e-12), rtol=2e-11)
262