1# Licensed under a 3-clause BSD style license - see LICENSE.rst
2
3"""Configure the tests for :mod:`astropy.cosmology`."""
4
5##############################################################################
6# IMPORTS
7
8# STDLIB
9import json
10import os
11
12import pytest
13
14import astropy.cosmology.units as cu
15import astropy.units as u
16from astropy.cosmology import core
17from astropy.cosmology.core import Cosmology
18from astropy.utils.misc import isiterable
19
20###############################################################################
21# FUNCTIONS
22
23
24def read_json(filename, **kwargs):
25    """Read JSON.
26
27    Parameters
28    ----------
29    filename : str
30    **kwargs
31        Keyword arguments into :meth:`~astropy.cosmology.Cosmology.from_format`
32
33    Returns
34    -------
35    `~astropy.cosmology.Cosmology` instance
36
37    """
38    # read
39    if isinstance(filename, (str, bytes, os.PathLike)):
40        with open(filename, "r") as file:
41            data = file.read()
42    else:  # file-like : this also handles errors in dumping
43        data = filename.read()
44
45    mapping = json.loads(data)  # parse json mappable to dict
46
47    # deserialize Quantity
48    with u.add_enabled_units(cu.redshift):
49        for k, v in mapping.items():
50            if isinstance(v, dict) and "value" in v and "unit" in v:
51                mapping[k] = u.Quantity(v["value"], v["unit"])
52        for k, v in mapping.get("meta", {}).items():  # also the metadata
53            if isinstance(v, dict) and "value" in v and "unit" in v:
54                mapping["meta"][k] = u.Quantity(v["value"], v["unit"])
55
56    return Cosmology.from_format(mapping, **kwargs)
57
58
59def write_json(cosmology, file, *, overwrite=False):
60    """Write Cosmology to JSON.
61
62    Parameters
63    ----------
64    cosmology : `astropy.cosmology.Cosmology` subclass instance
65    file : path-like or file-like
66    overwrite : bool (optional, keyword-only)
67    """
68    data = cosmology.to_format("mapping")  # start by turning into dict
69    data["cosmology"] = data["cosmology"].__qualname__
70
71    # serialize Quantity
72    for k, v in data.items():
73        if isinstance(v, u.Quantity):
74            data[k] = {"value": v.value.tolist(), "unit": str(v.unit)}
75    for k, v in data.get("meta", {}).items():  # also serialize the metadata
76        if isinstance(v, u.Quantity):
77            data["meta"][k] = {"value": v.value.tolist(), "unit": str(v.unit)}
78
79    # check that file exists and whether to overwrite.
80    if os.path.exists(file) and not overwrite:
81        raise IOError(f"{file} exists. Set 'overwrite' to write over.")
82    with open(file, "w") as write_file:
83        json.dump(data, write_file)
84
85
86def json_identify(origin, filepath, fileobj, *args, **kwargs):
87    return filepath is not None and filepath.endswith(".json")
88
89
90###############################################################################
91# FIXTURES
92
93@pytest.fixture
94def clean_registry():
95    # TODO! with monkeypatch instead for thread safety.
96    ORIGINAL_COSMOLOGY_CLASSES = core._COSMOLOGY_CLASSES
97    core._COSMOLOGY_CLASSES = {}  # set as empty dict
98
99    yield core._COSMOLOGY_CLASSES
100
101    core._COSMOLOGY_CLASSES = ORIGINAL_COSMOLOGY_CLASSES
102