1# Copyright 2020 The PyMC Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""NumPy array trace backend 16 17Store sampling values in memory as a NumPy array. 18""" 19import glob 20import json 21import os 22import shutil 23import warnings 24 25from typing import Any, Dict, List, Optional 26 27import numpy as np 28 29from pymc3.backends import base 30from pymc3.backends.base import MultiTrace 31from pymc3.exceptions import TraceDirectoryError 32from pymc3.model import Model, modelcontext 33 34 35def save_trace(trace: MultiTrace, directory: Optional[str] = None, overwrite=False) -> str: 36 """Save multitrace to file. 37 38 TODO: Also save warnings. 39 40 This is a custom data format for PyMC3 traces. Each chain goes inside 41 a directory, and each directory contains a metadata json file, and a 42 numpy compressed file. See https://docs.scipy.org/doc/numpy/neps/npy-format.html 43 for more information about this format. 44 45 Parameters 46 ---------- 47 trace: pm.MultiTrace 48 trace to save to disk 49 directory: str (optional) 50 path to a directory to save the trace 51 overwrite: bool (default False) 52 whether to overwrite an existing directory. 53 54 Returns 55 ------- 56 str, path to the directory where the trace was saved 57 """ 58 warnings.warn( 59 "The `save_trace` function will soon be removed." 60 "Instead, use `arviz.to_netcdf` to save traces.", 61 DeprecationWarning, 62 ) 63 64 if directory is None: 65 directory = ".pymc_{}.trace" 66 idx = 1 67 while os.path.exists(directory.format(idx)): 68 idx += 1 69 directory = directory.format(idx) 70 71 if os.path.isdir(directory): 72 if overwrite: 73 shutil.rmtree(directory) 74 else: 75 raise OSError( 76 "Cautiously refusing to overwrite the already existing {}! Please supply " 77 "a different directory, or set `overwrite=True`".format(directory) 78 ) 79 os.makedirs(directory) 80 81 for chain, ndarray in trace._straces.items(): 82 SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray) 83 return directory 84 85 86def load_trace(directory: str, model=None) -> MultiTrace: 87 """Loads a multitrace that has been written to file. 88 89 A the model used for the trace must be passed in, or the command 90 must be run in a model context. 91 92 Parameters 93 ---------- 94 directory: str 95 Path to a pymc3 serialized trace 96 model: pm.Model (optional) 97 Model used to create the trace. Can also be inferred from context 98 99 Returns 100 ------- 101 pm.Multitrace that was saved in the directory 102 """ 103 warnings.warn( 104 "The `load_trace` function will soon be removed." 105 "Instead, use `arviz.from_netcdf` to load traces.", 106 DeprecationWarning, 107 ) 108 straces = [] 109 for subdir in glob.glob(os.path.join(directory, "*")): 110 if os.path.isdir(subdir): 111 straces.append(SerializeNDArray(subdir).load(model)) 112 if not straces: 113 raise TraceDirectoryError("%s is not a PyMC3 saved chain directory." % directory) 114 return base.MultiTrace(straces) 115 116 117class SerializeNDArray: 118 metadata_file = "metadata.json" 119 samples_file = "samples.npz" 120 metadata_path = None # type: str 121 samples_path = None # type: str 122 123 def __init__(self, directory: str): 124 """Helper to save and load NDArray objects""" 125 warnings.warn( 126 "The `SerializeNDArray` class will soon be removed. " 127 "Instead, use ArviZ to save/load traces.", 128 DeprecationWarning, 129 ) 130 self.directory = directory 131 self.metadata_path = os.path.join(self.directory, self.metadata_file) 132 self.samples_path = os.path.join(self.directory, self.samples_file) 133 134 @staticmethod 135 def to_metadata(ndarray): 136 """Extract ndarray metadata into json-serializable content""" 137 if ndarray._stats is None: 138 stats = ndarray._stats 139 sampler_vars = None 140 else: 141 stats = [] 142 sampler_vars = [] 143 for stat in ndarray._stats: 144 stats.append({key: value.tolist() for key, value in stat.items()}) 145 sampler_vars.append({key: str(value.dtype) for key, value in stat.items()}) 146 147 metadata = { 148 "draw_idx": ndarray.draw_idx, 149 "draws": ndarray.draws, 150 "_stats": stats, 151 "chain": ndarray.chain, 152 "sampler_vars": sampler_vars, 153 } 154 return metadata 155 156 def save(self, ndarray): 157 """Serialize a ndarray to file 158 159 The goal here is to be modestly safer and more portable than a 160 pickle file. The expense is that the model code must be available 161 to reload the multitrace. 162 """ 163 if not isinstance(ndarray, NDArray): 164 raise TypeError("Can only save NDArray") 165 166 if os.path.isdir(self.directory): 167 shutil.rmtree(self.directory) 168 169 os.mkdir(self.directory) 170 171 with open(self.metadata_path, "w") as buff: 172 json.dump(SerializeNDArray.to_metadata(ndarray), buff) 173 174 np.savez_compressed(self.samples_path, **ndarray.samples) 175 176 def load(self, model: Model) -> "NDArray": 177 """Load the saved ndarray from file""" 178 if not os.path.exists(self.samples_path) or not os.path.exists(self.metadata_path): 179 raise TraceDirectoryError("%s is not a trace directory" % self.directory) 180 181 new_trace = NDArray(model=model) 182 with open(self.metadata_path) as buff: 183 metadata = json.load(buff) 184 185 metadata["_stats"] = [ 186 {k: np.array(v) for k, v in stat.items()} for stat in metadata["_stats"] 187 ] 188 189 # it seems like at least some old traces don't have 'sampler_vars' 190 try: 191 sampler_vars = metadata.pop("sampler_vars") 192 new_trace._set_sampler_vars(sampler_vars) 193 except KeyError: 194 pass 195 196 for key, value in metadata.items(): 197 setattr(new_trace, key, value) 198 new_trace.samples = dict(np.load(self.samples_path)) 199 return new_trace 200 201 202class NDArray(base.BaseTrace): 203 """NDArray trace object 204 205 Parameters 206 ---------- 207 name: str 208 Name of backend. This has no meaning for the NDArray backend. 209 model: Model 210 If None, the model is taken from the `with` context. 211 vars: list of variables 212 Sampling values will be stored for these variables. If None, 213 `model.unobserved_RVs` is used. 214 """ 215 216 supports_sampler_stats = True 217 218 def __init__(self, name=None, model=None, vars=None, test_point=None): 219 super().__init__(name, model, vars, test_point) 220 self.draw_idx = 0 221 self.draws = None 222 self.samples = {} 223 self._stats = None 224 225 # Sampling methods 226 227 def setup(self, draws, chain, sampler_vars=None) -> None: 228 """Perform chain-specific setup. 229 230 Parameters 231 ---------- 232 draws: int 233 Expected number of draws 234 chain: int 235 Chain number 236 sampler_vars: list of dicts 237 Names and dtypes of the variables that are 238 exported by the samplers. 239 """ 240 super().setup(draws, chain, sampler_vars) 241 242 self.chain = chain 243 if self.samples: # Concatenate new array if chain is already present. 244 old_draws = len(self) 245 self.draws = old_draws + draws 246 self.draw_idx = old_draws 247 for varname, shape in self.var_shapes.items(): 248 old_var_samples = self.samples[varname] 249 new_var_samples = np.zeros((draws,) + shape, self.var_dtypes[varname]) 250 self.samples[varname] = np.concatenate((old_var_samples, new_var_samples), axis=0) 251 else: # Otherwise, make array of zeros for each variable. 252 self.draws = draws 253 for varname, shape in self.var_shapes.items(): 254 self.samples[varname] = np.zeros((draws,) + shape, dtype=self.var_dtypes[varname]) 255 256 if sampler_vars is None: 257 return 258 259 if self._stats is None: 260 self._stats = [] 261 for sampler in sampler_vars: 262 data = dict() # type: Dict[str, np.ndarray] 263 self._stats.append(data) 264 for varname, dtype in sampler.items(): 265 data[varname] = np.zeros(draws, dtype=dtype) 266 else: 267 for data, vars in zip(self._stats, sampler_vars): 268 if vars.keys() != data.keys(): 269 raise ValueError("Sampler vars can't change") 270 old_draws = len(self) 271 for varname, dtype in vars.items(): 272 old = data[varname] 273 new = np.zeros(draws, dtype=dtype) 274 data[varname] = np.concatenate([old, new]) 275 276 def record(self, point, sampler_stats=None) -> None: 277 """Record results of a sampling iteration. 278 279 Parameters 280 ---------- 281 point: dict 282 Values mapped to variable names 283 """ 284 for varname, value in zip(self.varnames, self.fn(point)): 285 self.samples[varname][self.draw_idx] = value 286 287 if self._stats is not None and sampler_stats is None: 288 raise ValueError("Expected sampler_stats") 289 if self._stats is None and sampler_stats is not None: 290 raise ValueError("Unknown sampler_stats") 291 if sampler_stats is not None: 292 for data, vars in zip(self._stats, sampler_stats): 293 for key, val in vars.items(): 294 data[key][self.draw_idx] = val 295 self.draw_idx += 1 296 297 def _get_sampler_stats(self, varname, sampler_idx, burn, thin): 298 return self._stats[sampler_idx][varname][burn::thin] 299 300 def close(self): 301 if self.draw_idx == self.draws: 302 return 303 # Remove trailing zeros if interrupted before completed all 304 # draws. 305 self.samples = {var: vtrace[: self.draw_idx] for var, vtrace in self.samples.items()} 306 if self._stats is not None: 307 self._stats = [ 308 {var: trace[: self.draw_idx] for var, trace in stats.items()} 309 for stats in self._stats 310 ] 311 312 # Selection methods 313 314 def __len__(self): 315 if not self.samples: # `setup` has not been called. 316 return 0 317 return self.draw_idx 318 319 def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray: 320 """Get values from trace. 321 322 Parameters 323 ---------- 324 varname: str 325 burn: int 326 thin: int 327 328 Returns 329 ------- 330 A NumPy array 331 """ 332 return self.samples[varname][burn::thin] 333 334 def _slice(self, idx): 335 # Slicing directly instead of using _slice_as_ndarray to 336 # support stop value in slice (which is needed by 337 # iter_sample). 338 339 # Only the first `draw_idx` value are valid because of preallocation 340 idx = slice(*idx.indices(len(self))) 341 342 sliced = NDArray(model=self.model, vars=self.vars) 343 sliced.chain = self.chain 344 sliced.samples = {varname: values[idx] for varname, values in self.samples.items()} 345 sliced.sampler_vars = self.sampler_vars 346 sliced.draw_idx = (idx.stop - idx.start) // idx.step 347 348 if self._stats is None: 349 return sliced 350 sliced._stats = [] 351 for vars in self._stats: 352 var_sliced = {} 353 sliced._stats.append(var_sliced) 354 for key, vals in vars.items(): 355 var_sliced[key] = vals[idx] 356 357 return sliced 358 359 def point(self, idx) -> Dict[str, Any]: 360 """Return dictionary of point values at `idx` for current chain 361 with variable names as keys. 362 """ 363 idx = int(idx) 364 return {varname: values[idx] for varname, values in self.samples.items()} 365 366 367def _slice_as_ndarray(strace, idx): 368 sliced = NDArray(model=strace.model, vars=strace.vars) 369 sliced.chain = strace.chain 370 371 # Happy path where we do not need to load everything from the trace 372 if (idx.step is None or idx.step >= 1) and (idx.stop is None or idx.stop == len(strace)): 373 start, stop, step = idx.indices(len(strace)) 374 sliced.samples = { 375 v: strace.get_values(v, burn=idx.start, thin=idx.step) for v in strace.varnames 376 } 377 sliced.draw_idx = (stop - start) // step 378 else: 379 start, stop, step = idx.indices(len(strace)) 380 sliced.samples = {v: strace.get_values(v)[start:stop:step] for v in strace.varnames} 381 sliced.draw_idx = (stop - start) // step 382 383 return sliced 384 385 386def point_list_to_multitrace( 387 point_list: List[Dict[str, np.ndarray]], model: Optional[Model] = None 388) -> MultiTrace: 389 """transform point list into MultiTrace""" 390 _model = modelcontext(model) 391 varnames = list(point_list[0].keys()) 392 with _model: 393 chain = NDArray(model=_model, vars=[_model[vn] for vn in varnames]) 394 chain.setup(draws=len(point_list), chain=0) 395 # since we are simply loading a trace by hand, we need only a vacuous function for 396 # chain.record() to use. This crushes the default. 397 def point_fun(point): 398 return [point[vn] for vn in varnames] 399 400 chain.fn = point_fun 401 for point in point_list: 402 chain.record(point) 403 return MultiTrace([chain]) 404