1"""Base IO code for all datasets. Heavily influenced by scikit-learn's implementation."""
2import hashlib
3import itertools
4import os
5import shutil
6from collections import namedtuple
7from urllib.request import urlretrieve
8
9from ..rcparams import rcParams
10from .io_netcdf import from_netcdf
11
12LocalFileMetadata = namedtuple("LocalFileMetadata", ["filename", "description"])
13
14RemoteFileMetadata = namedtuple(
15    "RemoteFileMetadata", ["filename", "url", "checksum", "description"]
16)
17_DATASET_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "_datasets")
18
19# Models for local datasets are located at https://github.com/arviz-devs/arviz_example_data
20
21LOCAL_DATASETS = {
22    "centered_eight": LocalFileMetadata(
23        filename=os.path.join(_DATASET_DIR, "centered_eight.nc"),
24        description="""
25A centered parameterization of the eight schools model. Provided as an example of a
26model that NUTS has trouble fitting. Compare to `load_arviz_data("non_centered_eight")`.
27
28The eight schools model is a hierarchical model used for an analysis of the effectiveness
29of classes that were designed to improve students’ performance on the Scholastic Aptitude Test.
30
31See Bayesian Data Analysis (Gelman et. al.) for more details.
32""",
33    ),
34    "non_centered_eight": LocalFileMetadata(
35        filename=os.path.join(_DATASET_DIR, "non_centered_eight.nc"),
36        description="""
37A non-centered parameterization of the eight schools model. This is a hierarchical model
38where sampling problems may be fixed by a non-centered parametrization. Compare to
39`load_arviz_data("centered_eight")`.
40
41The eight schools model is a hierarchical model used for an analysis of the effectiveness
42of classes that were designed to improve students’ performance on the Scholastic Aptitude Test.
43
44See Bayesian Data Analysis (Gelman et. al.) for more details.
45""",
46    ),
47}
48
49REMOTE_DATASETS = {
50    "radon": RemoteFileMetadata(
51        filename="radon_hierarchical.nc",
52        url="http://ndownloader.figshare.com/files/24067472",
53        checksum="a9b2b4adf1bf9c5728e5bdc97107e69c4fc8d8b7d213e9147233b57be8b4587b",
54        description="""
55Radon is a radioactive gas that enters homes through contact points with the ground.
56It is a carcinogen that is the primary cause of lung cancer in non-smokers. Radon
57levels vary greatly from household to household.
58
59This example uses an EPA study of radon levels in houses in Minnesota to construct a
60model with a hierarchy over households within a county. The model includes estimates
61(gamma) for contextual effects of the uranium per household.
62
63See Gelman and Hill (2006) for details on the example, or
64https://docs.pymc.io/notebooks/multilevel_modeling.html
65by Chris Fonnesbeck for details on this implementation.
66""",
67    ),
68    "rugby": RemoteFileMetadata(
69        filename="rugby.nc",
70        url="http://ndownloader.figshare.com/files/16254359",
71        checksum="9eecd2c6317e45b0388dd97ae6326adecf94128b5a7d15a52c9fcfac0937e2a6",
72        description="""
73The Six Nations Championship is a yearly rugby competition between Italy, Ireland,
74Scotland, England, France and Wales. Fifteen games are played each year, representing
75all combinations of the six teams.
76
77This example uses and includes results from 2014 - 2017, comprising 60 total
78games. It models latent parameters for each team's attack and defense, as well
79as a parameter for home team advantage.
80
81See https://docs.pymc.io/notebooks/rugby_analytics.html by Peader Coyle
82for more details and references.
83""",
84    ),
85    "regression1d": RemoteFileMetadata(
86        filename="regression1d.nc",
87        url="http://ndownloader.figshare.com/files/16254899",
88        checksum="909e8ffe344e196dad2730b1542881ab5729cb0977dd20ba645a532ffa427278",
89        description="""
90A synthetic one dimensional linear regression dataset with latent slope,
91intercept, and noise ("eps"). One hundred data points, fit with PyMC3.
92
93True slope and intercept are included as deterministic variables.
94""",
95    ),
96    "regression10d": RemoteFileMetadata(
97        filename="regression10d.nc",
98        url="http://ndownloader.figshare.com/files/16255736",
99        checksum="c6716ec7e19926ad2a52d6ae4c1d1dd5ddb747e204c0d811757c8e93fcf9f970",
100        description="""
101A synthetic multi-dimensional (10 dimensions) linear regression dataset with
102latent weights ("w"), intercept, and noise ("eps"). Five hundred data points,
103fit with PyMC3.
104
105True weights and intercept are included as deterministic variables.
106""",
107    ),
108    "classification1d": RemoteFileMetadata(
109        filename="classification1d.nc",
110        url="http://ndownloader.figshare.com/files/16256678",
111        checksum="1cf3806e72c14001f6864bb69d89747dcc09dd55bcbca50aba04e9939daee5a0",
112        description="""
113A synthetic one dimensional logistic regression dataset with latent slope and
114intercept, passed into a Bernoulli random variable. One hundred data points,
115fit with PyMC3.
116
117True slope and intercept are included as deterministic variables.
118""",
119    ),
120    "classification10d": RemoteFileMetadata(
121        filename="classification10d.nc",
122        url="http://ndownloader.figshare.com/files/16256681",
123        checksum="16c9a45e1e6e0519d573cafc4d266d761ba347e62b6f6a79030aaa8e2fde1367",
124        description="""
125A synthetic multi dimensional (10 dimensions) logistic regression dataset with
126latent weights ("w") and intercept, passed into a Bernoulli random variable.
127Five hundred data points, fit with PyMC3.
128
129True weights and intercept are included as deterministic variables.
130""",
131    ),
132    "glycan_torsion_angles": RemoteFileMetadata(
133        filename="glycan_torsion_angles.nc",
134        url="http://ndownloader.figshare.com/files/22882652",
135        checksum="4622621fe7a1d3075c18c4c34af8cc57c59eabbb3501b20c6e2d9c6c4737034c",
136        description="""
137Torsion angles phi and psi are critical for determining the three dimensional
138structure of bio-molecules. Combinations of phi and psi torsion angles that
139produce clashes between atoms in the bio-molecule result in high energy, unlikely structures.
140
141This model uses a Von Mises distribution to propose torsion angles for the
142structure of a glycan molecule (pdb id: 2LIQ), and a Potential to estimate
143the proposed structure's energy. Said Potential is bound by Boltzman's law.
144""",
145    ),
146}
147
148
149def get_data_home(data_home=None):
150    """Return the path of the arviz data dir.
151
152    This folder is used by some dataset loaders to avoid downloading the
153    data several times.
154
155    By default the data dir is set to a folder named 'arviz_data' in the
156    user home folder.
157
158    Alternatively, it can be set by the 'ARVIZ_DATA' environment
159    variable or programmatically by giving an explicit folder path. The '~'
160    symbol is expanded to the user home folder.
161
162    If the folder does not already exist, it is automatically created.
163
164    Parameters
165    ----------
166    data_home : str | None
167        The path to arviz data dir.
168    """
169    if data_home is None:
170        data_home = os.environ.get("ARVIZ_DATA", os.path.join("~", "arviz_data"))
171    data_home = os.path.expanduser(data_home)
172    if not os.path.exists(data_home):
173        os.makedirs(data_home)
174    return data_home
175
176
177def clear_data_home(data_home=None):
178    """Delete all the content of the data home cache.
179
180    Parameters
181    ----------
182    data_home : str | None
183        The path to arviz data dir.
184    """
185    data_home = get_data_home(data_home)
186    shutil.rmtree(data_home)
187
188
189def _sha256(path):
190    """Calculate the sha256 hash of the file at path."""
191    sha256hash = hashlib.sha256()
192    chunk_size = 8192
193    with open(path, "rb") as buff:
194        while True:
195            buffer = buff.read(chunk_size)
196            if not buffer:
197                break
198            sha256hash.update(buffer)
199    return sha256hash.hexdigest()
200
201
202def load_arviz_data(dataset=None, data_home=None):
203    """Load a local or remote pre-made dataset.
204
205    Run with no parameters to get a list of all available models.
206
207    The directory to save to can also be set with the environment
208    variable `ARVIZ_HOME`. The checksum of the dataset is checked against a
209    hardcoded value to watch for data corruption.
210
211    Run `az.clear_data_home` to clear the data directory.
212
213    Parameters
214    ----------
215    dataset : str
216        Name of dataset to load.
217
218    data_home : str, optional
219        Where to save remote datasets
220
221    Returns
222    -------
223    xarray.Dataset
224    """
225    if dataset in LOCAL_DATASETS:
226        resource = LOCAL_DATASETS[dataset]
227        return from_netcdf(resource.filename)
228
229    elif dataset in REMOTE_DATASETS:
230        remote = REMOTE_DATASETS[dataset]
231        home_dir = get_data_home(data_home=data_home)
232        file_path = os.path.join(home_dir, remote.filename)
233
234        if not os.path.exists(file_path):
235            http_type = rcParams["data.http_protocol"]
236
237            # Replaces http type. Redundant if http_type is http, useful if http_type is https
238            url = remote.url.replace("http", http_type)
239            urlretrieve(url, file_path)
240
241        checksum = _sha256(file_path)
242        if remote.checksum != checksum:
243            raise IOError(
244                "{} has an SHA256 checksum ({}) differing from expected ({}), "
245                "file may be corrupted. Run `arviz.clear_data_home()` and try "
246                "again, or please open an issue.".format(file_path, checksum, remote.checksum)
247            )
248        return from_netcdf(file_path)
249    else:
250        if dataset is None:
251            return dict(itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items()))
252        else:
253            raise ValueError(
254                "Dataset {} not found! The following are available:\n{}".format(
255                    dataset, list_datasets()
256                )
257            )
258
259
260def list_datasets():
261    """Get a string representation of all available datasets with descriptions."""
262    lines = []
263    for name, resource in itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items()):
264
265        if isinstance(resource, LocalFileMetadata):
266            location = f"local: {resource.filename}"
267        elif isinstance(resource, RemoteFileMetadata):
268            location = f"remote: {resource.url}"
269        else:
270            location = "unknown"
271        lines.append(f"{name}\n{'=' * len(name)}\n{resource.description}\n{location}")
272
273    return f"\n\n{10 * '-'}\n\n".join(lines)
274