1# pylint: disable=redefined-outer-name
2import os
3import shutil
4from collections.abc import MutableMapping
5
6import numpy as np
7import pytest
8
9from ... import InferenceData, from_dict
10
11from ..helpers import (  # pylint: disable=unused-import
12    chains,
13    check_multiple_attrs,
14    draws,
15    eight_schools_params,
16    importorskip,
17    running_on_ci,
18)
19
20zarr = importorskip("zarr")  # pylint: disable=invalid-name
21
22
23class TestDataZarr:
24    @pytest.fixture(scope="class")
25    def data(self, draws, chains):
26        class Data:
27            # fake 8-school output
28            obj = {}
29            for key, shape in {"mu": [], "tau": [], "eta": [8], "theta": [8]}.items():
30                obj[key] = np.random.randn(chains, draws, *shape)
31
32        return Data
33
34    def get_inference_data(self, data, eight_schools_params):
35        return from_dict(
36            posterior=data.obj,
37            posterior_predictive=data.obj,
38            sample_stats=data.obj,
39            prior=data.obj,
40            prior_predictive=data.obj,
41            sample_stats_prior=data.obj,
42            observed_data=eight_schools_params,
43            coords={"school": np.arange(8)},
44            dims={"theta": ["school"], "eta": ["school"]},
45        )
46
47    @pytest.mark.parametrize("store", [0, 1, 2])
48    def test_io_method(self, data, eight_schools_params, store):
49        # create InferenceData and check it has been properly created
50        inference_data = self.get_inference_data(  # pylint: disable=W0612
51            data, eight_schools_params
52        )
53        test_dict = {
54            "posterior": ["eta", "theta", "mu", "tau"],
55            "posterior_predictive": ["eta", "theta", "mu", "tau"],
56            "sample_stats": ["eta", "theta", "mu", "tau"],
57            "prior": ["eta", "theta", "mu", "tau"],
58            "prior_predictive": ["eta", "theta", "mu", "tau"],
59            "sample_stats_prior": ["eta", "theta", "mu", "tau"],
60            "observed_data": ["J", "y", "sigma"],
61        }
62        fails = check_multiple_attrs(test_dict, inference_data)
63        assert not fails
64
65        # check filename does not exist and use to_zarr method
66        here = os.path.dirname(os.path.abspath(__file__))
67        data_directory = os.path.join(here, "..", "saved_models")
68        filepath = os.path.join(data_directory, "zarr")
69        assert not os.path.exists(filepath)
70
71        # InferenceData method
72        if store == 0:
73            # Tempdir
74            store = inference_data.to_zarr(store=None)
75            assert isinstance(store, MutableMapping)
76        elif store == 1:
77            inference_data.to_zarr(store=filepath)
78            # assert file has been saved correctly
79            assert os.path.exists(filepath)
80            assert os.path.getsize(filepath) > 0
81        elif store == 2:
82            store = zarr.storage.DirectoryStore(filepath)
83            inference_data.to_zarr(store=store)
84            # assert file has been saved correctly
85            assert os.path.exists(filepath)
86            assert os.path.getsize(filepath) > 0
87
88        if isinstance(store, MutableMapping):
89            inference_data2 = InferenceData.from_zarr(store)
90        else:
91            inference_data2 = InferenceData.from_zarr(filepath)
92
93        # Everything in dict still available in inference_data2 ?
94        fails = check_multiple_attrs(test_dict, inference_data2)
95        assert not fails
96
97        # Remove created folder structure
98        if os.path.exists(filepath):
99            shutil.rmtree(filepath)
100        assert not os.path.exists(filepath)
101