1# pylint: disable=redefined-outer-name 2import os 3import shutil 4from collections.abc import MutableMapping 5 6import numpy as np 7import pytest 8 9from ... import InferenceData, from_dict 10 11from ..helpers import ( # pylint: disable=unused-import 12 chains, 13 check_multiple_attrs, 14 draws, 15 eight_schools_params, 16 importorskip, 17 running_on_ci, 18) 19 20zarr = importorskip("zarr") # pylint: disable=invalid-name 21 22 23class TestDataZarr: 24 @pytest.fixture(scope="class") 25 def data(self, draws, chains): 26 class Data: 27 # fake 8-school output 28 obj = {} 29 for key, shape in {"mu": [], "tau": [], "eta": [8], "theta": [8]}.items(): 30 obj[key] = np.random.randn(chains, draws, *shape) 31 32 return Data 33 34 def get_inference_data(self, data, eight_schools_params): 35 return from_dict( 36 posterior=data.obj, 37 posterior_predictive=data.obj, 38 sample_stats=data.obj, 39 prior=data.obj, 40 prior_predictive=data.obj, 41 sample_stats_prior=data.obj, 42 observed_data=eight_schools_params, 43 coords={"school": np.arange(8)}, 44 dims={"theta": ["school"], "eta": ["school"]}, 45 ) 46 47 @pytest.mark.parametrize("store", [0, 1, 2]) 48 def test_io_method(self, data, eight_schools_params, store): 49 # create InferenceData and check it has been properly created 50 inference_data = self.get_inference_data( # pylint: disable=W0612 51 data, eight_schools_params 52 ) 53 test_dict = { 54 "posterior": ["eta", "theta", "mu", "tau"], 55 "posterior_predictive": ["eta", "theta", "mu", "tau"], 56 "sample_stats": ["eta", "theta", "mu", "tau"], 57 "prior": ["eta", "theta", "mu", "tau"], 58 "prior_predictive": ["eta", "theta", "mu", "tau"], 59 "sample_stats_prior": ["eta", "theta", "mu", "tau"], 60 "observed_data": ["J", "y", "sigma"], 61 } 62 fails = check_multiple_attrs(test_dict, inference_data) 63 assert not fails 64 65 # check filename does not exist and use to_zarr method 66 here = os.path.dirname(os.path.abspath(__file__)) 67 data_directory = os.path.join(here, "..", "saved_models") 68 filepath = os.path.join(data_directory, "zarr") 69 assert not os.path.exists(filepath) 70 71 # InferenceData method 72 if store == 0: 73 # Tempdir 74 store = inference_data.to_zarr(store=None) 75 assert isinstance(store, MutableMapping) 76 elif store == 1: 77 inference_data.to_zarr(store=filepath) 78 # assert file has been saved correctly 79 assert os.path.exists(filepath) 80 assert os.path.getsize(filepath) > 0 81 elif store == 2: 82 store = zarr.storage.DirectoryStore(filepath) 83 inference_data.to_zarr(store=store) 84 # assert file has been saved correctly 85 assert os.path.exists(filepath) 86 assert os.path.getsize(filepath) > 0 87 88 if isinstance(store, MutableMapping): 89 inference_data2 = InferenceData.from_zarr(store) 90 else: 91 inference_data2 = InferenceData.from_zarr(filepath) 92 93 # Everything in dict still available in inference_data2 ? 94 fails = check_multiple_attrs(test_dict, inference_data2) 95 assert not fails 96 97 # Remove created folder structure 98 if os.path.exists(filepath): 99 shutil.rmtree(filepath) 100 assert not os.path.exists(filepath) 101