1#   Copyright 2020 The PyMC Developers
2#
3#   Licensed under the Apache License, Version 2.0 (the "License");
4#   you may not use this file except in compliance with the License.
5#   You may obtain a copy of the License at
6#
7#       http://www.apache.org/licenses/LICENSE-2.0
8#
9#   Unless required by applicable law or agreed to in writing, software
10#   distributed under the License is distributed on an "AS IS" BASIS,
11#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12#   See the License for the specific language governing permissions and
13#   limitations under the License.
14
15"""NumPy array trace backend
16
17Store sampling values in memory as a NumPy array.
18"""
19import glob
20import json
21import os
22import shutil
23import warnings
24
25from typing import Any, Dict, List, Optional
26
27import numpy as np
28
29from pymc3.backends import base
30from pymc3.backends.base import MultiTrace
31from pymc3.exceptions import TraceDirectoryError
32from pymc3.model import Model, modelcontext
33
34
35def save_trace(trace: MultiTrace, directory: Optional[str] = None, overwrite=False) -> str:
36    """Save multitrace to file.
37
38    TODO: Also save warnings.
39
40    This is a custom data format for PyMC3 traces.  Each chain goes inside
41    a directory, and each directory contains a metadata json file, and a
42    numpy compressed file.  See https://docs.scipy.org/doc/numpy/neps/npy-format.html
43    for more information about this format.
44
45    Parameters
46    ----------
47    trace: pm.MultiTrace
48        trace to save to disk
49    directory: str (optional)
50        path to a directory to save the trace
51    overwrite: bool (default False)
52        whether to overwrite an existing directory.
53
54    Returns
55    -------
56    str, path to the directory where the trace was saved
57    """
58    warnings.warn(
59        "The `save_trace` function will soon be removed."
60        "Instead, use `arviz.to_netcdf` to save traces.",
61        DeprecationWarning,
62    )
63
64    if directory is None:
65        directory = ".pymc_{}.trace"
66        idx = 1
67        while os.path.exists(directory.format(idx)):
68            idx += 1
69        directory = directory.format(idx)
70
71    if os.path.isdir(directory):
72        if overwrite:
73            shutil.rmtree(directory)
74        else:
75            raise OSError(
76                "Cautiously refusing to overwrite the already existing {}! Please supply "
77                "a different directory, or set `overwrite=True`".format(directory)
78            )
79    os.makedirs(directory)
80
81    for chain, ndarray in trace._straces.items():
82        SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray)
83    return directory
84
85
86def load_trace(directory: str, model=None) -> MultiTrace:
87    """Loads a multitrace that has been written to file.
88
89    A the model used for the trace must be passed in, or the command
90    must be run in a model context.
91
92    Parameters
93    ----------
94    directory: str
95        Path to a pymc3 serialized trace
96    model: pm.Model (optional)
97        Model used to create the trace.  Can also be inferred from context
98
99    Returns
100    -------
101    pm.Multitrace that was saved in the directory
102    """
103    warnings.warn(
104        "The `load_trace` function will soon be removed."
105        "Instead, use `arviz.from_netcdf` to load traces.",
106        DeprecationWarning,
107    )
108    straces = []
109    for subdir in glob.glob(os.path.join(directory, "*")):
110        if os.path.isdir(subdir):
111            straces.append(SerializeNDArray(subdir).load(model))
112    if not straces:
113        raise TraceDirectoryError("%s is not a PyMC3 saved chain directory." % directory)
114    return base.MultiTrace(straces)
115
116
117class SerializeNDArray:
118    metadata_file = "metadata.json"
119    samples_file = "samples.npz"
120    metadata_path = None  # type: str
121    samples_path = None  # type: str
122
123    def __init__(self, directory: str):
124        """Helper to save and load NDArray objects"""
125        warnings.warn(
126            "The `SerializeNDArray` class will soon be removed. "
127            "Instead, use ArviZ to save/load traces.",
128            DeprecationWarning,
129        )
130        self.directory = directory
131        self.metadata_path = os.path.join(self.directory, self.metadata_file)
132        self.samples_path = os.path.join(self.directory, self.samples_file)
133
134    @staticmethod
135    def to_metadata(ndarray):
136        """Extract ndarray metadata into json-serializable content"""
137        if ndarray._stats is None:
138            stats = ndarray._stats
139            sampler_vars = None
140        else:
141            stats = []
142            sampler_vars = []
143            for stat in ndarray._stats:
144                stats.append({key: value.tolist() for key, value in stat.items()})
145                sampler_vars.append({key: str(value.dtype) for key, value in stat.items()})
146
147        metadata = {
148            "draw_idx": ndarray.draw_idx,
149            "draws": ndarray.draws,
150            "_stats": stats,
151            "chain": ndarray.chain,
152            "sampler_vars": sampler_vars,
153        }
154        return metadata
155
156    def save(self, ndarray):
157        """Serialize a ndarray to file
158
159        The goal here is to be modestly safer and more portable than a
160        pickle file. The expense is that the model code must be available
161        to reload the multitrace.
162        """
163        if not isinstance(ndarray, NDArray):
164            raise TypeError("Can only save NDArray")
165
166        if os.path.isdir(self.directory):
167            shutil.rmtree(self.directory)
168
169        os.mkdir(self.directory)
170
171        with open(self.metadata_path, "w") as buff:
172            json.dump(SerializeNDArray.to_metadata(ndarray), buff)
173
174        np.savez_compressed(self.samples_path, **ndarray.samples)
175
176    def load(self, model: Model) -> "NDArray":
177        """Load the saved ndarray from file"""
178        if not os.path.exists(self.samples_path) or not os.path.exists(self.metadata_path):
179            raise TraceDirectoryError("%s is not a trace directory" % self.directory)
180
181        new_trace = NDArray(model=model)
182        with open(self.metadata_path) as buff:
183            metadata = json.load(buff)
184
185        metadata["_stats"] = [
186            {k: np.array(v) for k, v in stat.items()} for stat in metadata["_stats"]
187        ]
188
189        # it seems like at least some old traces don't have 'sampler_vars'
190        try:
191            sampler_vars = metadata.pop("sampler_vars")
192            new_trace._set_sampler_vars(sampler_vars)
193        except KeyError:
194            pass
195
196        for key, value in metadata.items():
197            setattr(new_trace, key, value)
198        new_trace.samples = dict(np.load(self.samples_path))
199        return new_trace
200
201
202class NDArray(base.BaseTrace):
203    """NDArray trace object
204
205    Parameters
206    ----------
207    name: str
208        Name of backend. This has no meaning for the NDArray backend.
209    model: Model
210        If None, the model is taken from the `with` context.
211    vars: list of variables
212        Sampling values will be stored for these variables. If None,
213        `model.unobserved_RVs` is used.
214    """
215
216    supports_sampler_stats = True
217
218    def __init__(self, name=None, model=None, vars=None, test_point=None):
219        super().__init__(name, model, vars, test_point)
220        self.draw_idx = 0
221        self.draws = None
222        self.samples = {}
223        self._stats = None
224
225    # Sampling methods
226
227    def setup(self, draws, chain, sampler_vars=None) -> None:
228        """Perform chain-specific setup.
229
230        Parameters
231        ----------
232        draws: int
233            Expected number of draws
234        chain: int
235            Chain number
236        sampler_vars: list of dicts
237            Names and dtypes of the variables that are
238            exported by the samplers.
239        """
240        super().setup(draws, chain, sampler_vars)
241
242        self.chain = chain
243        if self.samples:  # Concatenate new array if chain is already present.
244            old_draws = len(self)
245            self.draws = old_draws + draws
246            self.draw_idx = old_draws
247            for varname, shape in self.var_shapes.items():
248                old_var_samples = self.samples[varname]
249                new_var_samples = np.zeros((draws,) + shape, self.var_dtypes[varname])
250                self.samples[varname] = np.concatenate((old_var_samples, new_var_samples), axis=0)
251        else:  # Otherwise, make array of zeros for each variable.
252            self.draws = draws
253            for varname, shape in self.var_shapes.items():
254                self.samples[varname] = np.zeros((draws,) + shape, dtype=self.var_dtypes[varname])
255
256        if sampler_vars is None:
257            return
258
259        if self._stats is None:
260            self._stats = []
261            for sampler in sampler_vars:
262                data = dict()  # type: Dict[str, np.ndarray]
263                self._stats.append(data)
264                for varname, dtype in sampler.items():
265                    data[varname] = np.zeros(draws, dtype=dtype)
266        else:
267            for data, vars in zip(self._stats, sampler_vars):
268                if vars.keys() != data.keys():
269                    raise ValueError("Sampler vars can't change")
270                old_draws = len(self)
271                for varname, dtype in vars.items():
272                    old = data[varname]
273                    new = np.zeros(draws, dtype=dtype)
274                    data[varname] = np.concatenate([old, new])
275
276    def record(self, point, sampler_stats=None) -> None:
277        """Record results of a sampling iteration.
278
279        Parameters
280        ----------
281        point: dict
282            Values mapped to variable names
283        """
284        for varname, value in zip(self.varnames, self.fn(point)):
285            self.samples[varname][self.draw_idx] = value
286
287        if self._stats is not None and sampler_stats is None:
288            raise ValueError("Expected sampler_stats")
289        if self._stats is None and sampler_stats is not None:
290            raise ValueError("Unknown sampler_stats")
291        if sampler_stats is not None:
292            for data, vars in zip(self._stats, sampler_stats):
293                for key, val in vars.items():
294                    data[key][self.draw_idx] = val
295        self.draw_idx += 1
296
297    def _get_sampler_stats(self, varname, sampler_idx, burn, thin):
298        return self._stats[sampler_idx][varname][burn::thin]
299
300    def close(self):
301        if self.draw_idx == self.draws:
302            return
303        # Remove trailing zeros if interrupted before completed all
304        # draws.
305        self.samples = {var: vtrace[: self.draw_idx] for var, vtrace in self.samples.items()}
306        if self._stats is not None:
307            self._stats = [
308                {var: trace[: self.draw_idx] for var, trace in stats.items()}
309                for stats in self._stats
310            ]
311
312    # Selection methods
313
314    def __len__(self):
315        if not self.samples:  # `setup` has not been called.
316            return 0
317        return self.draw_idx
318
319    def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
320        """Get values from trace.
321
322        Parameters
323        ----------
324        varname: str
325        burn: int
326        thin: int
327
328        Returns
329        -------
330        A NumPy array
331        """
332        return self.samples[varname][burn::thin]
333
334    def _slice(self, idx):
335        # Slicing directly instead of using _slice_as_ndarray to
336        # support stop value in slice (which is needed by
337        # iter_sample).
338
339        # Only the first `draw_idx` value are valid because of preallocation
340        idx = slice(*idx.indices(len(self)))
341
342        sliced = NDArray(model=self.model, vars=self.vars)
343        sliced.chain = self.chain
344        sliced.samples = {varname: values[idx] for varname, values in self.samples.items()}
345        sliced.sampler_vars = self.sampler_vars
346        sliced.draw_idx = (idx.stop - idx.start) // idx.step
347
348        if self._stats is None:
349            return sliced
350        sliced._stats = []
351        for vars in self._stats:
352            var_sliced = {}
353            sliced._stats.append(var_sliced)
354            for key, vals in vars.items():
355                var_sliced[key] = vals[idx]
356
357        return sliced
358
359    def point(self, idx) -> Dict[str, Any]:
360        """Return dictionary of point values at `idx` for current chain
361        with variable names as keys.
362        """
363        idx = int(idx)
364        return {varname: values[idx] for varname, values in self.samples.items()}
365
366
367def _slice_as_ndarray(strace, idx):
368    sliced = NDArray(model=strace.model, vars=strace.vars)
369    sliced.chain = strace.chain
370
371    # Happy path where we do not need to load everything from the trace
372    if (idx.step is None or idx.step >= 1) and (idx.stop is None or idx.stop == len(strace)):
373        start, stop, step = idx.indices(len(strace))
374        sliced.samples = {
375            v: strace.get_values(v, burn=idx.start, thin=idx.step) for v in strace.varnames
376        }
377        sliced.draw_idx = (stop - start) // step
378    else:
379        start, stop, step = idx.indices(len(strace))
380        sliced.samples = {v: strace.get_values(v)[start:stop:step] for v in strace.varnames}
381        sliced.draw_idx = (stop - start) // step
382
383    return sliced
384
385
386def point_list_to_multitrace(
387    point_list: List[Dict[str, np.ndarray]], model: Optional[Model] = None
388) -> MultiTrace:
389    """transform point list into MultiTrace"""
390    _model = modelcontext(model)
391    varnames = list(point_list[0].keys())
392    with _model:
393        chain = NDArray(model=_model, vars=[_model[vn] for vn in varnames])
394        chain.setup(draws=len(point_list), chain=0)
395        # since we are simply loading a trace by hand, we need only a vacuous function for
396        # chain.record() to use. This crushes the default.
397        def point_fun(point):
398            return [point[vn] for vn in varnames]
399
400        chain.fn = point_fun
401        for point in point_list:
402            chain.record(point)
403    return MultiTrace([chain])
404