1# Copyright 2020 The PyMC Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Functions for MCMC sampling.""" 16 17import collections.abc as abc 18import logging 19import pickle 20import sys 21import time 22import warnings 23 24from collections import defaultdict 25from copy import copy, deepcopy 26from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast 27 28import arviz 29import numpy as np 30import packaging 31import theano.gradient as tg 32import xarray 33 34from arviz import InferenceData 35from fastprogress.fastprogress import progress_bar 36 37import pymc3 as pm 38 39from pymc3.backends.base import BaseTrace, MultiTrace 40from pymc3.backends.ndarray import NDArray 41from pymc3.distributions.distribution import draw_values 42from pymc3.distributions.posterior_predictive import fast_sample_posterior_predictive 43from pymc3.exceptions import IncorrectArgumentsError, SamplingError 44from pymc3.model import Model, Point, all_continuous, modelcontext 45from pymc3.parallel_sampling import Draw, _cpu_count 46from pymc3.step_methods import ( 47 NUTS, 48 PGBART, 49 BinaryGibbsMetropolis, 50 BinaryMetropolis, 51 CategoricalGibbsMetropolis, 52 CompoundStep, 53 DEMetropolis, 54 HamiltonianMC, 55 Metropolis, 56 Slice, 57) 58from pymc3.step_methods.arraystep import BlockedStep, PopulationArrayStepShared 59from pymc3.step_methods.hmc import quadpotential 60from pymc3.util import ( 61 chains_and_samples, 62 check_start_vals, 63 dataset_to_point_list, 64 get_default_varnames, 65 get_untransformed_name, 66 is_transformed_name, 67 update_start_vals, 68) 69from pymc3.vartypes import discrete_types 70 71sys.setrecursionlimit(10000) 72 73__all__ = [ 74 "sample", 75 "iter_sample", 76 "sample_posterior_predictive", 77 "sample_posterior_predictive_w", 78 "init_nuts", 79 "sample_prior_predictive", 80 "fast_sample_posterior_predictive", 81] 82 83STEP_METHODS = ( 84 NUTS, 85 HamiltonianMC, 86 Metropolis, 87 BinaryMetropolis, 88 BinaryGibbsMetropolis, 89 Slice, 90 CategoricalGibbsMetropolis, 91 PGBART, 92) 93Step = Union[BlockedStep, CompoundStep] 94 95ArrayLike = Union[np.ndarray, List[float]] 96PointType = Dict[str, np.ndarray] 97PointList = List[PointType] 98Backend = Union[BaseTrace, MultiTrace, NDArray] 99 100_log = logging.getLogger("pymc3") 101 102 103def instantiate_steppers( 104 _model, steps: List[Step], selected_steps, step_kwargs=None 105) -> Union[Step, List[Step]]: 106 """Instantiate steppers assigned to the model variables. 107 108 This function is intended to be called automatically from ``sample()``, but 109 may be called manually. 110 111 Parameters 112 ---------- 113 model : Model object 114 A fully-specified model object; legacy argument -- ignored 115 steps : list 116 A list of zero or more step function instances that have been assigned to some subset of 117 the model's parameters. 118 selected_steps : dict 119 A dictionary that maps a step method class to a list of zero or more model variables. 120 step_kwargs : dict 121 Parameters for the samplers. Keys are the lower case names of 122 the step method, values a dict of arguments. Defaults to None. 123 124 Returns 125 ------- 126 methods : list or step 127 List of step methods associated with the model's variables, or step method 128 if there is only one. 129 """ 130 if step_kwargs is None: 131 step_kwargs = {} 132 133 used_keys = set() 134 for step_class, vars in selected_steps.items(): 135 if vars: 136 args = step_kwargs.get(step_class.name, {}) 137 used_keys.add(step_class.name) 138 step = step_class(vars=vars, **args) 139 steps.append(step) 140 141 unused_args = set(step_kwargs).difference(used_keys) 142 if unused_args: 143 raise ValueError("Unused step method arguments: %s" % unused_args) 144 145 if len(steps) == 1: 146 return steps[0] 147 148 return steps 149 150 151def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None): 152 """Assign model variables to appropriate step methods. 153 154 Passing a specified model will auto-assign its constituent stochastic 155 variables to step methods based on the characteristics of the variables. 156 This function is intended to be called automatically from ``sample()``, but 157 may be called manually. Each step method passed should have a 158 ``competence()`` method that returns an ordinal competence value 159 corresponding to the variable passed to it. This value quantifies the 160 appropriateness of the step method for sampling the variable. 161 162 Parameters 163 ---------- 164 model : Model object 165 A fully-specified model object 166 step : step function or vector of step functions 167 One or more step functions that have been assigned to some subset of 168 the model's parameters. Defaults to ``None`` (no assigned variables). 169 methods : vector of step method classes 170 The set of step methods from which the function may choose. Defaults 171 to the main step methods provided by PyMC3. 172 step_kwargs : dict 173 Parameters for the samplers. Keys are the lower case names of 174 the step method, values a dict of arguments. 175 176 Returns 177 ------- 178 methods : list 179 List of step methods associated with the model's variables. 180 """ 181 steps = [] 182 assigned_vars = set() 183 184 if step is not None: 185 try: 186 steps += list(step) 187 except TypeError: 188 steps.append(step) 189 for step in steps: 190 try: 191 assigned_vars = assigned_vars.union(set(step.vars)) 192 except AttributeError: 193 for method in step.methods: 194 assigned_vars = assigned_vars.union(set(method.vars)) 195 196 # Use competence classmethods to select step methods for remaining 197 # variables 198 selected_steps = defaultdict(list) 199 for var in model.free_RVs: 200 if var not in assigned_vars: 201 # determine if a gradient can be computed 202 has_gradient = var.dtype not in discrete_types 203 if has_gradient: 204 try: 205 tg.grad(model.logpt, var) 206 except (AttributeError, NotImplementedError, tg.NullTypeGradError): 207 has_gradient = False 208 # select the best method 209 selected = max( 210 methods, 211 key=lambda method, var=var, has_gradient=has_gradient: method._competence( 212 var, has_gradient 213 ), 214 ) 215 selected_steps[selected].append(var) 216 217 return instantiate_steppers(model, steps, selected_steps, step_kwargs) 218 219 220def _print_step_hierarchy(s: Step, level=0) -> None: 221 if isinstance(s, CompoundStep): 222 _log.info(">" * level + "CompoundStep") 223 for i in s.methods: 224 _print_step_hierarchy(i, level + 1) 225 else: 226 varnames = ", ".join( 227 [ 228 get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name 229 for v in s.vars 230 ] 231 ) 232 _log.info(">" * level + f"{s.__class__.__name__}: [{varnames}]") 233 234 235def sample( 236 draws=1000, 237 step=None, 238 init="auto", 239 n_init=200000, 240 start=None, 241 trace=None, 242 chain_idx=0, 243 chains=None, 244 cores=None, 245 tune=1000, 246 progressbar=True, 247 model=None, 248 random_seed=None, 249 discard_tuned_samples=True, 250 compute_convergence_checks=True, 251 callback=None, 252 jitter_max_retries=10, 253 *, 254 return_inferencedata=None, 255 idata_kwargs: dict = None, 256 mp_ctx=None, 257 pickle_backend: str = "pickle", 258 **kwargs, 259): 260 r"""Draw samples from the posterior using the given step methods. 261 262 Multiple step methods are supported via compound step methods. 263 264 Parameters 265 ---------- 266 draws : int 267 The number of samples to draw. Defaults to 1000. The number of tuned samples are discarded 268 by default. See ``discard_tuned_samples``. 269 init : str 270 Initialization method to use for auto-assigned NUTS samplers. 271 272 * auto: Choose a default initialization method automatically. 273 Currently, this is ``jitter+adapt_diag``, but this can change in the future. 274 If you depend on the exact behaviour, choose an initialization method explicitly. 275 * adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the 276 variance of the tuning samples. All chains use the test value (usually the prior mean) 277 as starting point. 278 * jitter+adapt_diag: Same as ``adapt_diag``, but add uniform jitter in [-1, 1] to the 279 starting point in each chain. 280 * advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the 281 sample variance of the tuning samples. 282 * advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based 283 on the variance of the gradients during tuning. This is **experimental** and might be 284 removed in a future release. 285 * advi: Run ADVI to estimate posterior mean and diagonal mass matrix. 286 * advi_map: Initialize ADVI with MAP and use MAP as starting point. 287 * map: Use the MAP as starting point. This is discouraged. 288 * adapt_full: Adapt a dense mass matrix using the sample covariances 289 290 step : function or iterable of functions 291 A step function or collection of functions. If there are variables without step methods, 292 step methods for those variables will be assigned automatically. By default the NUTS step 293 method will be used, if appropriate to the model; this is a good default for beginning 294 users. 295 n_init : int 296 Number of iterations of initializer. Only works for 'ADVI' init methods. 297 start : dict, or array of dict 298 Starting point in parameter space (or partial point) 299 Defaults to ``trace.point(-1))`` if there is a trace provided and model.test_point if not 300 (defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can 301 overwrite the default. 302 trace : backend, list, or MultiTrace 303 This should be a backend instance, a list of variables to track, or a MultiTrace object 304 with past values. If a MultiTrace object is given, it must contain samples for the chain 305 number ``chain``. If None or a list of variables, the NDArray backend is used. 306 chain_idx : int 307 Chain number used to store sample in backend. If ``chains`` is greater than one, chain 308 numbers will start here. 309 chains : int 310 The number of chains to sample. Running independent chains is important for some 311 convergence statistics and can also reveal multiple modes in the posterior. If ``None``, 312 then set to either ``cores`` or 2, whichever is larger. 313 cores : int 314 The number of chains to run in parallel. If ``None``, set to the number of CPUs in the 315 system, but at most 4. 316 tune : int 317 Number of iterations to tune, defaults to 1000. Samplers adjust the step sizes, scalings or 318 similar during tuning. Tuning samples will be drawn in addition to the number specified in 319 the ``draws`` argument, and will be discarded unless ``discard_tuned_samples`` is set to 320 False. 321 progressbar : bool, optional default=True 322 Whether or not to display a progress bar in the command line. The bar shows the percentage 323 of completion, the sampling speed in samples per second (SPS), and the estimated remaining 324 time until completion ("expected time of arrival"; ETA). 325 model : Model (optional if in ``with`` context) 326 random_seed : int or list of ints 327 A list is accepted if ``cores`` is greater than one. 328 discard_tuned_samples : bool 329 Whether to discard posterior samples of the tune interval. 330 compute_convergence_checks : bool, default=True 331 Whether to compute sampler statistics like Gelman-Rubin and ``effective_n``. 332 callback : function, default=None 333 A function which gets called for every sample from the trace of a chain. The function is 334 called with the trace and the current draw and will contain all samples for a single trace. 335 the ``draw.chain`` argument can be used to determine which of the active chains the sample 336 is drawn from. 337 Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback. 338 jitter_max_retries : int 339 Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter 340 that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full`` 341 init methods. 342 return_inferencedata : bool, default=False 343 Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False) 344 Defaults to `False`, but we'll switch to `True` in an upcoming release. 345 idata_kwargs : dict, optional 346 Keyword arguments for :func:`arviz:arviz.from_pymc3` 347 mp_ctx : multiprocessing.context.BaseContent 348 A multiprocessing context for parallel sampling. See multiprocessing 349 documentation for details. 350 pickle_backend : str 351 One of `'pickle'` or `'dill'`. The library used to pickle models 352 in parallel sampling if the multiprocessing context is not of type 353 `fork`. 354 355 Returns 356 ------- 357 trace : pymc3.backends.base.MultiTrace or arviz.InferenceData 358 A ``MultiTrace`` or ArviZ ``InferenceData`` object that contains the samples. 359 360 Notes 361 ----- 362 Optional keyword arguments can be passed to ``sample`` to be delivered to the 363 ``step_method``\ s used during sampling. 364 365 If your model uses only one step method, you can address step method kwargs 366 directly. In particular, the NUTS step method has several options including: 367 368 * target_accept : float in [0, 1]. The step size is tuned such that we 369 approximate this acceptance rate. Higher values like 0.9 or 0.95 often 370 work better for problematic posteriors 371 * max_treedepth : The maximum depth of the trajectory tree 372 * step_scale : float, default 0.25 373 The initial guess for the step size scaled down by :math:`1/n**(1/4)` 374 375 If your model uses multiple step methods, aka a Compound Step, then you have 376 two ways to address arguments to each step method: 377 378 A. If you let ``sample()`` automatically assign the ``step_method``\ s, 379 and you can correctly anticipate what they will be, then you can wrap 380 step method kwargs in a dict and pass that to sample() with a kwarg set 381 to the name of the step method. 382 e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis, 383 you could send: 384 385 1. ``target_accept`` to NUTS: nuts={'target_accept':0.9} 386 2. ``transit_p`` to BinaryGibbsMetropolis: binary_gibbs_metropolis={'transit_p':.7} 387 388 Note that available names are: 389 390 ``nuts``, ``hmc``, ``metropolis``, ``binary_metropolis``, 391 ``binary_gibbs_metropolis``, ``categorical_gibbs_metropolis``, 392 ``DEMetropolis``, ``DEMetropolisZ``, ``slice`` 393 394 B. If you manually declare the ``step_method``\ s, within the ``step`` 395 kwarg, then you can address the ``step_method`` kwargs directly. 396 e.g. for a CompoundStep comprising NUTS and BinaryGibbsMetropolis, 397 you could send :: 398 399 step=[pm.NUTS([freeRV1, freeRV2], target_accept=0.9), 400 pm.BinaryGibbsMetropolis([freeRV3], transit_p=.7)] 401 402 You can find a full list of arguments in the docstring of the step methods. 403 404 Examples 405 -------- 406 .. code:: ipython 407 408 In [1]: import pymc3 as pm 409 ...: n = 100 410 ...: h = 61 411 ...: alpha = 2 412 ...: beta = 2 413 414 In [2]: with pm.Model() as model: # context management 415 ...: p = pm.Beta("p", alpha=alpha, beta=beta) 416 ...: y = pm.Binomial("y", n=n, p=p, observed=h) 417 ...: trace = pm.sample() 418 419 In [3]: az.summary(trace, kind="stats") 420 421 Out[3]: 422 mean sd hdi_3% hdi_97% 423 p 0.609 0.047 0.528 0.699 424 """ 425 model = modelcontext(model) 426 start = deepcopy(start) 427 if start is None: 428 check_start_vals(model.test_point, model) 429 else: 430 if isinstance(start, dict): 431 update_start_vals(start, model.test_point, model) 432 else: 433 for chain_start_vals in start: 434 update_start_vals(chain_start_vals, model.test_point, model) 435 check_start_vals(start, model) 436 437 if cores is None: 438 cores = min(4, _cpu_count()) 439 440 if chains is None: 441 chains = max(2, cores) 442 if isinstance(start, dict): 443 start = [start] * chains 444 if random_seed == -1: 445 random_seed = None 446 if chains == 1 and isinstance(random_seed, int): 447 random_seed = [random_seed] 448 if random_seed is None or isinstance(random_seed, int): 449 if random_seed is not None: 450 np.random.seed(random_seed) 451 random_seed = [np.random.randint(2 ** 30) for _ in range(chains)] 452 if not isinstance(random_seed, abc.Iterable): 453 raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int") 454 455 if not discard_tuned_samples and not return_inferencedata: 456 warnings.warn( 457 "Tuning samples will be included in the returned `MultiTrace` object, which can lead to" 458 " complications in your downstream analysis. Please consider to switch to `InferenceData`:\n" 459 "`pm.sample(..., return_inferencedata=True)`", 460 UserWarning, 461 ) 462 463 if return_inferencedata is None: 464 v = packaging.version.parse(pm.__version__) 465 if v.release[0] > 3 or v.release[1] >= 10: # type: ignore 466 warnings.warn( 467 "In v4.0, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. " 468 "You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.", 469 FutureWarning, 470 stacklevel=2, 471 ) 472 # set the default 473 return_inferencedata = False 474 475 if start is not None: 476 for start_vals in start: 477 _check_start_shape(model, start_vals) 478 479 # small trace warning 480 if draws == 0: 481 msg = "Tuning was enabled throughout the whole trace." 482 _log.warning(msg) 483 elif draws < 500: 484 msg = "Only %s samples in chain." % draws 485 _log.warning(msg) 486 487 draws += tune 488 489 if model.ndim == 0: 490 raise ValueError("The model does not contain any free variables.") 491 492 if step is None and init is not None and all_continuous(model.vars): 493 try: 494 # By default, try to use NUTS 495 _log.info("Auto-assigning NUTS sampler...") 496 start_, step = init_nuts( 497 init=init, 498 chains=chains, 499 n_init=n_init, 500 model=model, 501 random_seed=random_seed, 502 progressbar=progressbar, 503 jitter_max_retries=jitter_max_retries, 504 **kwargs, 505 ) 506 if start is None: 507 start = start_ 508 check_start_vals(start, model) 509 except (AttributeError, NotImplementedError, tg.NullTypeGradError): 510 # gradient computation failed 511 _log.info("Initializing NUTS failed. " "Falling back to elementwise auto-assignment.") 512 _log.debug("Exception in init nuts", exec_info=True) 513 step = assign_step_methods(model, step, step_kwargs=kwargs) 514 else: 515 step = assign_step_methods(model, step, step_kwargs=kwargs) 516 517 if isinstance(step, list): 518 step = CompoundStep(step) 519 if start is None: 520 start = {} 521 if isinstance(start, dict): 522 start = [start] * chains 523 524 sample_args = { 525 "draws": draws, 526 "step": step, 527 "start": start, 528 "trace": trace, 529 "chain": chain_idx, 530 "chains": chains, 531 "tune": tune, 532 "progressbar": progressbar, 533 "model": model, 534 "random_seed": random_seed, 535 "cores": cores, 536 "callback": callback, 537 "discard_tuned_samples": discard_tuned_samples, 538 } 539 parallel_args = { 540 "pickle_backend": pickle_backend, 541 "mp_ctx": mp_ctx, 542 } 543 544 sample_args.update(kwargs) 545 546 has_population_samplers = np.any( 547 [ 548 isinstance(m, PopulationArrayStepShared) 549 for m in (step.methods if isinstance(step, CompoundStep) else [step]) 550 ] 551 ) 552 553 parallel = cores > 1 and chains > 1 and not has_population_samplers 554 t_start = time.time() 555 if parallel: 556 _log.info(f"Multiprocess sampling ({chains} chains in {cores} jobs)") 557 _print_step_hierarchy(step) 558 try: 559 trace = _mp_sample(**sample_args, **parallel_args) 560 except pickle.PickleError: 561 _log.warning("Could not pickle model, sampling singlethreaded.") 562 _log.debug("Pickling error:", exec_info=True) 563 parallel = False 564 except AttributeError as e: 565 if not str(e).startswith("AttributeError: Can't pickle"): 566 raise 567 _log.warning("Could not pickle model, sampling singlethreaded.") 568 _log.debug("Pickling error:", exec_info=True) 569 parallel = False 570 if not parallel: 571 if has_population_samplers: 572 has_demcmc = np.any( 573 [ 574 isinstance(m, DEMetropolis) 575 for m in (step.methods if isinstance(step, CompoundStep) else [step]) 576 ] 577 ) 578 _log.info(f"Population sampling ({chains} chains)") 579 if has_demcmc and chains < 3: 580 raise ValueError( 581 "DEMetropolis requires at least 3 chains. " 582 "For this {}-dimensional model you should use ≥{} chains".format( 583 model.ndim, model.ndim + 1 584 ) 585 ) 586 if has_demcmc and chains <= model.ndim: 587 warnings.warn( 588 "DEMetropolis should be used with more chains than dimensions! " 589 "(The model has {} dimensions.)".format(model.ndim), 590 UserWarning, 591 ) 592 _print_step_hierarchy(step) 593 trace = _sample_population(parallelize=cores > 1, **sample_args) 594 else: 595 _log.info(f"Sequential sampling ({chains} chains in 1 job)") 596 _print_step_hierarchy(step) 597 trace = _sample_many(**sample_args) 598 599 t_sampling = time.time() - t_start 600 # count the number of tune/draw iterations that happened 601 # ideally via the "tune" statistic, but not all samplers record it! 602 if "tune" in trace.stat_names: 603 stat = trace.get_sampler_stats("tune", chains=0) 604 # when CompoundStep is used, the stat is 2 dimensional! 605 if len(stat.shape) == 2: 606 stat = stat[:, 0] 607 stat = tuple(stat) 608 n_tune = stat.count(True) 609 n_draws = stat.count(False) 610 else: 611 # these may be wrong when KeyboardInterrupt happened, but they're better than nothing 612 n_tune = min(tune, len(trace)) 613 n_draws = max(0, len(trace) - n_tune) 614 615 if discard_tuned_samples: 616 trace = trace[n_tune:] 617 618 # save metadata in SamplerReport 619 trace.report._n_tune = n_tune 620 trace.report._n_draws = n_draws 621 trace.report._t_sampling = t_sampling 622 623 if "variable_inclusion" in trace.stat_names: 624 variable_inclusion = np.stack(trace.get_sampler_stats("variable_inclusion")).mean(0) 625 trace.report.variable_importance = variable_inclusion / variable_inclusion.sum() 626 627 n_chains = len(trace.chains) 628 _log.info( 629 f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations ' 630 f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) " 631 f"took {trace.report.t_sampling:.0f} seconds." 632 ) 633 634 idata = None 635 if compute_convergence_checks or return_inferencedata: 636 ikwargs = dict(model=model, save_warmup=not discard_tuned_samples) 637 if idata_kwargs: 638 ikwargs.update(idata_kwargs) 639 idata = arviz.from_pymc3(trace, **ikwargs) 640 641 if compute_convergence_checks: 642 if draws - tune < 100: 643 warnings.warn("The number of samples is too small to check convergence reliably.") 644 else: 645 trace.report._run_convergence_checks(idata, model) 646 trace.report._log_summary() 647 648 if return_inferencedata: 649 return idata 650 else: 651 return trace 652 653 654def _check_start_shape(model, start): 655 if not isinstance(start, dict): 656 raise TypeError("start argument must be a dict or an array-like of dicts") 657 e = "" 658 for var in model.vars: 659 if var.name in start.keys(): 660 var_shape = var.shape.tag.test_value 661 start_var_shape = np.shape(start[var.name]) 662 if start_var_shape: 663 if not np.array_equal(var_shape, start_var_shape): 664 e += "\nExpected shape {} for var '{}', got: {}".format( 665 tuple(var_shape), var.name, start_var_shape 666 ) 667 # if start var has no shape 668 else: 669 # if model var has a specified shape 670 if var_shape.size > 0: 671 e += "\nExpected shape {} for var " "'{}', got scalar {}".format( 672 tuple(var_shape), var.name, start[var.name] 673 ) 674 675 if e != "": 676 raise ValueError(f"Bad shape for start argument:{e}") 677 678 679def _sample_many( 680 draws, 681 chain: int, 682 chains: int, 683 start: list, 684 random_seed: list, 685 step, 686 callback=None, 687 **kwargs, 688): 689 """Samples all chains sequentially. 690 691 Parameters 692 ---------- 693 draws: int 694 The number of samples to draw 695 chain: int 696 Number of the first chain in the sequence. 697 chains: int 698 Total number of chains to sample. 699 start: list 700 Starting points for each chain 701 random_seed: list 702 A list of seeds, one for each chain 703 step: function 704 Step function 705 706 Returns 707 ------- 708 trace: MultiTrace 709 Contains samples of all chains 710 """ 711 traces: List[Backend] = [] 712 for i in range(chains): 713 trace = _sample( 714 draws=draws, 715 chain=chain + i, 716 start=start[i], 717 step=step, 718 random_seed=random_seed[i], 719 callback=callback, 720 **kwargs, 721 ) 722 if trace is None: 723 if len(traces) == 0: 724 raise ValueError("Sampling stopped before a sample was created.") 725 else: 726 break 727 elif len(trace) < draws: 728 if len(traces) == 0: 729 traces.append(trace) 730 break 731 else: 732 traces.append(trace) 733 return MultiTrace(traces) 734 735 736def _sample_population( 737 draws: int, 738 chain: int, 739 chains: int, 740 start, 741 random_seed, 742 step, 743 tune, 744 model, 745 progressbar: bool = True, 746 parallelize=False, 747 **kwargs, 748): 749 """Performs sampling of a population of chains using the ``PopulationStepper``. 750 751 Parameters 752 ---------- 753 draws : int 754 The number of samples to draw 755 chain : int 756 The number of the first chain in the population 757 chains : int 758 The total number of chains in the population 759 start : list 760 Start points for each chain 761 random_seed : int or list of ints, optional 762 A list is accepted if more if ``cores`` is greater than one. 763 step : function 764 Step function (should be or contain a population step method) 765 tune : int, optional 766 Number of iterations to tune, if applicable (defaults to None) 767 model : Model (optional if in ``with`` context) 768 progressbar : bool 769 Show progress bars? (defaults to True) 770 parallelize : bool 771 Setting for multiprocess parallelization 772 773 Returns 774 ------- 775 trace : MultiTrace 776 Contains samples of all chains 777 """ 778 sampling = _prepare_iter_population( 779 draws, 780 [chain + c for c in range(chains)], 781 step, 782 start, 783 parallelize, 784 tune=tune, 785 model=model, 786 random_seed=random_seed, 787 progressbar=progressbar, 788 ) 789 790 if progressbar: 791 sampling = progress_bar(sampling, total=draws, display=progressbar) 792 793 latest_traces = None 794 for it, traces in enumerate(sampling): 795 latest_traces = traces 796 return MultiTrace(latest_traces) 797 798 799def _sample( 800 chain: int, 801 progressbar: bool, 802 random_seed, 803 start, 804 draws: int, 805 step=None, 806 trace=None, 807 tune=None, 808 model: Optional[Model] = None, 809 callback=None, 810 **kwargs, 811): 812 """Main iteration for singleprocess sampling. 813 814 Multiple step methods are supported via compound step methods. 815 816 Parameters 817 ---------- 818 chain : int 819 Number of the chain that the samples will belong to. 820 progressbar : bool 821 Whether or not to display a progress bar in the command line. The bar shows the percentage 822 of completion, the sampling speed in samples per second (SPS), and the estimated remaining 823 time until completion ("expected time of arrival"; ETA). 824 random_seed : int or list of ints 825 A list is accepted if ``cores`` is greater than one. 826 start : dict 827 Starting point in parameter space (or partial point) 828 draws : int 829 The number of samples to draw 830 step : function 831 Step function 832 trace : backend, list, or MultiTrace 833 This should be a backend instance, a list of variables to track, or a MultiTrace object 834 with past values. If a MultiTrace object is given, it must contain samples for the chain 835 number ``chain``. If None or a list of variables, the NDArray backend is used. 836 tune : int, optional 837 Number of iterations to tune, if applicable (defaults to None) 838 model : Model (optional if in ``with`` context) 839 840 Returns 841 ------- 842 strace : pymc3.backends.base.BaseTrace 843 A ``BaseTrace`` object that contains the samples for this chain. 844 """ 845 skip_first = kwargs.get("skip_first", 0) 846 847 sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed, callback) 848 _pbar_data = {"chain": chain, "divergences": 0} 849 _desc = "Sampling chain {chain:d}, {divergences:,d} divergences" 850 if progressbar: 851 sampling = progress_bar(sampling, total=draws, display=progressbar) 852 sampling.comment = _desc.format(**_pbar_data) 853 try: 854 strace = None 855 for it, (strace, diverging) in enumerate(sampling): 856 if it >= skip_first and diverging: 857 _pbar_data["divergences"] += 1 858 if progressbar: 859 sampling.comment = _desc.format(**_pbar_data) 860 except KeyboardInterrupt: 861 pass 862 return strace 863 864 865def iter_sample( 866 draws: int, 867 step, 868 start: Optional[Dict[Any, Any]] = None, 869 trace=None, 870 chain=0, 871 tune: Optional[int] = None, 872 model: Optional[Model] = None, 873 random_seed: Optional[Union[int, List[int]]] = None, 874 callback=None, 875): 876 """Generate a trace on each iteration using the given step method. 877 878 Multiple step methods ared supported via compound step methods. Returns the 879 amount of time taken. 880 881 Parameters 882 ---------- 883 draws : int 884 The number of samples to draw 885 step : function 886 Step function 887 start : dict 888 Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if 889 there is a trace provided and model.test_point if not (defaults to empty dict) 890 trace : backend, list, or MultiTrace 891 This should be a backend instance, a list of variables to track, or a MultiTrace object 892 with past values. If a MultiTrace object is given, it must contain samples for the chain 893 number ``chain``. If None or a list of variables, the NDArray backend is used. 894 chain : int, optional 895 Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers 896 will start here. 897 tune : int, optional 898 Number of iterations to tune, if applicable (defaults to None) 899 model : Model (optional if in ``with`` context) 900 random_seed : int or list of ints, optional 901 A list is accepted if more if ``cores`` is greater than one. 902 callback : 903 A function which gets called for every sample from the trace of a chain. The function is 904 called with the trace and the current draw and will contain all samples for a single trace. 905 the ``draw.chain`` argument can be used to determine which of the active chains the sample 906 is drawn from. 907 Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback. 908 909 Yields 910 ------ 911 trace : MultiTrace 912 Contains all samples up to the current iteration 913 914 Examples 915 -------- 916 :: 917 918 for trace in iter_sample(500, step): 919 ... 920 """ 921 sampling = _iter_sample(draws, step, start, trace, chain, tune, model, random_seed, callback) 922 for i, (strace, _) in enumerate(sampling): 923 yield MultiTrace([strace[: i + 1]]) 924 925 926def _iter_sample( 927 draws, 928 step, 929 start=None, 930 trace=None, 931 chain=0, 932 tune=None, 933 model=None, 934 random_seed=None, 935 callback=None, 936): 937 """Generator for sampling one chain. (Used in singleprocess sampling.) 938 939 Parameters 940 ---------- 941 draws : int 942 The number of samples to draw 943 step : function 944 Step function 945 start : dict, optional 946 Starting point in parameter space (or partial point). Defaults to trace.point(-1)) if 947 there is a trace provided and model.test_point if not (defaults to empty dict) 948 trace : backend, list, MultiTrace, or None 949 This should be a backend instance, a list of variables to track, or a MultiTrace object 950 with past values. If a MultiTrace object is given, it must contain samples for the chain 951 number ``chain``. If None or a list of variables, the NDArray backend is used. 952 chain : int, optional 953 Chain number used to store sample in backend. If ``cores`` is greater than one, chain numbers 954 will start here. 955 tune : int, optional 956 Number of iterations to tune, if applicable (defaults to None) 957 model : Model (optional if in ``with`` context) 958 random_seed : int or list of ints, optional 959 A list is accepted if more if ``cores`` is greater than one. 960 961 Yields 962 ------ 963 strace : BaseTrace 964 The trace object containing the samples for this chain 965 diverging : bool 966 Indicates if the draw is divergent. Only available with some samplers. 967 """ 968 model = modelcontext(model) 969 draws = int(draws) 970 if random_seed is not None: 971 np.random.seed(random_seed) 972 if draws < 1: 973 raise ValueError("Argument `draws` must be greater than 0.") 974 975 if start is None: 976 start = {} 977 978 strace = _choose_backend(trace, chain, model=model) 979 980 if len(strace) > 0: 981 update_start_vals(start, strace.point(-1), model) 982 else: 983 update_start_vals(start, model.test_point, model) 984 985 try: 986 step = CompoundStep(step) 987 except TypeError: 988 pass 989 990 point = Point(start, model=model) 991 992 if step.generates_stats and strace.supports_sampler_stats: 993 strace.setup(draws, chain, step.stats_dtypes) 994 else: 995 strace.setup(draws, chain) 996 997 try: 998 step.tune = bool(tune) 999 if hasattr(step, "reset_tuning"): 1000 step.reset_tuning() 1001 for i in range(draws): 1002 stats = None 1003 diverging = False 1004 1005 if i == 0 and hasattr(step, "iter_count"): 1006 step.iter_count = 0 1007 if i == tune: 1008 step = stop_tuning(step) 1009 if step.generates_stats: 1010 point, stats = step.step(point) 1011 if strace.supports_sampler_stats: 1012 strace.record(point, stats) 1013 diverging = i > tune and stats and stats[0].get("diverging") 1014 else: 1015 strace.record(point) 1016 else: 1017 point = step.step(point) 1018 strace.record(point) 1019 if callback is not None: 1020 warns = getattr(step, "warnings", None) 1021 callback( 1022 trace=strace, 1023 draw=Draw(chain, i == draws, i, i < tune, stats, point, warns), 1024 ) 1025 1026 yield strace, diverging 1027 except KeyboardInterrupt: 1028 strace.close() 1029 if hasattr(step, "warnings"): 1030 warns = step.warnings() 1031 strace._add_warnings(warns) 1032 raise 1033 except BaseException: 1034 strace.close() 1035 raise 1036 else: 1037 strace.close() 1038 if hasattr(step, "warnings"): 1039 warns = step.warnings() 1040 strace._add_warnings(warns) 1041 1042 1043class PopulationStepper: 1044 """Wraps population of step methods to step them in parallel with single or multiprocessing.""" 1045 1046 def __init__(self, steppers, parallelize, progressbar=True): 1047 """Use multiprocessing to parallelize chains. 1048 1049 Falls back to sequential evaluation if multiprocessing fails. 1050 1051 In the multiprocessing mode of operation, a new process is started for each 1052 chain/stepper and Pipes are used to communicate with the main process. 1053 1054 Parameters 1055 ---------- 1056 steppers : list 1057 A collection of independent step methods, one for each chain. 1058 parallelize : bool 1059 Indicates if parallelization via multiprocessing is desired. 1060 progressbar : bool 1061 Should we display a progress bar showing relative progress? 1062 """ 1063 self.nchains = len(steppers) 1064 self.is_parallelized = False 1065 self._primary_ends = [] 1066 self._processes = [] 1067 self._steppers = steppers 1068 if parallelize: 1069 try: 1070 # configure a child process for each stepper 1071 _log.info( 1072 "Attempting to parallelize chains to all cores. You can turn this off with `pm.sample(cores=1)`." 1073 ) 1074 import multiprocessing 1075 1076 for c, stepper in ( 1077 enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers) 1078 ): 1079 secondary_end, primary_end = multiprocessing.Pipe() 1080 stepper_dumps = pickle.dumps(stepper, protocol=4) 1081 process = multiprocessing.Process( 1082 target=self.__class__._run_secondary, 1083 args=(c, stepper_dumps, secondary_end), 1084 name=f"ChainWalker{c}", 1085 ) 1086 # we want the child process to exit if the parent is terminated 1087 process.daemon = True 1088 # Starting the process might fail and takes time. 1089 # By doing it in the constructor, the sampling progress bar 1090 # will not be confused by the process start. 1091 process.start() 1092 self._primary_ends.append(primary_end) 1093 self._processes.append(process) 1094 self.is_parallelized = True 1095 except Exception: 1096 _log.info( 1097 "Population parallelization failed. " 1098 "Falling back to sequential stepping of chains." 1099 ) 1100 _log.debug("Error was: ", exec_info=True) 1101 else: 1102 _log.info( 1103 "Chains are not parallelized. You can enable this by passing " 1104 "`pm.sample(cores=n)`, where n > 1." 1105 ) 1106 return super().__init__() 1107 1108 def __enter__(self): 1109 """Do nothing: processes are already started in ``__init__``.""" 1110 return 1111 1112 def __exit__(self, exc_type, exc_val, exc_tb): 1113 if len(self._processes) > 0: 1114 try: 1115 for primary_end in self._primary_ends: 1116 primary_end.send(None) 1117 for process in self._processes: 1118 process.join(timeout=3) 1119 except Exception: 1120 _log.warning("Termination failed.") 1121 return 1122 1123 @staticmethod 1124 def _run_secondary(c, stepper_dumps, secondary_end): 1125 """This method is started on a separate process to perform stepping of a chain. 1126 1127 Parameters 1128 ---------- 1129 c : int 1130 number of this chain 1131 stepper : BlockedStep 1132 a step method such as CompoundStep 1133 secondary_end : multiprocessing.connection.PipeConnection 1134 This is our connection to the main process 1135 """ 1136 # re-seed each child process to make them unique 1137 np.random.seed(None) 1138 try: 1139 stepper = pickle.loads(stepper_dumps) 1140 # the stepper is not necessarily a PopulationArraySharedStep itself, 1141 # but rather a CompoundStep. PopulationArrayStepShared.population 1142 # has to be updated, therefore we identify the substeppers first. 1143 population_steppers = [] 1144 for sm in stepper.methods if isinstance(stepper, CompoundStep) else [stepper]: 1145 if isinstance(sm, PopulationArrayStepShared): 1146 population_steppers.append(sm) 1147 while True: 1148 incoming = secondary_end.recv() 1149 # receiving a None is the signal to exit 1150 if incoming is None: 1151 break 1152 tune_stop, population = incoming 1153 if tune_stop: 1154 stop_tuning(stepper) 1155 # forward the population to the PopulationArrayStepShared objects 1156 # This is necessary because due to the process fork, the population 1157 # object is no longer shared between the steppers. 1158 for popstep in population_steppers: 1159 popstep.population = population 1160 update = stepper.step(population[c]) 1161 secondary_end.send(update) 1162 except Exception: 1163 _log.exception(f"ChainWalker{c}") 1164 return 1165 1166 def step(self, tune_stop, population): 1167 """Step the entire population of chains. 1168 1169 Parameters 1170 ---------- 1171 tune_stop : bool 1172 Indicates if the condition (i == tune) is fulfilled 1173 population : list 1174 Current Points of all chains 1175 1176 Returns 1177 ------- 1178 update : list 1179 List of (Point, stats) tuples for all chains 1180 """ 1181 updates = [None] * self.nchains 1182 if self.is_parallelized: 1183 for c in range(self.nchains): 1184 self._primary_ends[c].send((tune_stop, population)) 1185 # Blockingly get the step outcomes 1186 for c in range(self.nchains): 1187 updates[c] = self._primary_ends[c].recv() 1188 else: 1189 for c in range(self.nchains): 1190 if tune_stop: 1191 self._steppers[c] = stop_tuning(self._steppers[c]) 1192 updates[c] = self._steppers[c].step(population[c]) 1193 return updates 1194 1195 1196def _prepare_iter_population( 1197 draws: int, 1198 chains: list, 1199 step, 1200 start: list, 1201 parallelize: bool, 1202 tune=None, 1203 model=None, 1204 random_seed=None, 1205 progressbar=True, 1206): 1207 """Prepare a PopulationStepper and traces for population sampling. 1208 1209 Parameters 1210 ---------- 1211 draws : int 1212 The number of samples to draw 1213 chains : list 1214 The chain numbers in the population 1215 step : function 1216 Step function (should be or contain a population step method) 1217 start : list 1218 Start points for each chain 1219 parallelize : bool 1220 Setting for multiprocess parallelization 1221 tune : int, optional 1222 Number of iterations to tune, if applicable (defaults to None) 1223 model : Model (optional if in ``with`` context) 1224 random_seed : int or list of ints, optional 1225 A list is accepted if more if ``cores`` is greater than one. 1226 progressbar : bool 1227 ``progressbar`` argument for the ``PopulationStepper``, (defaults to True) 1228 1229 Returns 1230 ------- 1231 _iter_population : generator 1232 Yields traces of all chains at the same time 1233 """ 1234 # chains contains the chain numbers, but for indexing we need indices... 1235 nchains = len(chains) 1236 model = modelcontext(model) 1237 draws = int(draws) 1238 if random_seed is not None: 1239 np.random.seed(random_seed) 1240 if draws < 1: 1241 raise ValueError("Argument `draws` should be above 0.") 1242 1243 # The initialization of traces, samplers and points must happen in the right order: 1244 # 1. traces are initialized and update_start_vals configures variable transforms 1245 # 2. population of points is created 1246 # 3. steppers are initialized and linked to the points object 1247 # 4. traces are configured to track the sampler stats 1248 # 5. a PopulationStepper is configured for parallelized stepping 1249 1250 # 1. prepare a BaseTrace for each chain 1251 traces = [_choose_backend(None, chain, model=model) for chain in chains] 1252 for c, strace in enumerate(traces): 1253 # initialize the trace size and variable transforms 1254 if len(strace) > 0: 1255 update_start_vals(start[c], strace.point(-1), model) 1256 else: 1257 update_start_vals(start[c], model.test_point, model) 1258 1259 # 2. create a population (points) that tracks each chain 1260 # it is updated as the chains are advanced 1261 population = [Point(start[c], model=model) for c in range(nchains)] 1262 1263 # 3. Set up the steppers 1264 steppers: List[Step] = [] 1265 for c in range(nchains): 1266 # need indepenent samplers for each chain 1267 # it is important to copy the actual steppers (but not the delta_logp) 1268 if isinstance(step, CompoundStep): 1269 chainstep = CompoundStep([copy(m) for m in step.methods]) 1270 else: 1271 chainstep = copy(step) 1272 # link population samplers to the shared population state 1273 for sm in chainstep.methods if isinstance(step, CompoundStep) else [chainstep]: 1274 if isinstance(sm, PopulationArrayStepShared): 1275 sm.link_population(population, c) 1276 steppers.append(chainstep) 1277 1278 # 4. configure tracking of sampler stats 1279 for c in range(nchains): 1280 if steppers[c].generates_stats and traces[c].supports_sampler_stats: 1281 traces[c].setup(draws, c, steppers[c].stats_dtypes) 1282 else: 1283 traces[c].setup(draws, c) 1284 1285 # 5. configure the PopulationStepper (expensive call) 1286 popstep = PopulationStepper(steppers, parallelize, progressbar=progressbar) 1287 1288 # Because the preparations above are expensive, the actual iterator is 1289 # in another method. This way the progbar will not be disturbed. 1290 return _iter_population(draws, tune, popstep, steppers, traces, population) 1291 1292 1293def _iter_population(draws, tune, popstep, steppers, traces, points): 1294 """Iterate a ``PopulationStepper``. 1295 1296 Parameters 1297 ---------- 1298 draws : int 1299 number of draws per chain 1300 tune : int 1301 number of tuning steps 1302 popstep : PopulationStepper 1303 the helper object for (parallelized) stepping of chains 1304 steppers : list 1305 The step methods for each chain 1306 traces : list 1307 Traces for each chain 1308 points : list 1309 population of chain states 1310 1311 Yields 1312 ------ 1313 traces : list 1314 List of trace objects of the individual chains 1315 """ 1316 try: 1317 with popstep: 1318 # iterate draws of all chains 1319 for i in range(draws): 1320 # this call steps all chains and returns a list of (point, stats) 1321 # the `popstep` may interact with subprocesses internally 1322 updates = popstep.step(i == tune, points) 1323 1324 # apply the update to the points and record to the traces 1325 for c, strace in enumerate(traces): 1326 if steppers[c].generates_stats: 1327 points[c], stats = updates[c] 1328 if strace.supports_sampler_stats: 1329 strace.record(points[c], stats) 1330 else: 1331 strace.record(points[c]) 1332 else: 1333 points[c] = updates[c] 1334 strace.record(points[c]) 1335 # yield the state of all chains in parallel 1336 yield traces 1337 except KeyboardInterrupt: 1338 for c, strace in enumerate(traces): 1339 strace.close() 1340 if hasattr(steppers[c], "report"): 1341 steppers[c].report._finalize(strace) 1342 raise 1343 except BaseException: 1344 for c, strace in enumerate(traces): 1345 strace.close() 1346 raise 1347 else: 1348 for c, strace in enumerate(traces): 1349 strace.close() 1350 if hasattr(steppers[c], "report"): 1351 steppers[c].report._finalize(strace) 1352 1353 1354def _choose_backend(trace, chain, **kwds) -> Backend: 1355 """Selects or creates a NDArray trace backend for a particular chain. 1356 1357 Parameters 1358 ---------- 1359 trace : BaseTrace, list, MultiTrace, or None 1360 This should be a BaseTrace, list of variables to track, 1361 or a MultiTrace object with past values. 1362 If a MultiTrace object is given, it must contain samples for the chain number ``chain``. 1363 If None or a list of variables, the NDArray backend is used. 1364 chain : int 1365 Number of the chain of interest. 1366 **kwds : 1367 keyword arguments to forward to the backend creation 1368 1369 Returns 1370 ------- 1371 trace : BaseTrace 1372 A trace object for the selected chain 1373 """ 1374 if isinstance(trace, BaseTrace): 1375 return trace 1376 if isinstance(trace, MultiTrace): 1377 return trace._straces[chain] 1378 if trace is None: 1379 return NDArray(**kwds) 1380 1381 return NDArray(vars=trace, **kwds) 1382 1383 1384def _mp_sample( 1385 draws: int, 1386 tune: int, 1387 step, 1388 chains: int, 1389 cores: int, 1390 chain: int, 1391 random_seed: list, 1392 start: list, 1393 progressbar=True, 1394 trace=None, 1395 model=None, 1396 callback=None, 1397 discard_tuned_samples=True, 1398 mp_ctx=None, 1399 pickle_backend="pickle", 1400 **kwargs, 1401): 1402 """Main iteration for multiprocess sampling. 1403 1404 Parameters 1405 ---------- 1406 draws : int 1407 The number of samples to draw 1408 tune : int, optional 1409 Number of iterations to tune, if applicable (defaults to None) 1410 step : function 1411 Step function 1412 chains : int 1413 The number of chains to sample. 1414 cores : int 1415 The number of chains to run in parallel. 1416 chain : int 1417 Number of the first chain. 1418 random_seed : list of ints 1419 Random seeds for each chain. 1420 start : list 1421 Starting points for each chain. 1422 progressbar : bool 1423 Whether or not to display a progress bar in the command line. 1424 trace : BaseTrace, list, MultiTrace or None 1425 This should be a backend instance, a list of variables to track, or a MultiTrace object 1426 with past values. If a MultiTrace object is given, it must contain samples for the chain 1427 number ``chain``. If None or a list of variables, the NDArray backend is used. 1428 model : Model (optional if in ``with`` context) 1429 callback : Callable 1430 A function which gets called for every sample from the trace of a chain. The function is 1431 called with the trace and the current draw and will contain all samples for a single trace. 1432 the ``draw.chain`` argument can be used to determine which of the active chains the sample 1433 is drawn from. 1434 Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback. 1435 1436 Returns 1437 ------- 1438 trace : pymc3.backends.base.MultiTrace 1439 A ``MultiTrace`` object that contains the samples for all chains. 1440 """ 1441 import pymc3.parallel_sampling as ps 1442 1443 # We did draws += tune in pm.sample 1444 draws -= tune 1445 1446 traces = [] 1447 for idx in range(chain, chain + chains): 1448 if trace is not None: 1449 strace = _choose_backend(copy(trace), idx, model=model) 1450 else: 1451 strace = _choose_backend(None, idx, model=model) 1452 # for user supply start value, fill-in missing value if the supplied 1453 # dict does not contain all parameters 1454 update_start_vals(start[idx - chain], model.test_point, model) 1455 if step.generates_stats and strace.supports_sampler_stats: 1456 strace.setup(draws + tune, idx + chain, step.stats_dtypes) 1457 else: 1458 strace.setup(draws + tune, idx + chain) 1459 traces.append(strace) 1460 1461 sampler = ps.ParallelSampler( 1462 draws, 1463 tune, 1464 chains, 1465 cores, 1466 random_seed, 1467 start, 1468 step, 1469 chain, 1470 progressbar, 1471 mp_ctx=mp_ctx, 1472 pickle_backend=pickle_backend, 1473 ) 1474 try: 1475 try: 1476 with sampler: 1477 for draw in sampler: 1478 trace = traces[draw.chain - chain] 1479 if trace.supports_sampler_stats and draw.stats is not None: 1480 trace.record(draw.point, draw.stats) 1481 else: 1482 trace.record(draw.point) 1483 if draw.is_last: 1484 trace.close() 1485 if draw.warnings is not None: 1486 trace._add_warnings(draw.warnings) 1487 1488 if callback is not None: 1489 callback(trace=trace, draw=draw) 1490 1491 except ps.ParallelSamplingError as error: 1492 trace = traces[error._chain - chain] 1493 trace._add_warnings(error._warnings) 1494 for trace in traces: 1495 trace.close() 1496 1497 multitrace = MultiTrace(traces) 1498 multitrace._report._log_summary() 1499 raise 1500 return MultiTrace(traces) 1501 except KeyboardInterrupt: 1502 if discard_tuned_samples: 1503 traces, length = _choose_chains(traces, tune) 1504 else: 1505 traces, length = _choose_chains(traces, 0) 1506 return MultiTrace(traces)[:length] 1507 finally: 1508 for trace in traces: 1509 trace.close() 1510 1511 1512def _choose_chains(traces, tune): 1513 """ 1514 Filter and slice traces such that (n_traces * len(shortest_trace)) is maximized. 1515 1516 We get here after a ``KeyboardInterrupt``, and so the different 1517 traces have different lengths. We therefore pick the number of 1518 traces such that (number of traces) * (length of shortest trace) 1519 is maximised. 1520 """ 1521 if tune is None: 1522 tune = 0 1523 1524 if not traces: 1525 return [] 1526 1527 lengths = [max(0, len(trace) - tune) for trace in traces] 1528 if not sum(lengths): 1529 raise ValueError("Not enough samples to build a trace.") 1530 1531 idxs = np.argsort(lengths) 1532 l_sort = np.array(lengths)[idxs] 1533 1534 use_until = np.argmax(l_sort * np.arange(1, l_sort.shape[0] + 1)[::-1]) 1535 final_length = l_sort[use_until] 1536 1537 return [traces[idx] for idx in idxs[use_until:]], final_length + tune 1538 1539 1540def stop_tuning(step): 1541 """Stop tuning the current step method.""" 1542 step.stop_tuning() 1543 return step 1544 1545 1546class _DefaultTrace: 1547 """ 1548 Utility for collecting samples into a dictionary. 1549 1550 Name comes from its similarity to ``defaultdict``: 1551 entries are lazily created. 1552 1553 Parameters 1554 ---------- 1555 samples : int 1556 The number of samples that will be collected, per variable, 1557 into the trace. 1558 1559 Attributes 1560 ---------- 1561 trace_dict : Dict[str, np.ndarray] 1562 A dictionary constituting a trace. Should be extracted 1563 after a procedure has filled the `_DefaultTrace` using the 1564 `insert()` method 1565 """ 1566 1567 trace_dict: Dict[str, np.ndarray] = {} 1568 _len: Optional[int] = None 1569 1570 def __init__(self, samples: int): 1571 self._len = samples 1572 self.trace_dict = {} 1573 1574 def insert(self, k: str, v, idx: int): 1575 """ 1576 Insert `v` as the value of the `idx`th sample for the variable `k`. 1577 1578 Parameters 1579 ---------- 1580 k: str 1581 Name of the variable. 1582 v: anything that can go into a numpy array (including a numpy array) 1583 The value of the `idx`th sample from variable `k` 1584 ids: int 1585 The index of the sample we are inserting into the trace. 1586 """ 1587 value_shape = np.shape(v) 1588 1589 # initialize if necessary 1590 if k not in self.trace_dict: 1591 array_shape = (self._len,) + value_shape 1592 self.trace_dict[k] = np.empty(array_shape, dtype=np.array(v).dtype) 1593 1594 # do the actual insertion 1595 if value_shape == (): 1596 self.trace_dict[k][idx] = v 1597 else: 1598 self.trace_dict[k][idx, :] = v 1599 1600 1601def sample_posterior_predictive( 1602 trace, 1603 samples: Optional[int] = None, 1604 model: Optional[Model] = None, 1605 var_names: Optional[List[str]] = None, 1606 size: Optional[int] = None, 1607 keep_size: Optional[bool] = False, 1608 random_seed=None, 1609 progressbar: bool = True, 1610) -> Dict[str, np.ndarray]: 1611 """Generate posterior predictive samples from a model given a trace. 1612 1613 Parameters 1614 ---------- 1615 trace : backend, list, xarray.Dataset, arviz.InferenceData, or MultiTrace 1616 Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()), 1617 or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior) 1618 samples : int 1619 Number of posterior predictive samples to generate. Defaults to one posterior predictive 1620 sample per posterior sample, that is, the number of draws times the number of chains. It 1621 is not recommended to modify this value; when modified, some chains may not be represented 1622 in the posterior predictive sample. 1623 model : Model (optional if in ``with`` context) 1624 Model used to generate ``trace`` 1625 vars : iterable 1626 Variables for which to compute the posterior predictive samples. 1627 Deprecated: please use ``var_names`` instead. 1628 var_names : Iterable[str] 1629 Names of variables for which to compute the posterior predictive samples. 1630 size : int 1631 The number of random draws from the distribution specified by the parameters in each 1632 sample of the trace. Not recommended unless more than ndraws times nchains posterior 1633 predictive samples are needed. 1634 keep_size : bool, optional 1635 Force posterior predictive sample to have the same shape as posterior and sample stats 1636 data: ``(nchains, ndraws, ...)``. Overrides samples and size parameters. 1637 random_seed : int 1638 Seed for the random number generator. 1639 progressbar : bool 1640 Whether or not to display a progress bar in the command line. The bar shows the percentage 1641 of completion, the sampling speed in samples per second (SPS), and the estimated remaining 1642 time until completion ("expected time of arrival"; ETA). 1643 1644 Returns 1645 ------- 1646 samples : dict 1647 Dictionary with the variable names as keys, and values numpy arrays containing 1648 posterior predictive samples. 1649 """ 1650 1651 _trace: Union[MultiTrace, PointList] 1652 if isinstance(trace, InferenceData): 1653 _trace = dataset_to_point_list(trace.posterior) 1654 elif isinstance(trace, xarray.Dataset): 1655 _trace = dataset_to_point_list(trace) 1656 else: 1657 _trace = trace 1658 1659 nchain: int 1660 len_trace: int 1661 if isinstance(trace, (InferenceData, xarray.Dataset)): 1662 nchain, len_trace = chains_and_samples(trace) 1663 else: 1664 len_trace = len(_trace) 1665 try: 1666 nchain = _trace.nchains 1667 except AttributeError: 1668 nchain = 1 1669 1670 if keep_size and samples is not None: 1671 raise IncorrectArgumentsError("Should not specify both keep_size and samples arguments") 1672 if keep_size and size is not None: 1673 raise IncorrectArgumentsError("Should not specify both keep_size and size arguments") 1674 1675 if samples is None: 1676 if isinstance(_trace, MultiTrace): 1677 samples = sum(len(v) for v in _trace._straces.values()) 1678 elif isinstance(_trace, list) and all(isinstance(x, dict) for x in _trace): 1679 # this is a list of points 1680 samples = len(_trace) 1681 else: 1682 raise TypeError( 1683 "Do not know how to compute number of samples for trace argument of type %s" 1684 % type(_trace) 1685 ) 1686 1687 assert samples is not None 1688 if samples < len_trace * nchain: 1689 warnings.warn( 1690 "samples parameter is smaller than nchains times ndraws, some draws " 1691 "and/or chains may not be represented in the returned posterior " 1692 "predictive sample" 1693 ) 1694 1695 model = modelcontext(model) 1696 1697 if model.potentials: 1698 warnings.warn( 1699 "The effect of Potentials on other parameters is ignored during posterior predictive sampling. " 1700 "This is likely to lead to invalid or biased predictive samples.", 1701 UserWarning, 1702 ) 1703 1704 if var_names is not None: 1705 vars_ = [model[x] for x in var_names] 1706 else: 1707 vars_ = model.observed_RVs 1708 1709 if random_seed is not None: 1710 np.random.seed(random_seed) 1711 1712 indices = np.arange(samples) 1713 1714 if progressbar: 1715 indices = progress_bar(indices, total=samples, display=progressbar) 1716 1717 ppc_trace_t = _DefaultTrace(samples) 1718 try: 1719 for idx in indices: 1720 if nchain > 1: 1721 # the trace object will either be a MultiTrace (and have _straces)... 1722 if hasattr(_trace, "_straces"): 1723 chain_idx, point_idx = np.divmod(idx, len_trace) 1724 param = cast(MultiTrace, _trace)._straces[chain_idx % nchain].point(point_idx) 1725 # ... or a PointList 1726 else: 1727 param = cast(PointList, _trace)[idx % (len_trace * nchain)] 1728 # there's only a single chain, but the index might hit it multiple times if 1729 # the number of indices is greater than the length of the trace. 1730 else: 1731 param = _trace[idx % len_trace] 1732 1733 values = draw_values(vars_, point=param, size=size) 1734 for k, v in zip(vars_, values): 1735 ppc_trace_t.insert(k.name, v, idx) 1736 except KeyboardInterrupt: 1737 pass 1738 1739 ppc_trace = ppc_trace_t.trace_dict 1740 if keep_size: 1741 for k, ary in ppc_trace.items(): 1742 ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:])) 1743 1744 return ppc_trace 1745 1746 1747def sample_posterior_predictive_w( 1748 traces, 1749 samples: Optional[int] = None, 1750 models: Optional[List[Model]] = None, 1751 weights: Optional[ArrayLike] = None, 1752 random_seed: Optional[int] = None, 1753 progressbar: bool = True, 1754): 1755 """Generate weighted posterior predictive samples from a list of models and 1756 a list of traces according to a set of weights. 1757 1758 Parameters 1759 ---------- 1760 traces : list or list of lists 1761 List of traces generated from MCMC sampling (xarray.Dataset, arviz.InferenceData, or 1762 MultiTrace), or a list of list containing dicts from find_MAP() or points. The number of 1763 traces should be equal to the number of weights. 1764 samples : int, optional 1765 Number of posterior predictive samples to generate. Defaults to the 1766 length of the shorter trace in traces. 1767 models : list of Model 1768 List of models used to generate the list of traces. The number of models should be equal to 1769 the number of weights and the number of observed RVs should be the same for all models. 1770 By default a single model will be inferred from ``with`` context, in this case results will 1771 only be meaningful if all models share the same distributions for the observed RVs. 1772 weights : array-like, optional 1773 Individual weights for each trace. Default, same weight for each model. 1774 random_seed : int, optional 1775 Seed for the random number generator. 1776 progressbar : bool, optional default True 1777 Whether or not to display a progress bar in the command line. The bar shows the percentage 1778 of completion, the sampling speed in samples per second (SPS), and the estimated remaining 1779 time until completion ("expected time of arrival"; ETA). 1780 1781 Returns 1782 ------- 1783 samples : dict 1784 Dictionary with the variables as keys. The values corresponding to the 1785 posterior predictive samples from the weighted models. 1786 """ 1787 np.random.seed(random_seed) 1788 1789 if isinstance(traces[0], InferenceData): 1790 n_samples = [ 1791 trace.posterior.sizes["chain"] * trace.posterior.sizes["draw"] for trace in traces 1792 ] 1793 traces = [dataset_to_point_list(trace.posterior) for trace in traces] 1794 elif isinstance(traces[0], xarray.Dataset): 1795 n_samples = [trace.sizes["chain"] * trace.sizes["draw"] for trace in traces] 1796 traces = [dataset_to_point_list(trace) for trace in traces] 1797 else: 1798 n_samples = [len(i) * i.nchains for i in traces] 1799 1800 if models is None: 1801 models = [modelcontext(models)] * len(traces) 1802 1803 for model in models: 1804 if model.potentials: 1805 warnings.warn( 1806 "The effect of Potentials on other parameters is ignored during posterior predictive sampling. " 1807 "This is likely to lead to invalid or biased predictive samples.", 1808 UserWarning, 1809 ) 1810 break 1811 1812 if weights is None: 1813 weights = [1] * len(traces) 1814 1815 if len(traces) != len(weights): 1816 raise ValueError("The number of traces and weights should be the same") 1817 1818 if len(models) != len(weights): 1819 raise ValueError("The number of models and weights should be the same") 1820 1821 length_morv = len(models[0].observed_RVs) 1822 if any(len(i.observed_RVs) != length_morv for i in models): 1823 raise ValueError("The number of observed RVs should be the same for all models") 1824 1825 weights = np.asarray(weights) 1826 p = weights / np.sum(weights) 1827 1828 min_tr = min(n_samples) 1829 1830 n = (min_tr * p).astype("int") 1831 # ensure n sum up to min_tr 1832 idx = np.argmax(n) 1833 n[idx] = n[idx] + min_tr - np.sum(n) 1834 trace = [] 1835 for i, j in enumerate(n): 1836 tr = traces[i] 1837 len_trace = len(tr) 1838 try: 1839 nchain = tr.nchains 1840 except AttributeError: 1841 nchain = 1 1842 1843 indices = np.random.randint(0, nchain * len_trace, j) 1844 if nchain > 1: 1845 chain_idx, point_idx = np.divmod(indices, len_trace) 1846 for idx in zip(chain_idx, point_idx): 1847 trace.append(tr._straces[idx[0]].point(idx[1])) 1848 else: 1849 for idx in indices: 1850 trace.append(tr[idx]) 1851 1852 obs = [x for m in models for x in m.observed_RVs] 1853 variables = np.repeat(obs, n) 1854 1855 lengths = list({np.atleast_1d(observed).shape for observed in obs}) 1856 1857 if len(lengths) == 1: 1858 size = [None for i in variables] 1859 elif len(lengths) > 2: 1860 raise ValueError("Observed variables could not be broadcast together") 1861 else: 1862 size = [] 1863 x = np.zeros(shape=lengths[0]) 1864 y = np.zeros(shape=lengths[1]) 1865 b = np.broadcast(x, y) 1866 for var in variables: 1867 shape = np.shape(np.atleast_1d(var.distribution.default())) 1868 if shape != b.shape: 1869 size.append(b.shape) 1870 else: 1871 size.append(None) 1872 len_trace = len(trace) 1873 1874 if samples is None: 1875 samples = len_trace 1876 1877 indices = np.random.randint(0, len_trace, samples) 1878 1879 if progressbar: 1880 indices = progress_bar(indices, total=samples, display=progressbar) 1881 1882 try: 1883 ppc = defaultdict(list) 1884 for idx in indices: 1885 param = trace[idx] 1886 var = variables[idx] 1887 # TODO sample_posterior_predictive_w is currently only work for model with 1888 # one observed. 1889 ppc[var.name].append(draw_values([var], point=param, size=size[idx])[0]) 1890 1891 except KeyboardInterrupt: 1892 pass 1893 else: 1894 return {k: np.asarray(v) for k, v in ppc.items()} 1895 1896 1897def sample_prior_predictive( 1898 samples=500, 1899 model: Optional[Model] = None, 1900 var_names: Optional[Iterable[str]] = None, 1901 random_seed=None, 1902) -> Dict[str, np.ndarray]: 1903 """Generate samples from the prior predictive distribution. 1904 1905 Parameters 1906 ---------- 1907 samples : int 1908 Number of samples from the prior predictive to generate. Defaults to 500. 1909 model : Model (optional if in ``with`` context) 1910 var_names : Iterable[str] 1911 A list of names of variables for which to compute the posterior predictive 1912 samples. Defaults to both observed and unobserved RVs. 1913 random_seed : int 1914 Seed for the random number generator. 1915 1916 Returns 1917 ------- 1918 dict 1919 Dictionary with variable names as keys. The values are numpy arrays of prior 1920 samples. 1921 """ 1922 model = modelcontext(model) 1923 1924 if model.potentials: 1925 warnings.warn( 1926 "The effect of Potentials on other parameters is ignored during prior predictive sampling. " 1927 "This is likely to lead to invalid or biased predictive samples.", 1928 UserWarning, 1929 ) 1930 1931 if var_names is None: 1932 prior_pred_vars = model.observed_RVs 1933 prior_vars = ( 1934 get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials 1935 ) 1936 vars_: Set[str] = {var.name for var in prior_vars + prior_pred_vars} 1937 else: 1938 vars_ = set(var_names) 1939 1940 if random_seed is not None: 1941 np.random.seed(random_seed) 1942 names = get_default_varnames(vars_, include_transformed=False) 1943 # draw_values fails with auto-transformed variables. transform them later! 1944 values = draw_values([model[name] for name in names], size=samples) 1945 1946 data = {k: v for k, v in zip(names, values)} 1947 if data is None: 1948 raise AssertionError("No variables sampled: attempting to sample %s" % names) 1949 1950 prior: Dict[str, np.ndarray] = {} 1951 for var_name in vars_: 1952 if var_name in data: 1953 prior[var_name] = data[var_name] 1954 elif is_transformed_name(var_name): 1955 untransformed = get_untransformed_name(var_name) 1956 if untransformed in data: 1957 prior[var_name] = model[untransformed].transformation.forward_val( 1958 data[untransformed] 1959 ) 1960 return prior 1961 1962 1963def _init_jitter(model, chains, jitter_max_retries): 1964 """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. 1965 1966 pymc3.util.check_start_vals is used to test whether the jittered starting values produce 1967 a finite log probability. Invalid values are resampled unless `jitter_max_retries` is achieved, 1968 in which case the last sampled values are returned. 1969 1970 Parameters 1971 ---------- 1972 model : pymc3.Model 1973 chains : int 1974 jitter_max_retries : int 1975 Maximum number of repeated attempts at initializing values (per chain). 1976 1977 Returns 1978 ------- 1979 start : ``pymc3.model.Point`` 1980 Starting point for sampler 1981 """ 1982 start = [] 1983 for _ in range(chains): 1984 for i in range(jitter_max_retries + 1): 1985 mean = {var: val.copy() for var, val in model.test_point.items()} 1986 for val in mean.values(): 1987 val[...] += 2 * np.random.rand(*val.shape) - 1 1988 1989 if i < jitter_max_retries: 1990 try: 1991 check_start_vals(mean, model) 1992 except SamplingError: 1993 pass 1994 else: 1995 break 1996 1997 start.append(mean) 1998 return start 1999 2000 2001def init_nuts( 2002 init="auto", 2003 chains=1, 2004 n_init=500000, 2005 model=None, 2006 random_seed=None, 2007 progressbar=True, 2008 jitter_max_retries=10, 2009 **kwargs, 2010): 2011 """Set up the mass matrix initialization for NUTS. 2012 2013 NUTS convergence and sampling speed is extremely dependent on the 2014 choice of mass/scaling matrix. This function implements different 2015 methods for choosing or adapting the mass matrix. 2016 2017 Parameters 2018 ---------- 2019 init : str 2020 Initialization method to use. 2021 2022 * auto: Choose a default initialization method automatically. 2023 Currently, this is ``jitter+adapt_diag``, but this can change in the future. If you 2024 depend on the exact behaviour, choose an initialization method explicitly. 2025 * adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the 2026 variance of the tuning samples. All chains use the test value (usually the prior mean) 2027 as starting point. 2028 * jitter+adapt_diag: Same as ``adapt_diag``, but use test value plus a uniform jitter in 2029 [-1, 1] as starting point in each chain. 2030 * advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the 2031 sample variance of the tuning samples. 2032 * advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based 2033 on the variance of the gradients during tuning. This is **experimental** and might be 2034 removed in a future release. 2035 * advi: Run ADVI to estimate posterior mean and diagonal mass matrix. 2036 * advi_map: Initialize ADVI with MAP and use MAP as starting point. 2037 * map: Use the MAP as starting point. This is discouraged. 2038 * adapt_full: Adapt a dense mass matrix using the sample covariances. All chains use the 2039 test value (usually the prior mean) as starting point. 2040 * jitter+adapt_full: Same as ``adapt_full``, but use test value plus a uniform jitter in 2041 [-1, 1] as starting point in each chain. 2042 2043 chains : int 2044 Number of jobs to start. 2045 n_init : int 2046 Number of iterations of initializer. Only works for 'ADVI' init methods. 2047 model : Model (optional if in ``with`` context) 2048 progressbar : bool 2049 Whether or not to display a progressbar for advi sampling. 2050 jitter_max_retries : int 2051 Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter 2052 that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full`` 2053 init methods. 2054 **kwargs : keyword arguments 2055 Extra keyword arguments are forwarded to pymc3.NUTS. 2056 2057 Returns 2058 ------- 2059 start : ``pymc3.model.Point`` 2060 Starting point for sampler 2061 nuts_sampler : ``pymc3.step_methods.NUTS`` 2062 Instantiated and initialized NUTS sampler object 2063 """ 2064 model = modelcontext(model) 2065 2066 vars = kwargs.get("vars", model.vars) 2067 if set(vars) != set(model.vars): 2068 raise ValueError("Must use init_nuts on all variables of a model.") 2069 if not all_continuous(vars): 2070 raise ValueError("init_nuts can only be used for models with only " "continuous variables.") 2071 2072 if not isinstance(init, str): 2073 raise TypeError("init must be a string.") 2074 2075 if init is not None: 2076 init = init.lower() 2077 2078 if init == "auto": 2079 init = "jitter+adapt_diag" 2080 2081 _log.info(f"Initializing NUTS using {init}...") 2082 2083 if random_seed is not None: 2084 random_seed = int(np.atleast_1d(random_seed)[0]) 2085 np.random.seed(random_seed) 2086 2087 cb = [ 2088 pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"), 2089 pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"), 2090 ] 2091 2092 if init == "adapt_diag": 2093 start = [model.test_point] * chains 2094 mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0) 2095 var = np.ones_like(mean) 2096 potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10) 2097 elif init == "jitter+adapt_diag": 2098 start = _init_jitter(model, chains, jitter_max_retries) 2099 mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0) 2100 var = np.ones_like(mean) 2101 potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10) 2102 elif init == "advi+adapt_diag_grad": 2103 approx: pm.MeanField = pm.fit( 2104 random_seed=random_seed, 2105 n=n_init, 2106 method="advi", 2107 model=model, 2108 callbacks=cb, 2109 progressbar=progressbar, 2110 obj_optimizer=pm.adagrad_window, 2111 ) 2112 start = approx.sample(draws=chains) 2113 start = list(start) 2114 stds = approx.bij.rmap(approx.std.eval()) 2115 cov = model.dict_to_array(stds) ** 2 2116 mean = approx.bij.rmap(approx.mean.get_value()) 2117 mean = model.dict_to_array(mean) 2118 weight = 50 2119 potential = quadpotential.QuadPotentialDiagAdaptGrad(model.ndim, mean, cov, weight) 2120 elif init == "advi+adapt_diag": 2121 approx = pm.fit( 2122 random_seed=random_seed, 2123 n=n_init, 2124 method="advi", 2125 model=model, 2126 callbacks=cb, 2127 progressbar=progressbar, 2128 obj_optimizer=pm.adagrad_window, 2129 ) 2130 start = approx.sample(draws=chains) 2131 start = list(start) 2132 stds = approx.bij.rmap(approx.std.eval()) 2133 cov = model.dict_to_array(stds) ** 2 2134 mean = approx.bij.rmap(approx.mean.get_value()) 2135 mean = model.dict_to_array(mean) 2136 weight = 50 2137 potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, cov, weight) 2138 elif init == "advi": 2139 approx = pm.fit( 2140 random_seed=random_seed, 2141 n=n_init, 2142 method="advi", 2143 model=model, 2144 callbacks=cb, 2145 progressbar=progressbar, 2146 obj_optimizer=pm.adagrad_window, 2147 ) 2148 start = approx.sample(draws=chains) 2149 start = list(start) 2150 stds = approx.bij.rmap(approx.std.eval()) 2151 cov = model.dict_to_array(stds) ** 2 2152 potential = quadpotential.QuadPotentialDiag(cov) 2153 elif init == "advi_map": 2154 start = pm.find_MAP(include_transformed=True) 2155 approx = pm.MeanField(model=model, start=start) 2156 pm.fit( 2157 random_seed=random_seed, 2158 n=n_init, 2159 method=pm.KLqp(approx), 2160 callbacks=cb, 2161 progressbar=progressbar, 2162 obj_optimizer=pm.adagrad_window, 2163 ) 2164 start = approx.sample(draws=chains) 2165 start = list(start) 2166 stds = approx.bij.rmap(approx.std.eval()) 2167 cov = model.dict_to_array(stds) ** 2 2168 potential = quadpotential.QuadPotentialDiag(cov) 2169 elif init == "map": 2170 start = pm.find_MAP(include_transformed=True) 2171 cov = pm.find_hessian(point=start) 2172 start = [start] * chains 2173 potential = quadpotential.QuadPotentialFull(cov) 2174 elif init == "adapt_full": 2175 start = [model.test_point] * chains 2176 mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0) 2177 cov = np.eye(model.ndim) 2178 potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10) 2179 elif init == "jitter+adapt_full": 2180 start = _init_jitter(model, chains, jitter_max_retries) 2181 mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0) 2182 cov = np.eye(model.ndim) 2183 potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10) 2184 else: 2185 raise ValueError(f"Unknown initializer: {init}.") 2186 2187 step = pm.NUTS(potential=potential, model=model, **kwargs) 2188 2189 return start, step 2190